We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a28fdde commit f388bbeCopy full SHA for f388bbe
colossalai/checkpoint_io/distributed_checkpoint_utils.py
@@ -408,7 +408,7 @@ def load_dist_model(
408
file_path = os.path.join(checkpoint, file)
409
state_dict_shard = load_state_dict(file_path)
410
for key, weight in state_dict_shard.items():
411
- if key not in covered_shards:
+ if key not in covered_shards or rank not in covered_shards[key]:
412
continue
413
if dtype == None:
414
dtype = weight.dtype
0 commit comments