Skip to content

Commit f0f0361

Browse files
committed
Merge remote-tracking branch 'origin/prefill_overlap'
2 parents 63b9053 + 14d30b9 commit f0f0361

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,6 @@ def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
437437

438438
@torch.no_grad()
439439
def microbatch_overlap_prefill(self, batch: PrefillMicroBatch, batch1: PrefillMicroBatch):
440-
assert batch.batch_size == batch1.batch_size
441440
assert batch.mem_indexes.is_cuda
442441
assert batch1.mem_indexes.is_cuda
443442
input_ids, input_ids1 = batch.input_ids, batch1.input_ids

test/model/model_infer.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightllm.utils.envs_utils import get_env_start_args
99
from lightllm.models.deepseek2.model import Deepseek2TpPartModel
1010
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch
11+
from torch.profiler import profile, record_function, ProfilerActivity
1112

1213

1314
def test_model_inference(args, model_class):
@@ -116,6 +117,16 @@ def decode(
116117
return logits
117118

118119

120+
def torch_profile(fn, log_dir=None):
121+
with profile(
122+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
123+
record_shapes=False,
124+
on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir)
125+
) as prof:
126+
fn()
127+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
128+
129+
119130
def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, output_len, ans_queue):
120131
args = get_env_start_args()
121132
import triton.profiler as proton
@@ -244,6 +255,28 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o
244255
if args.profile:
245256
proton.start(name="forward_prefill", context="python")
246257

258+
if args.torch_profile:
259+
print("Profile Prefill")
260+
try:
261+
torch_profile(
262+
lambda: model_part.forward(
263+
batch_size,
264+
total_token_num,
265+
input_len,
266+
test_data,
267+
mem_indexes,
268+
b_req_idx,
269+
b_start_loc,
270+
b_seq_len,
271+
b_ready_cache_len=b_ready_cache_len,
272+
is_prefill=True,
273+
),
274+
log_dir=f"./logs_decode_overlap/forward_prefill_{model_kvargs['rank_id']}",
275+
)
276+
except Exception as e:
277+
print(str(e))
278+
raise
279+
247280
logics = model_part.forward(
248281
batch_size,
249282
total_token_num,
@@ -291,6 +324,21 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o
291324
b_seq_len,
292325
total_token_num,
293326
)
327+
if i == 0 and args.torch_profile:
328+
torch_profile(
329+
lambda: overlap_decode(
330+
model_part,
331+
batch_size,
332+
max_len_in_batch,
333+
torch.from_numpy(predict_ids).cuda().reshape(-1),
334+
mem_indexes,
335+
b_req_idx,
336+
b_start_loc,
337+
b_seq_len,
338+
total_token_num,
339+
),
340+
log_dir=f"./logs_decode_overlap/forward_decode_{model_kvargs['rank_id']}",
341+
)
294342
else:
295343
logits = decode(
296344
model_part,
@@ -303,6 +351,21 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o
303351
b_seq_len,
304352
total_token_num,
305353
)
354+
if i ==0 and args.torch_profile:
355+
torch_profile(
356+
lambda: decode(
357+
model_part,
358+
batch_size,
359+
max_len_in_batch,
360+
torch.from_numpy(predict_ids).cuda().reshape(-1),
361+
mem_indexes,
362+
b_req_idx,
363+
b_start_loc,
364+
b_seq_len,
365+
total_token_num,
366+
),
367+
log_dir=f"./logs_decode_overlap/forward_decode_{model_kvargs['rank_id']}",
368+
)
306369

307370
prob_out = torch.softmax(logits, dim=-1)
308371
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)

test/model/test_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lightllm.models.cohere.model import CohereTpPartModel
2525
from lightllm.models.mixtral.model import MixtralTpPartModel
2626
from lightllm.models.qwen2.model import Qwen2TpPartModel
27-
from lightllm.utils.config_utils import get_config_json
27+
from lightllm.utils.config_utils import get_config_json, get_dtype
2828

2929

3030
def get_model(weight_dir):
@@ -71,6 +71,8 @@ def get_model(weight_dir):
7171
class TestModelInfer(unittest.TestCase):
7272
def test_model_infer(self):
7373
args = get_env_start_args()
74+
if args.data_type is None:
75+
args.data_type = get_dtype(args.model_dir)
7476
model_dir = args.model_dir
7577
model_class = get_model(model_dir)
7678
test_model_inference(args, model_class)
@@ -89,6 +91,11 @@ def test_model_infer(self):
8991
action="store_true",
9092
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
9193
)
94+
parser.add_argument(
95+
"--torch_profile",
96+
action="store_true",
97+
help="Enable torch profiler to profile the model",
98+
)
9299
args = parser.parse_args()
93100
set_env_start_args(args)
94101
torch.multiprocessing.set_start_method("spawn")

0 commit comments

Comments
 (0)