@@ -173,6 +173,7 @@ def load_json(file_path: str = "qcfg.json"):
173173 assert json_file is not None , f"JSON at { file_path } was not found"
174174 return json_file
175175
176+
176177def save_json (data , file_path : str = "qcfg.json" ):
177178 """
178179 Save data object to json file
@@ -184,6 +185,7 @@ def save_json(data, file_path: str = "qcfg.json"):
184185 with open (file_path , "w" , encoding = "utf-8" ) as outfile :
185186 json .dump (data , outfile , indent = 4 )
186187
188+
187189def save_serialized_json (config : dict , file_path : str = "qcfg.json" ):
188190 """
189191 Save qconfig by serializing it first
@@ -200,3 +202,44 @@ def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
200202
201203 serialize_config (config ) # Only remove stuff necessary to dump
202204 save_json (config , file_path )
205+
206+
207+ ################################
208+ # General state dict functions #
209+ ################################
210+
211+
212+ def load_state_dict (fname : str = "qmodel_for_aiu.pt" ) -> dict :
213+ """
214+ Load a model state dict .pt file.
215+
216+ Args:
217+ fname (str, optional): File for state dict of model. Defaults to "qmodel_for_aiu.pt".
218+
219+ Returns:
220+ dict: Model state dictionary
221+ """
222+ return torch .load (fname )
223+
224+
225+ def check_linear_dtypes (state_dict : dict , linear_names : list ):
226+ """
227+ Checks a state dict for proper dtypes of linear names saved for AIU
228+
229+ Args:
230+ state_dict (dict): Saved model state dict
231+ linear_names (list): List of layer names that correspond to torch.nn.Linear layers
232+ """
233+ assert state_dict is not None
234+
235+ # Check all quantized linear layers are int8 and everything else is fp16
236+ assert all (
237+ v .dtype == torch .int8
238+ for k , v in state_dict .items ()
239+ if any (n in k for n in linear_names ) and k .endswith (".weight" )
240+ )
241+ assert all (
242+ v .dtype == torch .float16
243+ for k , v in state_dict .items ()
244+ if all (n not in k for n in linear_names ) or not k .endswith (".weight" )
245+ )
0 commit comments