Skip to content

Commit c7844c7

Browse files
Enable gpt-oss mxfp4 on older hardware (sm75+) (#39940)
Co-authored-by: Marc Sun <[email protected]>
1 parent dd70a8c commit c7844c7

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

src/transformers/integrations/mxfp4.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,10 @@ def mlp_forward(self, hidden_states):
280280
batch_size = hidden_states.shape[0]
281281
hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
282282
router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
283-
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
283+
284+
with torch.cuda.device(router_logits.device):
285+
routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
286+
284287
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
285288
routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
286289
return routed_out, router_logits

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,33 @@ def validate_environment(self, *args, **kwargs):
6767
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
6868

6969
compute_capability = torch.cuda.get_device_capability()
70-
major, minor = compute_capability
70+
gpu_is_supported = compute_capability >= (7, 5)
71+
kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble()
7172

72-
if not is_triton_available("3.4.0") or not is_triton_kernels_availalble():
73-
if self.pre_quantized and not self.quantization_config.dequantize:
73+
if self.pre_quantized:
74+
# On unsupported GPUs or without kernels, we will dequantize the model to bf16
75+
if not gpu_is_supported:
7476
logger.warning_once(
75-
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
77+
"MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200). "
78+
"We will default to dequantizing the model to bf16."
7679
)
7780
self.quantization_config.dequantize = True
7881
return
79-
else:
80-
# we can't quantize the model in this case so we raise an error
81-
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
8282

83-
if major < 9:
83+
if not kernels_available:
84+
logger.warning_once(
85+
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
86+
)
87+
self.quantization_config.dequantize = True
88+
return
89+
elif not gpu_is_supported:
90+
# we can't quantize the model in this case so we raise an error
8491
raise ValueError(
85-
"MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100)"
92+
"MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200)"
8693
)
94+
elif not kernels_available:
95+
# we can't quantize the model in this case so we raise an error
96+
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
8797

8898
device_map = kwargs.get("device_map", None)
8999
if device_map is None:

tests/quantization/mxfp4/test_mxfp4.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,31 @@ def test_quantizer_validation_no_cuda(self):
107107

108108
def test_quantizer_validation_low_compute_capability(self):
109109
"""Test quantizer validation with low compute capability"""
110-
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
110+
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
111111
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
112112

113113
config = Mxfp4Config()
114114
quantizer = Mxfp4HfQuantizer(config)
115+
quantizer.pre_quantized = False
115116

116117
with self.assertRaises(ValueError):
117118
quantizer.validate_environment()
118119

120+
def test_quantizer_validation_low_compute_capability_with_prequantized(self):
121+
"""Test quantizer validation with low compute capability"""
122+
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
123+
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
124+
125+
config = Mxfp4Config()
126+
quantizer = Mxfp4HfQuantizer(config)
127+
128+
# Should automatically set dequantize=True and warn
129+
quantizer.validate_environment()
130+
self.assertTrue(quantizer.quantization_config.dequantize)
131+
119132
def test_quantizer_validation_low_compute_capability_with_dequantize(self):
120133
"""Test quantizer validation with low compute capability but dequantize enabled"""
121-
with patch("torch.cuda.get_device_capability", return_value=(8, 0)):
134+
with patch("torch.cuda.get_device_capability", return_value=(7, 0)):
122135
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
123136

124137
config = Mxfp4Config(dequantize=True)

0 commit comments

Comments
 (0)