Skip to content

Commit 1e86fb5

Browse files
committed
fix: Updated check_linear_dtypes for zero_point=fp32
Signed-off-by: Brandon Groth <[email protected]>
1 parent d4e713a commit 1e86fb5

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

tests/models/test_model_utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,15 @@ def check_linear_dtypes(state_dict: dict, linear_names: list):
227227
"""
228228
assert state_dict is not None
229229

230-
# Check all quantized linear layers are int8 and everything else is fp16
231-
assert all(
232-
v.dtype == torch.int8
233-
for k, v in state_dict.items()
234-
if any(n in k for n in linear_names) and k.endswith(".weight")
235-
)
236-
assert all(
237-
v.dtype == torch.float16
238-
for k, v in state_dict.items()
239-
if all(n not in k for n in linear_names) or not k.endswith(".weight")
240-
)
230+
for k, v in state_dict.items():
231+
# If k is a quantized layer, check weighs (int8), zero_point(fp32)
232+
if any(n in k for n in linear_names):
233+
if k.endswith(".weight"):
234+
assert v.dtype == torch.int8
235+
elif k.endswith(".zero_point"):
236+
assert v.dtype == torch.float32
237+
else:
238+
assert v.dtype == torch.float16
239+
else:
240+
# Everything else should be fp16
241+
assert v.dtype == torch.float16

0 commit comments

Comments
 (0)