Skip to content

Commit 0d5cda7

Browse files
Fix tests
1 parent 60725f2 commit 0d5cda7

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tests/test_parametrize.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics,
9090

9191
@pytest.mark.parametrize("device", get_available_devices())
9292
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
93-
@pytest.mark.parametrize("param_shape", [(64, 32), (8, 64, 32), (4, 8, 64, 32)])
94-
def test_moe_parameter_shapes(device, dtype, param_shape):
95-
"""Test parametrization with MoE-style parameter shapes, especially 3D tensors."""
96-
if device == "hpu" and dtype == torch.float16:
97-
pytest.skip("Float16 not supported on HPU.")
93+
def test_moe_parameter_shape(device, dtype):
94+
"""Test parametrization with MoE-style parameter shape"""
95+
if device == "hpu" and not is_supported_on_hpu("nf4", dtype):
96+
pytest.skip("This configuration is not supported on HPU.")
97+
98+
param_shape = (8, 64, 32)
9899

99100
# Create module with custom parameter shape directly on target device
100101
class MoEModule(nn.Module):
@@ -106,7 +107,7 @@ def __init__(self, device, dtype):
106107
original_param = module.param.clone()
107108

108109
# Apply quantization parametrization
109-
replace_parameter_4bit(module, "param", quant_type="nf4", blocksize=64)
110+
replace_parameter_4bit(module, "param", quant_type="nf4")
110111

111112
# Verify reconstruction maintains all properties
112113
reconstructed = module.param
@@ -120,8 +121,8 @@ def __init__(self, device, dtype):
120121
err_mean = err.mean()
121122

122123
# Use slightly looser bounds for higher dimensional tensors
123-
abs_bound = 0.085 if len(param_shape) > 2 else 0.08 # NF4 baseline + margin
124-
rel_bound = 0.25 if len(param_shape) > 2 else 0.22 # NF4 baseline + margin
124+
abs_bound = 0.085 # NF4 baseline + margin
125+
rel_bound = 0.25 # NF4 baseline + margin
125126

126127
assert err_mean < abs_bound, f"Mean abs error {err_mean:.6f} too high for shape {param_shape}"
127128
assert relerr < rel_bound, f"Mean rel error {relerr:.6f} too high for shape {param_shape}"
@@ -177,7 +178,7 @@ def test_state_dict_functionality(device, dtype, quant_type, compress_statistics
177178
assert "expert_weights" in state_dict, "Quantized parameter should be in state dict"
178179
assert "expert_weights.absmax" in state_dict, "Quantization absmax should be saved"
179180
assert "expert_weights.quant_map" in state_dict, "Quantization map should be saved"
180-
assert "expert_weights.quant_state.bitsandbytes__{quant_type}" in state_dict, "Quant state should be saved"
181+
assert f"expert_weights.quant_state.bitsandbytes__{quant_type}" in state_dict, "Quant state should be saved"
181182

182183
# Verify parametrization internals are NOT saved (clean state dict)
183184
assert "parametrizations.expert_weights.original" not in state_dict, (

0 commit comments

Comments
 (0)