Skip to content

float16 & where to hide complexity from scaling tricks #25

@f-dangel

Description

@f-dangel

During one of our internal discussions, we realized that the code starts to accumulate multiple re-scaling operations which are required to avoid over/under-floating when using float16 (e.g. using the average trace #24, splitting the 1 / batch_size and grad_scale when computing H_K, H_C). The problem with those tricks is that their implementation is often non-local, i.e. they affect multiple functions and are thus hard to understand. Long-term, the accumulation of such tricks will make the code unusably complex.

In my opinion, we should attempt to keep the optimizer's update step as scaling trick-free as possible, and move such complexity inside the StructuredMatrix interface. Also, we will need a strong motivation to implement this, and examples demonstrating the effectiveness of such tricks, as it could double the code base for the structured matrices. One possible scenario is that float16 is simply too unstable, and we will thus not support it.

  • Idea 1 (stable structured matrices): We could support a DenseStableMatrix class, which is the equivalent of DenseMatrix but has special implementations with higher stability (but also cost) in float16. This structure could then be easily accessed by the optimizer by adding a 'dense_stable' option to its supported entries in structures. A simple idea that might keep the computation stable might be to treat each Tensor t of a structured matrix internally as a (float, Tensor) tuple (scale, t_scaled), that is a normalized tensor and its scaling factor. For instance, multiplication by a scalar alpha will then just correspond to using (alpha * scale, t_scaled) internally.

    However, to me it is currently unclear if such a simple heuristic will consistently make all operations stable.

  • Alternative ideas go here

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions