Skip to content

Commit a67fe12

Browse files
committed
feat: Added testing helpers for state dicts
Signed-off-by: Brandon Groth <[email protected]>
1 parent bc55331 commit a67fe12

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/models/test_model_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def load_json(file_path: str = "qcfg.json"):
168168
assert json_file is not None, f"JSON at {file_path} was not found"
169169
return json_file
170170

171+
171172
def save_json(data, file_path: str = "qcfg.json"):
172173
"""
173174
Save data object to json file
@@ -179,6 +180,7 @@ def save_json(data, file_path: str = "qcfg.json"):
179180
with open(file_path, "w", encoding="utf-8") as outfile:
180181
json.dump(data, outfile, indent=4)
181182

183+
182184
def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
183185
"""
184186
Save qconfig by serializing it first
@@ -195,3 +197,44 @@ def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
195197

196198
serialize_config(config) # Only remove stuff necessary to dump
197199
save_json(config, file_path)
200+
201+
202+
################################
203+
# General state dict functions #
204+
################################
205+
206+
207+
def load_state_dict(fname: str = "qmodel_for_aiu.pt") -> dict:
208+
"""
209+
Load a model state dict .pt file.
210+
211+
Args:
212+
fname (str, optional): File for state dict of model. Defaults to "qmodel_for_aiu.pt".
213+
214+
Returns:
215+
dict: Model state dictionary
216+
"""
217+
return torch.load(fname)
218+
219+
220+
def check_linear_dtypes(state_dict: dict, linear_names: list):
221+
"""
222+
Checks a state dict for proper dtypes of linear names saved for AIU
223+
224+
Args:
225+
state_dict (dict): Saved model state dict
226+
linear_names (list): List of layer names that correspond to torch.nn.Linear layers
227+
"""
228+
assert state_dict is not None
229+
230+
# Check all quantized linear layers are int8 and everything else is fp16
231+
assert all(
232+
v.dtype == torch.int8
233+
for k, v in state_dict.items()
234+
if any(n in k for n in linear_names) and k.endswith(".weight")
235+
)
236+
assert all(
237+
v.dtype == torch.float16
238+
for k, v in state_dict.items()
239+
if all(n not in k for n in linear_names) or not k.endswith(".weight")
240+
)

0 commit comments

Comments
 (0)