Skip to content

Commit dc9ac0e

Browse files
authored
Fix the issue of loading ckpt when retraining (#2645)
1 parent b82939d commit dc9ac0e

File tree

1 file changed

+11
-6
lines changed
  • paddleformers/trainer/unified_checkpoint

1 file changed

+11
-6
lines changed

paddleformers/trainer/unified_checkpoint/utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,15 @@ def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst):
327327
split_tensors = []
328328
for i in range(num_splits):
329329
if get_env_device() == "xpu":
330-
ret = distributed_allgather(tensor[split_parts[i] : split_parts[i + 1], :], group=tp_group, offload=False)
330+
ret = distributed_allgather(
331+
tensor[split_parts[i] : split_parts[i + 1], :].contiguous(), group=tp_group, offload=False
332+
)
331333
else:
332334
ret = distributed_gather(
333-
tensor[split_parts[i] : split_parts[i + 1], :], dst=dst_rank, group=tp_group, offload=False
335+
tensor[split_parts[i] : split_parts[i + 1], :].contiguous(),
336+
dst=dst_rank,
337+
group=tp_group,
338+
offload=False,
334339
)
335340
# Copy to CPUPlace temporarily, may lower speed.
336341
if ret is not None:
@@ -383,9 +388,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
383388
tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[key], j, is_dst)
384389
else:
385390
if get_env_device() == "xpu":
386-
ret = distributed_allgather(tensor, group=tp_group, offload=False)
391+
ret = distributed_allgather(tensor.contiguous(), group=tp_group, offload=False)
387392
else:
388-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
393+
ret = distributed_gather(tensor.contiguous(), dst=j, group=tp_group, offload=False)
389394
action = tp_actions.pop(key)
390395
tensor = action(ret) if is_dst else None
391396
else:
@@ -439,9 +444,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, model_state_dict, tp_actions
439444
tensor = merge_large_tensor_parallel(tensor, tp_group, tp_actions[model_key], j, is_dst)
440445
else:
441446
if get_env_device() == "xpu":
442-
ret = distributed_allgather(tensor, group=tp_group, offload=False)
447+
ret = distributed_allgather(tensor.contiguous(), group=tp_group, offload=False)
443448
else:
444-
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
449+
ret = distributed_gather(tensor.contiguous(), dst=j, group=tp_group, offload=False)
445450
action = tp_actions[model_key]
446451
tensor = action(ret) if is_dst else None
447452
else:

0 commit comments

Comments
 (0)