diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a9c8b45a2..9bc45a438 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -162,6 +162,37 @@ def is_preset_scheme(name: str) -> bool: ), ) +MXFP4A16 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=32, + ) +) + +MXFP4 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=32, + ), + input_activations=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + dynamic=True, + symmetric=True, + group_size=32, + ), +) + + # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( @@ -313,4 +344,6 @@ def is_preset_scheme(name: str) -> bool: "FP8_BLOCK": FP8_BLOCK, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, + "MXFP4": MXFP4, + "MXFP4A16": MXFP4A16, }