@@ -67,23 +67,33 @@ def validate_environment(self, *args, **kwargs):
67
67
raise ImportError ("Using mxfp4 requires Accelerate: `pip install accelerate`" )
68
68
69
69
compute_capability = torch .cuda .get_device_capability ()
70
- major , minor = compute_capability
70
+ gpu_is_supported = compute_capability >= (7 , 5 )
71
+ kernels_available = is_triton_available ("3.4.0" ) and is_triton_kernels_availalble ()
71
72
72
- if not is_triton_available ("3.4.0" ) or not is_triton_kernels_availalble ():
73
- if self .pre_quantized and not self .quantization_config .dequantize :
73
+ if self .pre_quantized :
74
+ # On unsupported GPUs or without kernels, we will dequantize the model to bf16
75
+ if not gpu_is_supported :
74
76
logger .warning_once (
75
- "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
77
+ "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200). "
78
+ "We will default to dequantizing the model to bf16."
76
79
)
77
80
self .quantization_config .dequantize = True
78
81
return
79
- else :
80
- # we can't quantize the model in this case so we raise an error
81
- raise ValueError ("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed" )
82
82
83
- if major < 9 :
83
+ if not kernels_available :
84
+ logger .warning_once (
85
+ "MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
86
+ )
87
+ self .quantization_config .dequantize = True
88
+ return
89
+ elif not gpu_is_supported :
90
+ # we can't quantize the model in this case so we raise an error
84
91
raise ValueError (
85
- "MXFP4 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100, or B100 )"
92
+ "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200 )"
86
93
)
94
+ elif not kernels_available :
95
+ # we can't quantize the model in this case so we raise an error
96
+ raise ValueError ("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed" )
87
97
88
98
device_map = kwargs .get ("device_map" , None )
89
99
if device_map is None :
0 commit comments