Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def collect(self, x: torch.Tensor):
multipliers = torch.linspace(
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
)
print(f"Multipliers: {multipliers}")

# Get reduce axis for per-channel quantization
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)
Expand Down
122 changes: 114 additions & 8 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,112 @@
"algorithm": "max",
}

NVFP4_WEIGHT_MAX_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}

NVFP4_WEIGHT_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"num_steps": 8,
"start_multiplier": 0.25,
"stop_multiplier": 2.0,
},
}

NVFP4_WEIGHT_MSE_4_6_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"enable": False,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"num_steps": 8,
"start_multiplier": 0.375,
"stop_multiplier": 3.0,
},
}

NVFP4_WEIGHT_ACT_MSE_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"num_steps": 8,
"start_multiplier": 0.25,
"stop_multiplier": 2.0,
},
}

NVFP4_WEIGHT_ACT_MSE_4_6_CFG = {
"quant_cfg": {
"*weight_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
"*input_quantizer": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
},
**_default_disabled_quantizer_cfg,
},
"algorithm": {
"method": "mse",
"num_steps": 8,
"start_multiplier": 0.375,
"stop_multiplier": 3.0,
},
}


NVFP4_AWQ_LITE_CFG = {
"quant_cfg": {
Expand Down Expand Up @@ -720,7 +826,7 @@ def validate_num_bits(self):
if not all(x > 0 for x in num_bits):
raise ValueError("num_bits must be a positive integer or a tuple of positive integers.")

block_sizes = self.block_sizes
# block_sizes = self.block_sizes
if num_bits not in [
(4, 3),
(5, 2),
Expand All @@ -734,13 +840,13 @@ def validate_num_bits(self):
raise ValueError(
"Supported FPx quantization formats: FP8 (E4M3, E5M2), FP6(E3M2, E2M3), FP4(E2M1)."
)
elif num_bits != (4, 3) and (
block_sizes is None or block_sizes.get("type", None) != "dynamic"
):
raise ValueError(
"Only blockwise dynamic quantization is supported with quantization "
"formats E{num_bis[0]}M{num_bits[1]}."
)
# elif num_bits != (4, 3) and (
# block_sizes is None or block_sizes.get("type", None) != "dynamic"
# ):
# raise ValueError(
# "Only blockwise dynamic quantization is supported with quantization "
# "formats E{num_bis[0]}M{num_bits[1]}."
# )
return self

axis: int | tuple[int, ...] | None = ModeloptField(
Expand Down
69 changes: 61 additions & 8 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .calib import MseCalibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import QuantModule, SequentialQuantizer, TensorQuantizer
from .tensor_quant import scaled_e4m3_impl
from .utils import (
disable_calib,
enable_fake_quant,
Expand All @@ -41,6 +42,7 @@
is_quantized_linear,
is_quantized_row_parallel_linear,
quantizer_attr_names,
reduce_amax,
weight_attr_names,
)

Expand Down Expand Up @@ -216,14 +218,18 @@ def mse_calibrate(
max_calibrate(model, forward_loop, distributed_sync)

# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
seen_modules = set()

for name, module in model.named_modules():
if isinstance(module, TensorQuantizer) and not module._disabled:
# Static block quantization is not supported by MseCalibrator
if module.is_static_block_quant:
raise ValueError(
f"MSE calibration does not support static block quantization. "
f"Found static block quantization at {name}."
)
# if module.is_static_block_quant:
# raise ValueError(
# f"MSE calibration does not support static block quantization. "
# f"Found static block quantization at {name}."
# )
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
Expand All @@ -237,7 +243,20 @@ def quant_func(x, amax, quantizer=module):
disable_calib(quantizer),
enable_fake_quant(quantizer),
):
quantizer._keep_shape = True
Copy link
Contributor

@realAsma realAsma Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 233:

if is_nvfp4_static: # static per-block
scale = amax/6.0
global_scale = weight.amax()/6.0
scale_fp8 = scaled_e4m3_impl(amax/6.0, weight.amax()/6.0) # fP8 quantization
amax_equivalent = scale_fp8 * 6.0

xq = quantizer(x)
quantizer._keep_shape = False

# FP8 quantization of NVFP4 static per-block scales
if (
quantizer.is_static_block_quant
and quantizer._num_bits == (2, 1)
and quantizer._block_sizes.get("scale_bits") == (4, 3)
):
weight_amax = reduce_amax(
x, axis=None, keepdims=False, squeeze_scalar=True
)
quantizer._amax = scaled_e4m3_impl(amax / 6.0, weight_amax / 6.0) * 6.0

if original_amax is not None:
quantizer._amax = original_amax
Expand All @@ -256,14 +275,48 @@ def quant_func(x, amax, quantizer=module):
quant_func=quant_func,
)

# Step 3: Collect data with MSE calibrators
# Identify weight quantizers by checking if they have corresponding weight parameters
for name, parent_module in model.named_modules():
if parent_module in seen_modules:
continue
for weight_name in weight_attr_names(parent_module):
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
if isinstance(weight_quantizer, TensorQuantizer) and not weight_quantizer._disabled:
if weight_quantizer._calibrator is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
seen_modules.add(parent_module)

# Step 3: Calibrate weight quantizers once with MSE calibration
# This ensures weights are only calibrated once, not during every forward pass
for parent_module, weight_name, weight_quantizer in weight_quantizers:
# Enable calibration mode for the weight quantizer
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()

with enable_weight_access_and_writeback(parent_module, model):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)

# Step 4: Disable weight quantizers during forward loop
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.disable()

# Step 5: Collect data with MSE calibrators for activation quantizers only
enable_stats_collection(model)
if forward_loop is None:
weight_only_quantize(model)
# If no forward loop, nothing else to do since weights are already calibrated
pass
else:
# Run forward loop - only activation quantizers will collect data
forward_loop(model)

# Step 4: Compute optimal amax and load it
# Step 6: Re-enable weight quantizers before finalizing calibration
# This ensures finish_stats_collection processes them correctly
for _, _, weight_quantizer in weight_quantizers:
weight_quantizer.enable()

# Step 7: Compute optimal amax and load it for all quantizers (weights + activations)
finish_stats_collection(model, method="mse")

# TODO: Sync amax across distributed processes
Expand Down
21 changes: 15 additions & 6 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
self._enable_pre_quant_scale = True
self._dequantize = False
self._input_dtype = None
self._keep_shape = False

# Lazy initialize the bias calibrator for KV cache quantization
self._bias_calibrator = None
Expand Down Expand Up @@ -653,6 +654,14 @@ def _fake_quantize(self, inputs):
getattr(self, "_onnx_quantizer_type", None),
self._pass_through_bwd,
)
elif self._num_bits == (2, 1) and self.is_static_block_quant:
from modelopt.torch.quantization.triton.fp4_kernel import (
launch_static_blockwise_fp4_fake_quant,
)

outputs = launch_static_blockwise_fp4_fake_quant(
inputs, amax / 6.0, out_dtype=inputs.dtype
)
elif isinstance(self._num_bits, tuple):
# Float-point quantization, e.g., FP8
E, M = self._num_bits # noqa: N806
Expand Down Expand Up @@ -783,11 +792,11 @@ def _process_for_blockquant(self, inputs: torch.Tensor):
if hasattr(self, "_padding"):
inputs = F.pad(inputs, self._padding, "constant", 0)

if inputs.shape != self._original_shape:
raise ValueError(
f"Input shape has changed from {self._original_shape} to {inputs.shape}."
" Block-quantization requires a fixed input shape."
)
# if inputs.shape != self._original_shape:
# print(
# f"Input shape has changed from {self._original_shape} to {inputs.shape}."
# " Block-quantization requires a fixed input shape."
# )
inputs = inputs.reshape(self._block_reshape_size)
return inputs

Expand Down Expand Up @@ -941,7 +950,7 @@ def forward(self, inputs):
"This case should have been handled."
)

if self.is_static_block_quant:
if self.is_static_block_quant and not self._keep_shape:
outputs = self._reset_to_original_shape(outputs)

return outputs
Expand Down
Loading