@@ -90,11 +90,12 @@ def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics,
9090
9191@pytest .mark .parametrize ("device" , get_available_devices ())
9292@pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float16 , torch .bfloat16 ], ids = describe_dtype )
93- @pytest .mark .parametrize ("param_shape" , [(64 , 32 ), (8 , 64 , 32 ), (4 , 8 , 64 , 32 )])
94- def test_moe_parameter_shapes (device , dtype , param_shape ):
95- """Test parametrization with MoE-style parameter shapes, especially 3D tensors."""
96- if device == "hpu" and dtype == torch .float16 :
97- pytest .skip ("Float16 not supported on HPU." )
93+ def test_moe_parameter_shape (device , dtype ):
94+ """Test parametrization with MoE-style parameter shape"""
95+ if device == "hpu" and not is_supported_on_hpu ("nf4" , dtype ):
96+ pytest .skip ("This configuration is not supported on HPU." )
97+
98+ param_shape = (8 , 64 , 32 )
9899
99100 # Create module with custom parameter shape directly on target device
100101 class MoEModule (nn .Module ):
@@ -106,7 +107,7 @@ def __init__(self, device, dtype):
106107 original_param = module .param .clone ()
107108
108109 # Apply quantization parametrization
109- replace_parameter_4bit (module , "param" , quant_type = "nf4" , blocksize = 64 )
110+ replace_parameter_4bit (module , "param" , quant_type = "nf4" )
110111
111112 # Verify reconstruction maintains all properties
112113 reconstructed = module .param
@@ -120,8 +121,8 @@ def __init__(self, device, dtype):
120121 err_mean = err .mean ()
121122
122123 # Use slightly looser bounds for higher dimensional tensors
123- abs_bound = 0.085 if len ( param_shape ) > 2 else 0.08 # NF4 baseline + margin
124- rel_bound = 0.25 if len ( param_shape ) > 2 else 0.22 # NF4 baseline + margin
124+ abs_bound = 0.085 # NF4 baseline + margin
125+ rel_bound = 0.25 # NF4 baseline + margin
125126
126127 assert err_mean < abs_bound , f"Mean abs error { err_mean :.6f} too high for shape { param_shape } "
127128 assert relerr < rel_bound , f"Mean rel error { relerr :.6f} too high for shape { param_shape } "
@@ -177,7 +178,7 @@ def test_state_dict_functionality(device, dtype, quant_type, compress_statistics
177178 assert "expert_weights" in state_dict , "Quantized parameter should be in state dict"
178179 assert "expert_weights.absmax" in state_dict , "Quantization absmax should be saved"
179180 assert "expert_weights.quant_map" in state_dict , "Quantization map should be saved"
180- assert "expert_weights.quant_state.bitsandbytes__{quant_type}" in state_dict , "Quant state should be saved"
181+ assert f "expert_weights.quant_state.bitsandbytes__{ quant_type } " in state_dict , "Quant state should be saved"
181182
182183 # Verify parametrization internals are NOT saved (clean state dict)
183184 assert "parametrizations.expert_weights.original" not in state_dict , (
0 commit comments