Skip to content

Prove i64-accumulate + fused_rebase + clamp for einsum, mul, cube, etc. #190

@Forpee

Description

@Forpee

In fixed-point quantization at scale s, a value x is represented as x_hat = x · 2^s. After a multiplication x_hat * y_hat, the result is at scale 2sand must be "rebased" back to scalesby dividing by2^s`.

The old path did this in two separate steps:

  1. Multiply in i32: a_i32 * b_i32 -> product_i32
  2. Rebase in a follow-up node: product_i32 >> scale

The problem is step 1. Two scaled i32 values multiplied together easily exceed the i32 range (±2^31). For example at scale 12, two modest values like x_hat = 50,000 and y_hat = 50,000 produce 2.5 × 10^9 - already near i32::MAX. In an einsum (matrix multiply) you're summing many such products, so overflows are essentially guaranteed. In Rust release builds this wraps silently, producing garbage values with flipped signs.

The subsequent rebase divides by 2^12, but it's dividing already-corrupted wrapped values — the damage is irreversible. Worse, this corruption compounds through every layer: LayerNorm, attention scores, FFN - each multiplication re-wraps, and each rebase re-divides garbage.

The new path fuses multiply + accumulate + rebase into a single operation:

  1. Upcast to i64: a_i32 as i64 * b_i32 as i64 → product_i64
  2. Accumulate ini64 (for einsum: sum all products in i64)
  3. Rebase in i64: accumulated_i64 / (1 << scale) — the division happens while the full-precision result is still intact
  4. Clamp to i32 range: result.clamp(i32::MIN, i32::MAX) as i32 — saturates instead of wrapping, so extreme values hit the rail rather than flipping sign

This matters because in i64 there's ~9.2 × 10^18 of headroom. A typical GPT-2 matrix multiply accumulates ~768 products of two scale-12 values — the worst-case sum is well withini64 range. So the rebase division now operates on the exact accumulated result, and the final clamp ensures no sign-flip corruption leaks into downstream layers.

The effect is multiplicative across layers: fixing the precision at each multiply/einsum means LayerNorm inputs are correct, which means attention scores are correct, which means softmax outputs are correct, etc. That's why the end-to-end metrics jump from 0% top-1 agreement to 100%.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions