Skip to content

Commit 561f144

Browse files
committed
fix: Updated stddev dim for test_large_outlier_bert
Signed-off-by: Brandon Groth <[email protected]>
1 parent 1e86fb5 commit 561f144

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/models/test_save_aiu.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_large_outlier_bert(
6060
# Third Party
6161
import torch
6262

63-
# Break every tensor channel with a large magnitude outlier
63+
# Break every tensor channel with a large magnitude outlier - should work for per tensor too
6464
for k, v in model_tiny_bert.state_dict().items():
6565
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
6666
v[:, 0] = 1.21
@@ -69,11 +69,15 @@ def test_large_outlier_bert(
6969
qcfg_bert["recompute_narrow_weights"] = True
7070
qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True)
7171

72+
# Reduce perCh or perTensor
73+
stddev_dim = -1 if "perCh" in qcfg_bert["qw_mode"] else None
74+
7275
# Qmax should break the quantization with an outlier to have skinny distribution
7376
layer2stdev: dict[str, torch.Tensor] = {}
7477
for k, v in model_tiny_bert.state_dict().items():
7578
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
76-
layer2stdev[k] = v.to(torch.float32).std(dim=-1)
79+
# Collect perCh or perTensor std dev
80+
layer2stdev[k] = v.to(torch.float32).std(dim=stddev_dim)
7781

7882
save_for_aiu(model_tiny_bert, qcfg=qcfg_bert, verbose=True)
7983
state_dict = load_state_dict()
@@ -82,8 +86,9 @@ def test_large_outlier_bert(
8286
for k, v in state_dict.items():
8387
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
8488
perCh_stdev_model = layer2stdev.get(k)
85-
perCh_stdev_loaded = v.to(torch.float32).std(dim=-1)
89+
perCh_stdev_loaded = v.to(torch.float32).std(dim=stddev_dim)
8690

91+
# SAWB stddev should be at least as good as Qmax stddev w/ outlier
8792
assert torch.all(perCh_stdev_loaded >= perCh_stdev_model)
8893

8994

0 commit comments

Comments
 (0)