@@ -180,16 +180,20 @@ def load_state_dict(self, state_dicts, sharded_input=False):
180180 self .load_worldsize = len (state_dicts )
181181 state_dicts = _shard_inclusive (state_dicts , self .rank , self .worldsize )
182182 if self .load_worldsize == self .worldsize :
183- [
184- setattr (self , flag , state_dicts [0 ][self .statename (flag )])
185- for flag in self .state_params + self .reshard_params
186- ]
183+ for flag in self .state_params + self .reshard_params :
184+ if self .statename (flag ) in state_dicts [0 ]:
185+ setattr (self , flag , state_dicts [0 ][self .statename (flag )])
186+ elif self .rank == 0 :
187+ logging .warning (f"Dataloader state key { self .statename (flag )} not present in checkpoint!" )
187188 else :
188189 for flag in self .reshard_params :
189- reshard = self ._reshard (
190- [sd [self .statename (flag )] for sd in state_dicts ]
191- )
192- setattr (self , flag , reshard )
190+ if self .statename (flag ) in state_dicts [0 ]:
191+ reshard = self ._reshard (
192+ [sd [self .statename (flag )] for sd in state_dicts ]
193+ )
194+ setattr (self , flag , reshard )
195+ elif self .rank == 0 :
196+ logging .warning (f"Dataloader state key { self .statename (flag )} not present in checkpoint!" )
193197 return state_dicts
194198
195199 def load_from_path (self , path : str ):
0 commit comments