Skip to content

Commit aedb8d1

Browse files
authored
Fix gather_state_dict_fast
1 parent d322ff8 commit aedb8d1

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

colossalai/checkpoint_io/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,18 +1132,20 @@ def gather_state_dict_fast(
11321132
if rank == dst:
11331133
returned_state_dict = state_dict.copy()
11341134
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
1135+
ks, ops = [], []
11351136
for i, target_metadata in enumerate(all_meta_data):
11361137
if i == dst:
11371138
continue
1138-
ops = []
11391139
for k, shape, dtype in target_metadata:
11401140
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
11411141
returned_state_dict[k] = buffer
1142+
ks.append(k)
11421143
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
1143-
reqs = dist.batch_isend_irecv(ops)
1144-
for req, (k, *_) in zip(reqs, target_metadata):
1145-
req.wait()
1146-
returned_state_dict[k] = returned_state_dict[k].to(device)
1144+
reqs = dist.batch_isend_irecv(ops)
1145+
for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing
1146+
req.wait()
1147+
for k in ks:
1148+
returned_state_dict[k] = returned_state_dict[k].to(device)
11471149
return returned_state_dict
11481150
else:
11491151
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)

0 commit comments

Comments
 (0)