@@ -327,10 +327,15 @@ def merge_large_tensor_parallel(tensor, tp_group, tp_action, dst_rank, is_dst):
327
327
split_tensors = []
328
328
for i in range (num_splits ):
329
329
if get_env_device () == "xpu" :
330
- ret = distributed_allgather (tensor [split_parts [i ] : split_parts [i + 1 ], :], group = tp_group , offload = False )
330
+ ret = distributed_allgather (
331
+ tensor [split_parts [i ] : split_parts [i + 1 ], :].contiguous (), group = tp_group , offload = False
332
+ )
331
333
else :
332
334
ret = distributed_gather (
333
- tensor [split_parts [i ] : split_parts [i + 1 ], :], dst = dst_rank , group = tp_group , offload = False
335
+ tensor [split_parts [i ] : split_parts [i + 1 ], :].contiguous (),
336
+ dst = dst_rank ,
337
+ group = tp_group ,
338
+ offload = False ,
334
339
)
335
340
# Copy to CPUPlace temporarily, may lower speed.
336
341
if ret is not None :
@@ -383,9 +388,9 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
383
388
tensor = merge_large_tensor_parallel (tensor , tp_group , tp_actions [key ], j , is_dst )
384
389
else :
385
390
if get_env_device () == "xpu" :
386
- ret = distributed_allgather (tensor , group = tp_group , offload = False )
391
+ ret = distributed_allgather (tensor . contiguous () , group = tp_group , offload = False )
387
392
else :
388
- ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
393
+ ret = distributed_gather (tensor . contiguous () , dst = j , group = tp_group , offload = False )
389
394
action = tp_actions .pop (key )
390
395
tensor = action (ret ) if is_dst else None
391
396
else :
@@ -439,9 +444,9 @@ def merge_tensor_parallel_for_optimizer(state_dict, model_state_dict, tp_actions
439
444
tensor = merge_large_tensor_parallel (tensor , tp_group , tp_actions [model_key ], j , is_dst )
440
445
else :
441
446
if get_env_device () == "xpu" :
442
- ret = distributed_allgather (tensor , group = tp_group , offload = False )
447
+ ret = distributed_allgather (tensor . contiguous () , group = tp_group , offload = False )
443
448
else :
444
- ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
449
+ ret = distributed_gather (tensor . contiguous () , dst = j , group = tp_group , offload = False )
445
450
action = tp_actions [model_key ]
446
451
tensor = action (ret ) if is_dst else None
447
452
else :
0 commit comments