@@ -244,33 +244,24 @@ def __init__(self, **kwargs):
244244 )
245245
246246 missing , unexpected = None , None
247- try :
248- # assign=True: load params/buffers by assignment instead of performing an in-place copy.
249- # Because we are using device="meta", tensors do not have memory associated with them
250- # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
251-
252- # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
253- # by default initialized to fp32. This is fine because every other supported type
254- # losslessly converts to fp32, so we don't lose precision here.
255- if checkpoint :
256- missing , unexpected = self .model_ .load_state_dict (
257- checkpoint ,
258- strict = False ,
259- assign = True ,
260- ) # self.model_ = Transformer(gptconf)
261- else :
262- print ("Checkpoint not provided, defaulting weights to zeros." )
263- self .model_ .to_empty (device = "cpu" )
264- for p in self .model_ .parameters ():
265- p .data .fill_ (0 )
266- for b in self .model_ .buffers ():
267- b .data .fill_ (0 )
268- except RuntimeError as e :
269- print (
270- f"Could not load checkpoint into mode and will defaulting weights to zeros due to error: { e } ."
271- )
272- # Need to provide concrete (empty) values for meta-initialized tensors for quantization.
247+ # assign=True: load params/buffers by assignment instead of performing an in-place copy.
248+ # Because we are using device="meta", tensors do not have memory associated with them
249+ # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
250+
251+ # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
252+ # by default initialized to fp32. This is fine because every other supported type
253+ # losslessly converts to fp32, so we don't lose precision here.
254+ if checkpoint :
255+ missing , unexpected = self .model_ .load_state_dict (
256+ checkpoint ,
257+ strict = False ,
258+ assign = True ,
259+ ) # self.model_ = Transformer(gptconf)
260+ else :
261+ print ("Checkpoint not provided, defaulting weights to zeros." )
273262 self .model_ .to_empty (device = "cpu" )
263+ # Need to provide concrete values for meta-initialized tensors for quantization.
264+ # otherwise it is just filled with nan's.
274265 for p in self .model_ .parameters ():
275266 p .data .fill_ (0 )
276267 for b in self .model_ .buffers ():
0 commit comments