File tree Expand file tree Collapse file tree 1 file changed +15
-8
lines changed Expand file tree Collapse file tree 1 file changed +15
-8
lines changed Original file line number Diff line number Diff line change @@ -158,16 +158,23 @@ def load_torch_state_dict(
158158
159159 incompatible = model .load_state_dict (state )
160160 if (
161- incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
162- and incompatible .missing_keys
161+ isinstance (incompatible , tuple )
162+ and hasattr (incompatible , "missing_keys" )
163+ and hasattr (incompatible , "unexpected_keys" )
163164 ):
164- logger .warning ("Missing state dict keys: {}" , incompatible .missing_keys )
165+ if incompatible .missing_keys :
166+ logger .warning ("Missing state dict keys: {}" , incompatible .missing_keys )
165167
166- if (
167- incompatible is not None # pyright: ignore[reportUnnecessaryComparison]
168- and incompatible .unexpected_keys
169- ):
170- logger .warning ("Unexpected state dict keys: {}" , incompatible .unexpected_keys )
168+ if hasattr (incompatible , "unexpected_keys" ) and incompatible .unexpected_keys :
169+ logger .warning (
170+ "Unexpected state dict keys: {}" , incompatible .unexpected_keys
171+ )
172+ else :
173+ logger .warning (
174+ "`model.load_state_dict()` unexpectedly returned: {} "
175+ + "(expected named tuple with `missing_keys` and `unexpected_keys` attributes)" ,
176+ (s [:20 ] + "..." if len (s := str (incompatible )) > 20 else s ),
177+ )
171178
172179 return model
173180
You can’t perform that action at this time.
0 commit comments