Skip to content

Commit f5f54fd

Browse files
fix mtp static bench (#1009)
fix mtp params of static benchmark model_infer.py.
1 parent db83cce commit f5f54fd

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

test/benchmark/static_inference/model_infer.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def overlap_prefill(
6666
input_ids,
6767
mem_indexes,
6868
b_req_idx,
69+
b_mtp_index,
6970
b_seq_len,
7071
total_token_num,
7172
b_ready_cache_len,
@@ -76,16 +77,18 @@ def overlap_prefill(
7677
_0_input_ids = input_ids[: total_token_num // 2]
7778
_0_mem_indexes = mem_indexes[: total_token_num // 2]
7879
_0_b_req_idx = b_req_idx[: batch_size // 2]
80+
_0_b_mtp_index = b_mtp_index[: batch_size // 2]
7981
_0_b_seq_len = b_seq_len[: batch_size // 2]
8082
_o_b_ready_cache_len = b_ready_cache_len[: batch_size // 2]
8183
micro_batch1 = ModelInput(
8284
_0_batch_size,
8385
_0_total_token_num,
8486
_0_max_len_in_batch,
8587
_0_input_ids,
86-
_0_mem_indexes,
8788
_0_b_req_idx,
89+
_0_b_mtp_index,
8890
_0_b_seq_len,
91+
_0_mem_indexes,
8992
True,
9093
_o_b_ready_cache_len,
9194
{},
@@ -97,6 +100,7 @@ def overlap_prefill(
97100
_1_input_ids = input_ids[total_token_num // 2 :]
98101
_1_mem_indexes = mem_indexes[total_token_num // 2 :]
99102
_1_b_req_idx = b_req_idx[batch_size // 2 :]
103+
_1_b_mtp_index = b_mtp_index[batch_size // 2 :]
100104
_1_b_seq_len = b_seq_len[batch_size // 2 :]
101105
_1_b_ready_cache_len = b_ready_cache_len[batch_size // 2 :]
102106

@@ -105,9 +109,10 @@ def overlap_prefill(
105109
_1_total_token_num,
106110
_1_max_len_in_batch,
107111
_1_input_ids,
108-
_1_mem_indexes,
109112
_1_b_req_idx,
113+
_1_b_mtp_index,
110114
_1_b_seq_len,
115+
_1_mem_indexes,
111116
True,
112117
_1_b_ready_cache_len,
113118
{},
@@ -120,23 +125,25 @@ def overlap_prefill(
120125

121126

122127
def overlap_decode(
123-
model_part, batch_size, max_len_in_batch, input_ids, mem_indexes, b_req_idx, b_seq_len, total_token_num
128+
model_part, batch_size, max_len_in_batch, input_ids, mem_indexes, b_req_idx, b_mtp_index, b_seq_len, total_token_num
124129
):
125130
_0_batch_size = batch_size // 2
126131
_0_total_token_num = total_token_num // 2
127132
_0_max_len_in_batch = max_len_in_batch
128133
_0_input_ids = input_ids[: batch_size // 2]
129134
_0_mem_indexes = mem_indexes[: batch_size // 2]
130135
_0_b_req_idx = b_req_idx[: batch_size // 2]
136+
_0_b_mtp_index = b_mtp_index[: batch_size // 2]
131137
_0_b_seq_len = b_seq_len[: batch_size // 2]
132138
micro_batch1 = ModelInput(
133139
_0_batch_size,
134140
_0_total_token_num,
135141
_0_max_len_in_batch,
136142
_0_input_ids,
137-
_0_mem_indexes,
138143
_0_b_req_idx,
144+
_0_b_mtp_index,
139145
_0_b_seq_len,
146+
_0_mem_indexes,
140147
)
141148

142149
_1_batch_size = batch_size - batch_size // 2
@@ -145,16 +152,18 @@ def overlap_decode(
145152
_1_input_ids = input_ids[batch_size // 2 :]
146153
_1_mem_indexes = mem_indexes[batch_size // 2 :]
147154
_1_b_req_idx = b_req_idx[batch_size // 2 :]
155+
_1_b_mtp_index = b_mtp_index[batch_size // 2 :]
148156
_1_b_seq_len = b_seq_len[batch_size // 2 :]
149157

150158
micro_batch2 = ModelInput(
151159
_1_batch_size,
152160
_1_total_token_num,
153161
_1_max_len_in_batch,
154162
_1_input_ids,
155-
_1_mem_indexes,
156163
_1_b_req_idx,
164+
_1_b_mtp_index,
157165
_1_b_seq_len,
166+
_1_mem_indexes,
158167
)
159168

160169
output, output1 = model_part.microbatch_overlap_decode(micro_batch1, micro_batch2)
@@ -170,6 +179,7 @@ def prefill(
170179
input_ids,
171180
mem_indexes,
172181
b_req_idx,
182+
b_mtp_index,
173183
b_seq_len,
174184
total_token_num,
175185
b_ready_cache_len,
@@ -179,25 +189,30 @@ def prefill(
179189
total_token_num,
180190
max_len_in_batch,
181191
input_ids,
182-
mem_indexes,
183192
b_req_idx,
193+
b_mtp_index,
184194
b_seq_len,
195+
mem_indexes,
185196
is_prefill=True,
186197
b_ready_cache_len=b_ready_cache_len,
187198
)
199+
188200
model_output = model_part.forward(model_input)
189201
return model_output.logits
190202

191203

192-
def decode(model_part, batch_size, max_len_in_batch, input_ids, mem_indexes, b_req_idx, b_seq_len, total_token_num):
204+
def decode(
205+
model_part, batch_size, max_len_in_batch, input_ids, mem_indexes, b_req_idx, b_mtp_index, b_seq_len, total_token_num
206+
):
193207
model_input = ModelInput(
194208
batch_size,
195209
total_token_num,
196210
max_len_in_batch,
197211
input_ids,
198-
mem_indexes,
199212
b_req_idx,
213+
b_mtp_index,
200214
b_seq_len,
215+
mem_indexes,
201216
is_prefill=False,
202217
)
203218
model_output = model_part.forward(model_input)
@@ -236,6 +251,7 @@ def run_forward_once(
236251
b_req_idx = torch.tensor(
237252
[model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
238253
)
254+
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
239255
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
240256
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
241257
for i in range(batch_size):
@@ -260,6 +276,7 @@ def run_forward_once(
260276
test_data,
261277
mem_indexes,
262278
b_req_idx,
279+
b_mtp_index,
263280
b_seq_len,
264281
total_token_num,
265282
b_ready_cache_len, # b_ready_cache_len
@@ -288,6 +305,7 @@ def run_forward_once(
288305
test_data,
289306
mem_indexes,
290307
b_req_idx,
308+
b_mtp_index,
291309
b_seq_len,
292310
total_token_num,
293311
b_ready_cache_len, # b_ready_cache_len
@@ -302,6 +320,7 @@ def run_forward_once(
302320
torch.cuda.synchronize()
303321
step_start = time.time()
304322
total_token_num += batch_size
323+
b_mtp_index += 1
305324
b_seq_len += 1
306325
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
307326
max_len_in_batch = input_len + i + 1
@@ -312,6 +331,7 @@ def run_forward_once(
312331
predict_ids.view(-1),
313332
mem_indexes,
314333
b_req_idx,
334+
b_mtp_index,
315335
b_seq_len,
316336
total_token_num,
317337
)
@@ -325,6 +345,7 @@ def run_forward_once(
325345
predict_ids.view(-1),
326346
mem_indexes,
327347
b_req_idx,
348+
b_mtp_index,
328349
b_seq_len,
329350
total_token_num,
330351
),

0 commit comments

Comments
 (0)