-
Notifications
You must be signed in to change notification settings - Fork 4
Description
In fixed-point quantization at scale s, a value x is represented as x_hat = x · 2^s. After a multiplication x_hat * y_hat, the result is at scale 2sand must be "rebased" back to scalesby dividing by2^s`.
The old path did this in two separate steps:
- Multiply in
i32:a_i32 * b_i32 -> product_i32 - Rebase in a follow-up node:
product_i32 >> scale
The problem is step 1. Two scaled i32 values multiplied together easily exceed the i32 range (±2^31). For example at scale 12, two modest values like x_hat = 50,000 and y_hat = 50,000 produce 2.5 × 10^9 - already near i32::MAX. In an einsum (matrix multiply) you're summing many such products, so overflows are essentially guaranteed. In Rust release builds this wraps silently, producing garbage values with flipped signs.
The subsequent rebase divides by 2^12, but it's dividing already-corrupted wrapped values — the damage is irreversible. Worse, this corruption compounds through every layer: LayerNorm, attention scores, FFN - each multiplication re-wraps, and each rebase re-divides garbage.
The new path fuses multiply + accumulate + rebase into a single operation:
- Upcast to
i64:a_i32 as i64 * b_i32 as i64 → product_i64 - Accumulate in
i64(for einsum: sum all products in i64) - Rebase in
i64:accumulated_i64 / (1 << scale)— the division happens while the full-precision result is still intact - Clamp to
i32range:result.clamp(i32::MIN, i32::MAX) as i32— saturates instead of wrapping, so extreme values hit the rail rather than flipping sign
This matters because in i64 there's ~9.2 × 10^18 of headroom. A typical GPT-2 matrix multiply accumulates ~768 products of two scale-12 values — the worst-case sum is well withini64 range. So the rebase division now operates on the exact accumulated result, and the final clamp ensures no sign-flip corruption leaks into downstream layers.
The effect is multiplicative across layers: fixing the precision at each multiply/einsum means LayerNorm inputs are correct, which means attention scores are correct, which means softmax outputs are correct, etc. That's why the end-to-end metrics jump from 0% top-1 agreement to 100%.