Skip to content

[RFC] NVFP4 Rounding Modes #3264

@syed-ahmed

Description

@syed-ahmed

The current rounding mode for NVFP4 tensors in TorchAO is round-to-nearest. The purpose of this issue is to discuss support for other rounding modes.

What rounding modes are available?

  • Stochastic Rounding (RS)
  • Round Nearest (RN)
  • Round-zero (RZ)

Where do we need different rounding modes?

  • NVFP4 Training Recipe (NVFP4 MoE Training Status torchtitan#1962)
    • RS for gradients
    • RN for weights and activation
  • _AdamW in torchao.optim supports BF16 stochastic rounding:
    ```python
    # a clone of torch.optim.AdamW with extra features
    from torchao.optim import _AdamW
    model = ...
    model_bf16 = model.bfloat16()
    optim = _AdamW(model_bf16.parameters(), bf16_stochastic_round=True)
  • INT8 Quantization has stochastic rounding mode in TorchAO:
    def quantize_int8_rowwise(
    tensor: Tensor, stochastic_rounding: bool = False, eps: float = 1e-12
    ):

Existing RN Kernels

  • Eager path:
    def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
    """Convert FP32 numbers to sub-byte floating point numbers with the given
    number of exponent and mantissa bits.
  • torch.compile:
    def convert_fp32_to_fp4_packed(x_pairs):
    """Convert FP32 pairs to packed FP4 format.
    • Uses cvt.rn.satfinite.e2m1x2.f32 inline asm

Possible RS Kernels implementation

  • Emulated implementation from @slayton58. I've quickly written this in triton syntax but it probably makes most sense to write in pytorch eager similar to RN (_f32_to_floatx_unpacked).
    @triton.jit
    def float_rs(x, seed, offset):
        """
        Apply stochastic rounding when casting from float32 to NVFP4.
        
        Args:
            x: Input tensor (float32)
            seed: Random seed for the random number generator
            offset: Offset for random number generation (should be unique per element)
        
        Returns:
            Stochastically rounded tensor
        """
        
        # Scale down by 2^(-125) to normalize range
        downscale_factor = tl.math.exp2(-125.0)
        x = x * downscale_factor
        
        # Create 32-bit pseudorandom value
        rnd = tl.randint(seed, offset)
        
        # Isolate lower 22 bits for randomness injection
        # Process: left-shift by 10, then right-shift by 10
        rnd_shifted = (rnd << 10) >> 10
        
        # Reinterpret float bits as unsigned integer
        xb = x.to(tl.uint32, bitcast=True)
        
        # Inject randomness into the discarded precision bits
        yb = xb + rnd_shifted
        
        # Clear the lower 22 bits to perform rounding
        yb = (yb >> 22) << 22
        
        # Reinterpret integer bits back as floating point
        y = yb.to(tl.float32, bitcast=True)
        
        # Restore original magnitude by scaling up
        upscale_factor = tl.math.exp2(125.0)
        y = y * upscale_factor
        
        return y
  • Use an inline asm triton kernel using cvt.rs.satfinite.e2m1x4.f32 for stochastic rounding similar to RN.

Integration

  • A possible integration point for NVFP4 Training Recipe use case is to specify the rounding mode in to_nvfp4 calls.
    class RoundingMode(Enum):
        RN = "round_nearest"
        RS = "round_stochastic"
        RZ = "round_zero"
    
    def to_nvfp4(
            data_hp: torch.Tensor,
            block_size: int = 16,
            per_tensor_scale: Optional[torch.Tensor] = None,
            act_per_tensor_scale: Optional[torch.Tensor] = None,
            is_swizzled_scales: bool = False,
            use_triton_kernel: bool = False,
            act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None,
            rounding_mode: RoundingMode = RoundingMode.RN,
        ):
        ...
        if use_triton_kernel:
            blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale, rounding_mode)
        else:
            blockwise_scales, data_lp = nvfp4_quantize(
                data_hp, block_size, per_tensor_scale, rounding_mode
            )
  • We should discuss if we need to support rounding mode more generically to support other use cases like _AdamW, and int8 training.

Test Plan

  • TODO

CC: @slayton58, @ngimel, @supriyar, @Priyadlfw, @ptrblck, @eqy

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions