|
3 | 3 | import pytest |
4 | 4 |
|
5 | 5 | # Local |
6 | | -from .test_model_utils import ( |
7 | | - check_linear_dtypes, |
8 | | - delete_config, |
9 | | - load_state_dict, |
10 | | -) |
| 6 | +from .test_model_utils import check_linear_dtypes, delete_config, load_state_dict |
11 | 7 | from fms_mo import qmodel_prep |
12 | 8 | from fms_mo.utils.aiu_utils import save_for_aiu |
13 | 9 |
|
@@ -61,32 +57,33 @@ def test_large_outlier_bert( |
61 | 57 | qcfg_bert (dict): Fake tiny input |
62 | 58 | bert_linear_names (list): Quantized config for Bert |
63 | 59 | """ |
| 60 | + # Third Party |
64 | 61 | import torch |
65 | 62 |
|
66 | 63 | # Break every tensor channel with a large magnitude outlier |
67 | | - for k,v in model_tiny_bert.state_dict().items(): |
| 64 | + for k, v in model_tiny_bert.state_dict().items(): |
68 | 65 | if k.endswith(".weight") and any(n in k for n in bert_linear_names): |
69 | | - v[:,0] = 1.21 |
| 66 | + v[:, 0] = 1.21 |
70 | 67 |
|
71 | 68 | # Set recomputation for narrow weights and prep |
72 | 69 | qcfg_bert["recompute_narrow_weights"] = True |
73 | 70 | qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True) |
74 | 71 |
|
75 | 72 | # Qmax should break the quantization with an outlier to have skinny distribution |
76 | 73 | layer2stdev: dict[str, torch.Tensor] = {} |
77 | | - for k,v in model_tiny_bert.state_dict().items(): |
| 74 | + for k, v in model_tiny_bert.state_dict().items(): |
78 | 75 | if k.endswith(".weight") and any(n in k for n in bert_linear_names): |
79 | 76 | layer2stdev[k] = v.to(torch.float32).std(dim=-1) |
80 | 77 |
|
81 | 78 | save_for_aiu(model_tiny_bert, qcfg=qcfg_bert, verbose=True) |
82 | 79 | state_dict = load_state_dict() |
83 | 80 |
|
84 | 81 | # Loaded model w/ recomputed SAWB should have widened channel quantization stdev |
85 | | - for k,v in state_dict.items(): |
| 82 | + for k, v in state_dict.items(): |
86 | 83 | if k.endswith(".weight") and any(n in k for n in bert_linear_names): |
87 | 84 | perCh_stdev_model = layer2stdev.get(k) |
88 | 85 | perCh_stdev_loaded = v.to(torch.float32).std(dim=-1) |
89 | | - |
| 86 | + |
90 | 87 | assert torch.all(perCh_stdev_loaded >= perCh_stdev_model) |
91 | 88 |
|
92 | 89 |
|
@@ -136,4 +133,3 @@ def test_save_model_granite( |
136 | 133 | # Fetch saved state dict |
137 | 134 | state_dict = load_state_dict() |
138 | 135 | check_linear_dtypes(state_dict, granite_linear_names) |
139 | | - |
0 commit comments