@@ -168,6 +168,7 @@ def load_json(file_path: str = "qcfg.json"):
168168 assert json_file is not None , f"JSON at { file_path } was not found"
169169 return json_file
170170
171+
171172def save_json (data , file_path : str = "qcfg.json" ):
172173 """
173174 Save data object to json file
@@ -179,6 +180,7 @@ def save_json(data, file_path: str = "qcfg.json"):
179180 with open (file_path , "w" , encoding = "utf-8" ) as outfile :
180181 json .dump (data , outfile , indent = 4 )
181182
183+
182184def save_serialized_json (config : dict , file_path : str = "qcfg.json" ):
183185 """
184186 Save qconfig by serializing it first
@@ -195,3 +197,44 @@ def save_serialized_json(config: dict, file_path: str = "qcfg.json"):
195197
196198 serialize_config (config ) # Only remove stuff necessary to dump
197199 save_json (config , file_path )
200+
201+
202+ ################################
203+ # General state dict functions #
204+ ################################
205+
206+
207+ def load_state_dict (fname : str = "qmodel_for_aiu.pt" ) -> dict :
208+ """
209+ Load a model state dict .pt file.
210+
211+ Args:
212+ fname (str, optional): File for state dict of model. Defaults to "qmodel_for_aiu.pt".
213+
214+ Returns:
215+ dict: Model state dictionary
216+ """
217+ return torch .load (fname )
218+
219+
220+ def check_linear_dtypes (state_dict : dict , linear_names : list ):
221+ """
222+ Checks a state dict for proper dtypes of linear names saved for AIU
223+
224+ Args:
225+ state_dict (dict): Saved model state dict
226+ linear_names (list): List of layer names that correspond to torch.nn.Linear layers
227+ """
228+ assert state_dict is not None
229+
230+ # Check all quantized linear layers are int8 and everything else is fp16
231+ assert all (
232+ v .dtype == torch .int8
233+ for k , v in state_dict .items ()
234+ if any (n in k for n in linear_names ) and k .endswith (".weight" )
235+ )
236+ assert all (
237+ v .dtype == torch .float16
238+ for k , v in state_dict .items ()
239+ if all (n not in k for n in linear_names ) or not k .endswith (".weight" )
240+ )
0 commit comments