Skip to content

Commit daf9f97

Browse files
committed
feat: Added guards for save_for_aiu and added a test for 0 clip vals
Signed-off-by: Brandon Groth <[email protected]>
1 parent ae67c8c commit daf9f97

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

fms_mo/quant/quantizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3029,7 +3029,7 @@ def __init__(
30293029
self.register_buffer("clip_valn", torch.zeros(perGp[0]))
30303030
else:
30313031
self.register_buffer(
3032-
"clip_val", torch.zeros(perCh) if perCh else torch.Tensor([1.0])
3032+
"clip_val", torch.zeros(perCh) if perCh else torch.Tensor([0.0])
30333033
)
30343034
self.register_buffer(
30353035
"clip_valn", torch.zeros(perCh) if perCh else torch.Tensor([0.0])

fms_mo/utils/aiu_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,14 @@ def process_weight(
232232
is_w_recomputed = False
233233
if layer_name + ".quantize_weight.clip_val" in model.state_dict():
234234
w_cv = model.state_dict()[layer_name + ".quantize_weight.clip_val"]
235+
236+
# Check that clip values are initialized
237+
if torch.any(w_cv.isclose(torch.tensor(0.0))):
238+
raise ValueError(
239+
f"Quantization clip values for {layer_name=} have near-zero values and "
240+
"are likely uninitialized."
241+
)
242+
235243
if w_cv.numel() > 1:
236244
w_cv = w_cv.unsqueeze(dim=1)
237245
weight_int_as_fp = torch.clamp(127 / w_cv * weight_pre_quant, -127, 127).round()

tests/models/test_save_aiu.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def test_save_model_bert(
2929
3030
Args:
3131
model_tiny_bert (BertModel): Bert Tiny Model
32-
config_tiny_bert (BertConfig): Bert Tiny config
3332
input_tiny (BatchEncoding): Fake tiny input
34-
qcfg_bert (dict): Quantized config for Bert
33+
qcfg_bert (dict): Quantized config for Tiny Bert
34+
bert_linear_names (list): Names of linear layers for Bert
3535
"""
3636
# Quantize model and save state dict
3737
qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True)
@@ -54,8 +54,8 @@ def test_large_outlier_bert(
5454
Args:
5555
model_tiny_bert (BertModel): Bert Tiny Model
5656
input_tiny (BatchEncoding): Bert Tiny config
57-
qcfg_bert (dict): Fake tiny input
58-
bert_linear_names (list): Quantized config for Bert
57+
qcfg_bert (dict): Quantized config for Tiny Bert
58+
bert_linear_names (list): Names of linear layers for Bert
5959
"""
6060
# Third Party
6161
import torch
@@ -87,6 +87,27 @@ def test_large_outlier_bert(
8787
assert torch.all(perCh_stdev_loaded >= perCh_stdev_model)
8888

8989

90+
def test_clip_vals_zero_bert(
91+
model_tiny_bert: BertModel,
92+
input_tiny: BatchEncoding,
93+
qcfg_bert: dict,
94+
):
95+
"""
96+
Test if uninitialized clip vals throws an error
97+
98+
Args:
99+
model_tiny_bert (BertModel): Bert Tiny Model
100+
input_tiny (BatchEncoding): Bert Tiny config
101+
qcfg_bert (dict): Quantized config for Tiny Bert
102+
"""
103+
# Turn off calibration -> clip vals are init as 0
104+
qcfg_bert["qmodel_calibration"] = 0
105+
qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True)
106+
107+
with pytest.raises(ValueError):
108+
save_for_aiu(model_tiny_bert, qcfg=qcfg_bert, verbose=True)
109+
110+
90111
def test_save_model_llama(
91112
model_tiny_llama: LlamaModel,
92113
input_tiny: BatchEncoding,

0 commit comments

Comments
 (0)