|
3 | 3 | import pytest |
4 | 4 |
|
5 | 5 | # Local |
6 | | -from .test_model_utils import check_linear_dtypes, delete_config, load_state_dict |
| 6 | +from .test_model_utils import ( |
| 7 | + check_linear_dtypes, |
| 8 | + delete_config, |
| 9 | + load_state_dict, |
| 10 | +) |
7 | 11 | from fms_mo import qmodel_prep |
8 | 12 | from fms_mo.utils.aiu_utils import save_for_aiu |
9 | 13 |
|
@@ -42,6 +46,50 @@ def test_save_model_bert( |
42 | 46 | check_linear_dtypes(state_dict, bert_linear_names) |
43 | 47 |
|
44 | 48 |
|
| 49 | +def test_large_outlier_bert( |
| 50 | + model_tiny_bert: BertModel, |
| 51 | + input_tiny: BatchEncoding, |
| 52 | + qcfg_bert: dict, |
| 53 | + bert_linear_names: list, |
| 54 | +): |
| 55 | + """ |
| 56 | + Test if the recomputation mode increases standard deviation of a tensor with an outlier. |
| 57 | +
|
| 58 | + Args: |
| 59 | + model_tiny_bert (BertModel): Bert Tiny Model |
| 60 | + input_tiny (BatchEncoding): Bert Tiny config |
| 61 | + qcfg_bert (dict): Fake tiny input |
| 62 | + bert_linear_names (list): Quantized config for Bert |
| 63 | + """ |
| 64 | + import torch |
| 65 | + |
| 66 | + # Break every tensor channel with a large magnitude outlier |
| 67 | + for k,v in model_tiny_bert.state_dict().items(): |
| 68 | + if k.endswith(".weight") and any(n in k for n in bert_linear_names): |
| 69 | + v[:,0] = 1.21 |
| 70 | + |
| 71 | + # Set recomputation for narrow weights and prep |
| 72 | + qcfg_bert["recompute_narrow_weights"] = True |
| 73 | + qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True) |
| 74 | + |
| 75 | + # Qmax should break the quantization with an outlier to have skinny distribution |
| 76 | + layer2stdev: dict[str, torch.Tensor] = {} |
| 77 | + for k,v in model_tiny_bert.state_dict().items(): |
| 78 | + if k.endswith(".weight") and any(n in k for n in bert_linear_names): |
| 79 | + layer2stdev[k] = v.to(torch.float32).std(dim=-1) |
| 80 | + |
| 81 | + save_for_aiu(model_tiny_bert, qcfg=qcfg_bert, verbose=True) |
| 82 | + state_dict = load_state_dict() |
| 83 | + |
| 84 | + # Loaded model w/ recomputed SAWB should have widened channel quantization stdev |
| 85 | + for k,v in state_dict.items(): |
| 86 | + if k.endswith(".weight") and any(n in k for n in bert_linear_names): |
| 87 | + perCh_stdev_model = layer2stdev.get(k) |
| 88 | + perCh_stdev_loaded = v.to(torch.float32).std(dim=-1) |
| 89 | + |
| 90 | + assert torch.all(perCh_stdev_loaded >= perCh_stdev_model) |
| 91 | + |
| 92 | + |
45 | 93 | def test_save_model_llama( |
46 | 94 | model_tiny_llama: LlamaModel, |
47 | 95 | input_tiny: BatchEncoding, |
@@ -88,3 +136,4 @@ def test_save_model_granite( |
88 | 136 | # Fetch saved state dict |
89 | 137 | state_dict = load_state_dict() |
90 | 138 | check_linear_dtypes(state_dict, granite_linear_names) |
| 139 | + |
0 commit comments