Skip to content

Commit 515b2a5

Browse files
authored
Permissive load
1 parent 8324609 commit 515b2a5

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

fms_fsdp/utils/dataset_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,16 +180,20 @@ def load_state_dict(self, state_dicts, sharded_input=False):
180180
self.load_worldsize = len(state_dicts)
181181
state_dicts = _shard_inclusive(state_dicts, self.rank, self.worldsize)
182182
if self.load_worldsize == self.worldsize:
183-
[
184-
setattr(self, flag, state_dicts[0][self.statename(flag)])
185-
for flag in self.state_params + self.reshard_params
186-
]
183+
for flag in self.state_params + self.reshard_params:
184+
if self.statename(flag) in state_dicts[0]:
185+
setattr(self, flag, state_dicts[0][self.statename(flag)])
186+
elif self.rank == 0:
187+
logging.warning(f"Dataloader state key {self.statename(flag)} not present in checkpoint!")
187188
else:
188189
for flag in self.reshard_params:
189-
reshard = self._reshard(
190-
[sd[self.statename(flag)] for sd in state_dicts]
191-
)
192-
setattr(self, flag, reshard)
190+
if self.statename(flag) in state_dicts[0]:
191+
reshard = self._reshard(
192+
[sd[self.statename(flag)] for sd in state_dicts]
193+
)
194+
setattr(self, flag, reshard)
195+
elif self.rank == 0:
196+
logging.warning(f"Dataloader state key {self.statename(flag)} not present in checkpoint!")
193197
return state_dicts
194198

195199
def load_from_path(self, path: str):

0 commit comments

Comments
 (0)