Skip to content

Commit b622fe3

Browse files
authored
complete test module (#634)
1 parent 8d936d3 commit b622fe3

File tree

2 files changed

+11
-54
lines changed

2 files changed

+11
-54
lines changed

test/model/model_infer.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,17 @@ def test_model_inference(world_size, model_class, batch_size, input_len, output_
3838

3939
def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_len, ans_queue):
4040
import torch
41-
from lightllm.distributed import (
42-
get_tp_group,
43-
init_distributed_environment,
44-
initialize_model_parallel,
45-
get_tensor_model_parallel_world_size,
46-
get_tensor_model_parallel_rank,
47-
all_reduce,
48-
)
41+
from lightllm.distributed import set_custom_reduce
4942
import torch.distributed as dist
5043

5144
rank_id = model_kvargs["tp_rank"]
5245
world_size = model_kvargs["world_size"]
5346

5447
torch.cuda.set_device(rank_id)
55-
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", "False").upper() in [
56-
"ON",
57-
"TRUE",
58-
"1",
59-
]
60-
if LIGHTLLM_PYNCCL_ENABLE:
61-
init_distributed_environment(
62-
backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765"
63-
)
64-
initialize_model_parallel(tensor_model_parallel_size=world_size)
65-
tp_group = get_tp_group()
66-
dist.all_reduce = all_reduce
67-
dist.get_rank = get_tensor_model_parallel_rank
68-
dist.get_world_size = get_tensor_model_parallel_world_size
69-
tp_group.barrier()
70-
else:
71-
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
72-
dist.barrier()
48+
49+
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
50+
set_custom_reduce()
51+
dist.barrier()
7352

7453
torch.cuda.empty_cache()
7554

@@ -137,7 +116,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
137116
b_start_loc = None
138117
b_seq_len = None
139118

140-
tp_group.barrier()
119+
dist.barrier()
141120
import time
142121

143122
torch.cuda.synchronize()

test/model/test_settings/model_infer_batchs.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -64,38 +64,16 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output
6464
return
6565

6666
import torch
67-
from lightllm.distributed import (
68-
get_tp_group,
69-
init_distributed_environment,
70-
initialize_model_parallel,
71-
get_tensor_model_parallel_world_size,
72-
get_tensor_model_parallel_rank,
73-
all_reduce,
74-
)
67+
from lightllm.distributed import set_custom_reduce
7568
import torch.distributed as dist
7669

7770
rank_id = model_kvargs["tp_rank"]
7871
world_size = model_kvargs["world_size"]
7972

8073
torch.cuda.set_device(rank_id)
81-
LIGHTLLM_PYNCCL_ENABLE = os.getenv("LIGHTLLM_PYNCCL_ENABLE", "False").upper() in [
82-
"ON",
83-
"TRUE",
84-
"1",
85-
]
86-
if LIGHTLLM_PYNCCL_ENABLE:
87-
init_distributed_environment(
88-
backend="nccl", world_size=world_size, rank=rank_id, distributed_init_method="tcp://127.0.0.1:28765"
89-
)
90-
initialize_model_parallel(tensor_model_parallel_size=world_size)
91-
tp_group = get_tp_group()
92-
dist.all_reduce = all_reduce
93-
dist.get_rank = get_tensor_model_parallel_rank
94-
dist.get_world_size = get_tensor_model_parallel_world_size
95-
tp_group.barrier()
96-
else:
97-
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
98-
dist.barrier()
74+
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
75+
set_custom_reduce()
76+
dist.barrier()
9977

10078
torch.cuda.empty_cache()
10179

@@ -154,7 +132,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output
154132
b_start_loc = None
155133
b_seq_len = None
156134

157-
tp_group.barrier()
135+
dist.barrier()
158136
if rank_id == 0:
159137
new_log_path = log_path.replace("batch_size", str(batch_size))
160138
fp_file = open(new_log_path, "w+")

0 commit comments

Comments
 (0)