Skip to content

Commit 6e1f6d3

Browse files
committed
fix target_bits ut
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 6889568 commit 6e1f6d3

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

neural_compressor/torch/algorithms/weight_only/autoround.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
201201
model, weight_config = rounder.quantize()
202202
model.autoround_config = weight_config
203203
return rounder.save_quantized(output_dir=output_dir, inplace=True)
204-
elif "itrex" in export_format:
204+
elif "itrex" in export_format: # TODO: remove itrex related code later
205205
model, weight_config = rounder.quantize()
206206
model.autoround_config = weight_config
207207
model = pack_model(model, weight_config, device=device, inplace=True)

test/3x/torch/quantization/weight_only/test_autoround.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,12 @@ def test_target_bits(self):
420420
model = prepare(model=fp32_model, quant_config=quant_config)
421421
model = convert(model)
422422
# mxfp4/8 model inference relys on autoround extension for vLLM.
423-
assert "MXFP8" in model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__, \
423+
assert ("MXFP8" in model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__ and \
424+
"MXFP4" in model.model.decoder.layers[1].fc1.__class__.__name__) \
425+
or \
426+
("MXFP4" in model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__ and \
427+
"MXFP8" in model.model.decoder.layers[1].fc1.__class__.__name__), \
424428
"model is not quantized correctly, please check."
425-
assert "MXFP4" in model.model.decoder.layers[1].fc1.__class__.__name__, \
426-
"model is not quantized correctly, please check."
427429

428430

429431
@pytest.mark.skipif(not ct_installed, reason="The compressed-tensors module is not installed.")
@@ -461,9 +463,11 @@ def eval_acc_fn(model) -> float:
461463
)
462464
best_model = autotune(model=fp32_model, tune_config=custom_tune_config, eval_fn=eval_acc_fn)
463465
# mxfp4/8 model inference relys on autoround extension for vLLM.
464-
assert "MXFP8" in best_model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__, \
465-
"model is not quantized correctly, please check."
466-
assert "MXFP8" in best_model.model.decoder.layers[1].fc1.__class__.__name__, \
466+
assert ("MXFP8" in best_model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__ and \
467+
"MXFP8" in best_model.model.decoder.layers[1].fc1.__class__.__name__) \
468+
or \
469+
("MXFP4" in best_model.model.decoder.layers[0].self_attn.k_proj.__class__.__name__ and \
470+
"MXFP4" in best_model.model.decoder.layers[1].fc1.__class__.__name__), \
467471
"model is not quantized correctly, please check."
468472

469473
def test_static_attention_dtype(self):

0 commit comments

Comments
 (0)