Skip to content

Commit 1fdebcc

Browse files
committed
feat: Added testing helpers for state dicts
Signed-off-by: Brandon Groth <[email protected]>
1 parent 9028ad5 commit 1fdebcc

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/models/test_model_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def load_json(file_path: str = "qcfg.json"):
173173
assert json_file is not None, f"JSON at {file_path} was not found"
174174
return json_file
175175

176+
176177
def save_json(data, file_path: str = "qcfg.json"):
177178
"""
178179
Save data object to json file
@@ -184,6 +185,7 @@ def save_json(data, file_path: str = "qcfg.json"):
184185
with open(file_path, "w", encoding="utf-8") as outfile:
185186
json.dump(data, outfile, indent=4)
186187

188+
187189
def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
188190
"""
189191
Save qconfig by serializing it first
@@ -200,3 +202,44 @@ def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
200202

201203
serialize_config(config) # Only remove stuff necessary to dump
202204
save_json(config, file_path)
205+
206+
207+
################################
208+
# General state dict functions #
209+
################################
210+
211+
212+
def load_state_dict(fname: str = "qmodel_for_aiu.pt") -> dict:
213+
"""
214+
Load a model state dict .pt file.
215+
216+
Args:
217+
fname (str, optional): File for state dict of model. Defaults to "qmodel_for_aiu.pt".
218+
219+
Returns:
220+
dict: Model state dictionary
221+
"""
222+
return torch.load(fname)
223+
224+
225+
def check_linear_dtypes(state_dict: dict, linear_names: list):
226+
"""
227+
Checks a state dict for proper dtypes of linear names saved for AIU
228+
229+
Args:
230+
state_dict (dict): Saved model state dict
231+
linear_names (list): List of layer names that correspond to torch.nn.Linear layers
232+
"""
233+
assert state_dict is not None
234+
235+
# Check all quantized linear layers are int8 and everything else is fp16
236+
assert all(
237+
v.dtype == torch.int8
238+
for k, v in state_dict.items()
239+
if any(n in k for n in linear_names) and k.endswith(".weight")
240+
)
241+
assert all(
242+
v.dtype == torch.float16
243+
for k, v in state_dict.items()
244+
if all(n not in k for n in linear_names) or not k.endswith(".weight")
245+
)

0 commit comments

Comments
 (0)