File tree Expand file tree Collapse file tree 3 files changed +7
-37
lines changed Expand file tree Collapse file tree 3 files changed +7
-37
lines changed Original file line number Diff line number Diff line change 20
20
import paddle .distributed as dist
21
21
from paddle .distributed import fleet
22
22
23
- from fastdeploy .distributed .parallel_state import get_tensor_model_parallel_world_size
24
-
25
23
_TP_AR = None
26
24
27
25
@@ -39,10 +37,9 @@ def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
39
37
hcg = fleet .get_hybrid_communicate_group ()
40
38
model_parallel_group = hcg .get_model_parallel_group ()
41
39
global _TP_AR
42
- if get_tensor_model_parallel_world_size () > 1 and paddle .is_compiled_with_cuda ():
43
- from fastdeploy .distributed .custom_all_reduce import CustomAllreduce
40
+ from fastdeploy .distributed .custom_all_reduce import CustomAllreduce
44
41
45
- _TP_AR = CustomAllreduce (model_parallel_group , custom_all_reduce_max_bytes )
42
+ _TP_AR = CustomAllreduce (model_parallel_group , custom_all_reduce_max_bytes )
46
43
47
44
48
45
try :
Load Diff This file was deleted.
Original file line number Diff line number Diff line change @@ -68,7 +68,11 @@ def init_device(self):
68
68
69
69
gc .collect ()
70
70
paddle .device .cuda .empty_cache ()
71
- if self .parallel_config .enable_custom_all_reduce :
71
+ if (
72
+ self .parallel_config .enable_custom_all_reduce
73
+ and self .parallel_config .tensor_parallel_size > 1
74
+ and paddle .is_compiled_with_cuda ()
75
+ ):
72
76
from fastdeploy .distributed .communication import use_custom_allreduce
73
77
74
78
use_custom_allreduce ()
You can’t perform that action at this time.
0 commit comments