Skip to content

Commit d36e61f

Browse files
committed
feat: Added recompute test in test_save_aiu
Signed-off-by: Brandon Groth <[email protected]>
1 parent 158dcb2 commit d36e61f

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

tests/models/test_model_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,4 @@ def check_linear_dtypes(state_dict: dict, linear_names: list):
238238
for k, v in state_dict.items()
239239
if all(n not in k for n in linear_names) or not k.endswith(".weight")
240240
)
241+

tests/models/test_save_aiu.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import pytest
44

55
# Local
6-
from .test_model_utils import check_linear_dtypes, delete_config, load_state_dict
6+
from .test_model_utils import (
7+
check_linear_dtypes,
8+
delete_config,
9+
load_state_dict,
10+
)
711
from fms_mo import qmodel_prep
812
from fms_mo.utils.aiu_utils import save_for_aiu
913

@@ -42,6 +46,50 @@ def test_save_model_bert(
4246
check_linear_dtypes(state_dict, bert_linear_names)
4347

4448

49+
def test_large_outlier_bert(
50+
model_tiny_bert: BertModel,
51+
input_tiny: BatchEncoding,
52+
qcfg_bert: dict,
53+
bert_linear_names: list,
54+
):
55+
"""
56+
Test if the recomputation mode increases standard deviation of a tensor with an outlier.
57+
58+
Args:
59+
model_tiny_bert (BertModel): Bert Tiny Model
60+
input_tiny (BatchEncoding): Bert Tiny config
61+
qcfg_bert (dict): Fake tiny input
62+
bert_linear_names (list): Quantized config for Bert
63+
"""
64+
import torch
65+
66+
# Break every tensor channel with a large magnitude outlier
67+
for k,v in model_tiny_bert.state_dict().items():
68+
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
69+
v[:,0] = 1.21
70+
71+
# Set recomputation for narrow weights and prep
72+
qcfg_bert["recompute_narrow_weights"] = True
73+
qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True)
74+
75+
# Qmax should break the quantization with an outlier to have skinny distribution
76+
layer2stdev: dict[str, torch.Tensor] = {}
77+
for k,v in model_tiny_bert.state_dict().items():
78+
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
79+
layer2stdev[k] = v.to(torch.float32).std(dim=-1)
80+
81+
save_for_aiu(model_tiny_bert, qcfg=qcfg_bert, verbose=True)
82+
state_dict = load_state_dict()
83+
84+
# Loaded model w/ recomputed SAWB should have widened channel quantization stdev
85+
for k,v in state_dict.items():
86+
if k.endswith(".weight") and any(n in k for n in bert_linear_names):
87+
perCh_stdev_model = layer2stdev.get(k)
88+
perCh_stdev_loaded = v.to(torch.float32).std(dim=-1)
89+
90+
assert torch.all(perCh_stdev_loaded >= perCh_stdev_model)
91+
92+
4593
def test_save_model_llama(
4694
model_tiny_llama: LlamaModel,
4795
input_tiny: BatchEncoding,
@@ -88,3 +136,4 @@ def test_save_model_granite(
88136
# Fetch saved state dict
89137
state_dict = load_state_dict()
90138
check_linear_dtypes(state_dict, granite_linear_names)
139+

0 commit comments

Comments
 (0)