Skip to content

Commit 794c6b1

Browse files
authored
Merge branch 'main' into dist-ckp
2 parents c5b0882 + 97e60cb commit 794c6b1

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,11 @@ def _model_sharder(
118118
for name, param in model.named_parameters():
119119
if param is None:
120120
continue
121-
122-
if gather_dtensor:
123-
# Gather tensor pieces when using tensor parallel.
124-
if is_padded_tensor(param):
125-
param = to_unpadded_tensor(param)
126-
param_ = gather_distributed_param(param, keep_vars=False)
127-
else:
128-
param_ = param
129-
121+
122+
# Gather tensor pieces when using tensor parallel.
123+
param_ = gather_distributed_param(param, keep_vars=False)
124+
if is_padded_tensor(param_):
125+
param_ = to_unpadded_tensor(param_)
130126
if pinned_state_dicts is not None:
131127
if (prefix + name) not in pinned_state_dicts:
132128
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")

0 commit comments

Comments
 (0)