Skip to content

Commit dd70a8c

Browse files
authored
Fix MXFP4 quantizer validation to allow CPU inference with dequantize option (#39953)
* Fix MXFP4 quantizer validation to enable CPU dequantization Move dequantize check before CUDA availability check to allow CPU inference when quantization_config.dequantize is True. This enables users to run MXFP4 models on CPU by automatically converting them to BF16 format. * Add tests for MXFP4 quantizer CPU dequantization validation * fix: format mxfp4 test file with ruff
1 parent 82eb67e commit dd70a8c

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,16 @@ def validate_environment(self, *args, **kwargs):
5656
"Using mxfp4 quantization requires torch"
5757
"Please install the latest version of torch ( pip install --upgrade torch )"
5858
)
59+
60+
if self.quantization_config.dequantize:
61+
return
62+
5963
if not torch.cuda.is_available():
6064
raise RuntimeError("Using MXFP4 quantized models requires a GPU")
6165

6266
if not is_accelerate_available():
6367
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
6468

65-
if self.quantization_config.dequantize:
66-
return
67-
6869
compute_capability = torch.cuda.get_device_capability()
6970
major, minor = compute_capability
7071

tests/quantization/mxfp4/test_mxfp4.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,52 @@ def test_quantizer_validation_low_compute_capability_with_dequantize(self):
131131
if "compute capability" in str(e):
132132
self.fail("Should not raise compute capability error when dequantize=True")
133133

134+
def test_quantizer_validation_dequantize_on_cpu(self):
135+
"""Test quantizer validation with dequantize enabled on CPU-only environment"""
136+
with patch("torch.cuda.is_available", return_value=False):
137+
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
138+
139+
config = Mxfp4Config(dequantize=True)
140+
quantizer = Mxfp4HfQuantizer(config)
141+
142+
# Should not raise error when dequantize=True even without CUDA
143+
try:
144+
quantizer.validate_environment()
145+
except RuntimeError as e:
146+
if "requires a GPU" in str(e):
147+
self.fail("Should not raise GPU requirement error when dequantize=True on CPU")
148+
149+
def test_quantizer_validation_order_dequantize_before_cuda_check(self):
150+
"""Test that dequantize check happens before CUDA availability check"""
151+
# Mock both torch.cuda.is_available and is_accelerate_available to return False
152+
with (
153+
patch("torch.cuda.is_available", return_value=False),
154+
patch(
155+
"transformers.quantizers.quantizer_mxfp4.is_accelerate_available",
156+
return_value=False,
157+
),
158+
):
159+
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
160+
161+
# Test with dequantize=True - should pass even without CUDA and accelerate
162+
config = Mxfp4Config(dequantize=True)
163+
quantizer = Mxfp4HfQuantizer(config)
164+
165+
# This should not raise any error because dequantize check comes first
166+
try:
167+
quantizer.validate_environment()
168+
except (RuntimeError, ImportError) as e:
169+
if "requires a GPU" in str(e) or "requires Accelerate" in str(e):
170+
self.fail(f"Should not raise error when dequantize=True: {e}")
171+
172+
# Test with dequantize=False - should still fail due to missing CUDA
173+
config = Mxfp4Config(dequantize=False)
174+
quantizer = Mxfp4HfQuantizer(config)
175+
176+
with self.assertRaises(RuntimeError) as context:
177+
quantizer.validate_environment()
178+
self.assertIn("requires a GPU", str(context.exception))
179+
134180
def test_quantizer_validation_missing_triton(self):
135181
"""Test quantizer validation when triton is not available"""
136182
with (

0 commit comments

Comments
 (0)