Skip to content

Commit f9e853d

Browse files
committed
feat: Added test_save_aiu.py test
Signed-off-by: Brandon Groth <[email protected]>
1 parent 1fdebcc commit f9e853d

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

tests/models/test_save_aiu.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Third Party
2+
from transformers import BatchEncoding, BertModel, GraniteModel, LlamaModel
3+
import pytest
4+
5+
# Local
6+
from .test_model_utils import check_linear_dtypes, delete_config, load_state_dict
7+
from fms_mo import qmodel_prep
8+
from fms_mo.utils.aiu_utils import save_for_aiu
9+
10+
11+
@pytest.fixture(autouse=True)
12+
def delete_files():
13+
"""
14+
Delete any known files lingering before starting test
15+
"""
16+
delete_config("qcfg.json")
17+
delete_config("keys_to_save.json")
18+
delete_config("qmodel_for_aiu.pt")
19+
20+
21+
def test_save_model_bert(
22+
model_tiny_bert: BertModel,
23+
input_tiny: BatchEncoding,
24+
qcfg_bert: dict,
25+
bert_linear_names: list,
26+
):
27+
"""
28+
Save a BERT state dictionary and attempt to reload it to a fresh model
29+
30+
Args:
31+
model_tiny_bert (BertModel): Bert Tiny Model
32+
config_tiny_bert (BertConfig): Bert Tiny config
33+
input_tiny (BatchEncoding): Fake tiny input
34+
qcfg_bert (dict): Quantized config for Bert
35+
"""
36+
# Quantize model and save state dict
37+
qmodel_prep(model_tiny_bert, input_tiny, qcfg_bert, use_dynamo=True)
38+
save_for_aiu(model_tiny_bert, qcfg=qcfg_bert, verbose=True)
39+
40+
# Fetch saved state dict
41+
state_dict = load_state_dict()
42+
check_linear_dtypes(state_dict, bert_linear_names)
43+
44+
45+
def test_save_model_llama(
46+
model_tiny_llama: LlamaModel,
47+
input_tiny: BatchEncoding,
48+
qcfg_llama: dict,
49+
llama_linear_names: list,
50+
):
51+
"""
52+
Save a Llama state dictionary and attempt to reload it to a fresh model
53+
54+
Args:
55+
model_tiny_llama (LlamaModel): Llama Tiny Model
56+
config_tiny_llama (LlamaConfig): Llama Tiny config
57+
input_tiny (BatchEncoding): Fake tiny input
58+
qcfg_llama (dict): Quantized config for Llama
59+
"""
60+
# Quantize model and save state dict
61+
qmodel_prep(model_tiny_llama, input_tiny, qcfg_llama, use_dynamo=True)
62+
save_for_aiu(model_tiny_llama, qcfg=qcfg_llama, verbose=True)
63+
64+
# Fetch saved state dict
65+
state_dict = load_state_dict()
66+
check_linear_dtypes(state_dict, llama_linear_names)
67+
68+
69+
def test_save_model_granite(
70+
model_tiny_granite: GraniteModel,
71+
input_tiny: BatchEncoding,
72+
qcfg_granite: dict,
73+
granite_linear_names: list,
74+
):
75+
"""
76+
Save a Granite state dictionary and attempt to reload it to a fresh model
77+
78+
Args:
79+
model_tiny_granite (GraniteModel): Granite Tiny Model
80+
config_tiny_granite (GraniteConfig): Granite Tiny config
81+
input_tiny (BatchEncoding): Fake tiny input
82+
qcfg_granite (dict): Quantized config for Granite
83+
"""
84+
# Quantize model and save state dict
85+
qmodel_prep(model_tiny_granite, input_tiny, qcfg_granite, use_dynamo=True)
86+
save_for_aiu(model_tiny_granite, qcfg=qcfg_granite, verbose=True)
87+
88+
# Fetch saved state dict
89+
state_dict = load_state_dict()
90+
check_linear_dtypes(state_dict, granite_linear_names)

0 commit comments

Comments
 (0)