@@ -420,10 +420,12 @@ def test_target_bits(self):
420420 model = prepare (model = fp32_model , quant_config = quant_config )
421421 model = convert (model )
422422 # mxfp4/8 model inference relys on autoround extension for vLLM.
423- assert "MXFP8" in model .model .decoder .layers [0 ].self_attn .k_proj .__class__ .__name__ , \
423+ assert ("MXFP8" in model .model .decoder .layers [0 ].self_attn .k_proj .__class__ .__name__ and \
424+ "MXFP4" in model .model .decoder .layers [1 ].fc1 .__class__ .__name__ ) \
425+ or \
426+ ("MXFP4" in model .model .decoder .layers [0 ].self_attn .k_proj .__class__ .__name__ and \
427+ "MXFP8" in model .model .decoder .layers [1 ].fc1 .__class__ .__name__ ), \
424428 "model is not quantized correctly, please check."
425- assert "MXFP4" in model .model .decoder .layers [1 ].fc1 .__class__ .__name__ , \
426- "model is not quantized correctly, please check."
427429
428430
429431 @pytest .mark .skipif (not ct_installed , reason = "The compressed-tensors module is not installed." )
@@ -461,9 +463,11 @@ def eval_acc_fn(model) -> float:
461463 )
462464 best_model = autotune (model = fp32_model , tune_config = custom_tune_config , eval_fn = eval_acc_fn )
463465 # mxfp4/8 model inference relys on autoround extension for vLLM.
464- assert "MXFP8" in best_model .model .decoder .layers [0 ].self_attn .k_proj .__class__ .__name__ , \
465- "model is not quantized correctly, please check."
466- assert "MXFP8" in best_model .model .decoder .layers [1 ].fc1 .__class__ .__name__ , \
466+ assert ("MXFP8" in best_model .model .decoder .layers [0 ].self_attn .k_proj .__class__ .__name__ and \
467+ "MXFP8" in best_model .model .decoder .layers [1 ].fc1 .__class__ .__name__ ) \
468+ or \
469+ ("MXFP4" in best_model .model .decoder .layers [0 ].self_attn .k_proj .__class__ .__name__ and \
470+ "MXFP4" in best_model .model .decoder .layers [1 ].fc1 .__class__ .__name__ ), \
467471 "model is not quantized correctly, please check."
468472
469473 def test_static_attention_dtype (self ):
0 commit comments