Skip to content

Commit 97e60cb

Browse files
authored
[checkpointio] gather tensor before unpad it if the tensor is both padded and distributed (#6168)
1 parent 5b094a8 commit 97e60cb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,9 @@ def _model_sharder(
107107
if param is None:
108108
continue
109109
# Gather tensor pieces when using tensor parallel.
110-
if is_padded_tensor(param):
111-
param = to_unpadded_tensor(param)
112110
param_ = gather_distributed_param(param, keep_vars=False)
111+
if is_padded_tensor(param_):
112+
param_ = to_unpadded_tensor(param_)
113113
if pinned_state_dicts is not None:
114114
if (prefix + name) not in pinned_state_dicts:
115115
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")

0 commit comments

Comments
 (0)