@@ -38,38 +38,17 @@ def test_model_inference(world_size, model_class, batch_size, input_len, output_
3838
3939def tppart_model_infer (model_class , model_kvargs , batch_size , input_len , output_len , ans_queue ):
4040 import torch
41- from lightllm .distributed import (
42- get_tp_group ,
43- init_distributed_environment ,
44- initialize_model_parallel ,
45- get_tensor_model_parallel_world_size ,
46- get_tensor_model_parallel_rank ,
47- all_reduce ,
48- )
41+ from lightllm .distributed import set_custom_reduce
4942 import torch .distributed as dist
5043
5144 rank_id = model_kvargs ["tp_rank" ]
5245 world_size = model_kvargs ["world_size" ]
5346
5447 torch .cuda .set_device (rank_id )
55- LIGHTLLM_PYNCCL_ENABLE = os .getenv ("LIGHTLLM_PYNCCL_ENABLE" , "False" ).upper () in [
56- "ON" ,
57- "TRUE" ,
58- "1" ,
59- ]
60- if LIGHTLLM_PYNCCL_ENABLE :
61- init_distributed_environment (
62- backend = "nccl" , world_size = world_size , rank = rank_id , distributed_init_method = "tcp://127.0.0.1:28765"
63- )
64- initialize_model_parallel (tensor_model_parallel_size = world_size )
65- tp_group = get_tp_group ()
66- dist .all_reduce = all_reduce
67- dist .get_rank = get_tensor_model_parallel_rank
68- dist .get_world_size = get_tensor_model_parallel_world_size
69- tp_group .barrier ()
70- else :
71- dist .init_process_group ("nccl" , init_method = "tcp://127.0.0.1:28765" , rank = rank_id , world_size = world_size )
72- dist .barrier ()
48+
49+ dist .init_process_group ("nccl" , init_method = "tcp://127.0.0.1:28765" , rank = rank_id , world_size = world_size )
50+ set_custom_reduce ()
51+ dist .barrier ()
7352
7453 torch .cuda .empty_cache ()
7554
@@ -137,7 +116,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
137116 b_start_loc = None
138117 b_seq_len = None
139118
140- tp_group .barrier ()
119+ dist .barrier ()
141120 import time
142121
143122 torch .cuda .synchronize ()
0 commit comments