-
Notifications
You must be signed in to change notification settings - Fork 358
Open
Description
The current rounding mode for NVFP4 tensors in TorchAO is round-to-nearest. The purpose of this issue is to discuss support for other rounding modes.
What rounding modes are available?
- Stochastic Rounding (RS)
- Round Nearest (RN)
- Round-zero (RZ)
Where do we need different rounding modes?
- NVFP4 Training Recipe (NVFP4 MoE Training Status torchtitan#1962)
- RS for gradients
- RN for weights and activation
- _AdamW in torchao.optim supports BF16 stochastic rounding:
Lines 71 to 77 in 1e473ed
```python # a clone of torch.optim.AdamW with extra features from torchao.optim import _AdamW model = ... model_bf16 = model.bfloat16() optim = _AdamW(model_bf16.parameters(), bf16_stochastic_round=True) - INT8 Quantization has stochastic rounding mode in TorchAO:
ao/torchao/prototype/quantized_training/int8.py
Lines 24 to 26 in 1e473ed
def quantize_int8_rowwise( tensor: Tensor, stochastic_rounding: bool = False, eps: float = 1e-12 ):
Existing RN Kernels
- Eager path:
ao/torchao/prototype/custom_fp_utils.py
Lines 27 to 30 in 1e473ed
def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: """Convert FP32 numbers to sub-byte floating point numbers with the given number of exponent and mantissa bits. - torch.compile:
ao/torchao/prototype/mx_formats/kernels.py
Lines 1492 to 1493 in 1e473ed
def convert_fp32_to_fp4_packed(x_pairs): """Convert FP32 pairs to packed FP4 format. - Uses
cvt.rn.satfinite.e2m1x2.f32 inline asm
- Uses
Possible RS Kernels implementation
- Emulated implementation from @slayton58. I've quickly written this in triton syntax but it probably makes most sense to write in pytorch eager similar to RN (
_f32_to_floatx_unpacked).@triton.jit def float_rs(x, seed, offset): """ Apply stochastic rounding when casting from float32 to NVFP4. Args: x: Input tensor (float32) seed: Random seed for the random number generator offset: Offset for random number generation (should be unique per element) Returns: Stochastically rounded tensor """ # Scale down by 2^(-125) to normalize range downscale_factor = tl.math.exp2(-125.0) x = x * downscale_factor # Create 32-bit pseudorandom value rnd = tl.randint(seed, offset) # Isolate lower 22 bits for randomness injection # Process: left-shift by 10, then right-shift by 10 rnd_shifted = (rnd << 10) >> 10 # Reinterpret float bits as unsigned integer xb = x.to(tl.uint32, bitcast=True) # Inject randomness into the discarded precision bits yb = xb + rnd_shifted # Clear the lower 22 bits to perform rounding yb = (yb >> 22) << 22 # Reinterpret integer bits back as floating point y = yb.to(tl.float32, bitcast=True) # Restore original magnitude by scaling up upscale_factor = tl.math.exp2(125.0) y = y * upscale_factor return y
- Use an inline asm triton kernel using
cvt.rs.satfinite.e2m1x4.f32for stochastic rounding similar to RN.
Integration
- A possible integration point for NVFP4 Training Recipe use case is to specify the rounding mode in
to_nvfp4calls.class RoundingMode(Enum): RN = "round_nearest" RS = "round_stochastic" RZ = "round_zero" def to_nvfp4( data_hp: torch.Tensor, block_size: int = 16, per_tensor_scale: Optional[torch.Tensor] = None, act_per_tensor_scale: Optional[torch.Tensor] = None, is_swizzled_scales: bool = False, use_triton_kernel: bool = False, act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None, rounding_mode: RoundingMode = RoundingMode.RN, ): ... if use_triton_kernel: blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale, rounding_mode) else: blockwise_scales, data_lp = nvfp4_quantize( data_hp, block_size, per_tensor_scale, rounding_mode )
- We should discuss if we need to support rounding mode more generically to support other use cases like _AdamW, and int8 training.
Test Plan
- TODO
CC: @slayton58, @ngimel, @supriyar, @Priyadlfw, @ptrblck, @eqy
drisspg, liangel-02, namgyu-youn and vkuzo
Metadata
Metadata
Assignees
Labels
No labels
Type
Projects
Status
No status