@@ -57,14 +57,15 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn
57
57
self .split_connector = split_connector
58
58
59
59
self .speculative_decoding = self .cfg .speculative_config .method is not None
60
+ self .use_logprobs = self .cfg .enable_logprob
60
61
61
62
if self .speculative_decoding :
62
63
self .output_tokens = paddle .full (
63
64
shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 ],
64
65
fill_value = 2 ,
65
66
dtype = "int64" ,
66
67
)
67
- elif self .cfg . enable_logprob :
68
+ elif self .use_logprobs :
68
69
self .output_tokens = paddle .full (shape = [MAX_BSZ * (K + 1 ) + 2 , 1 ], fill_value = 2 , dtype = "int64" )
69
70
self .output_scores = paddle .full (shape = [MAX_BSZ * (K + 1 ), 1 ], fill_value = 0.0 , dtype = "float32" )
70
71
self .output_ranks = paddle .full (shape = [MAX_BSZ ], fill_value = 0 , dtype = "int64" )
@@ -125,53 +126,12 @@ def run(self):
125
126
assert self .resource_manager is not None , "The resource manager is None, cannot run."
126
127
if self .worker is not None :
127
128
raise Exception ("Worker is already running!" )
128
- use_logprobs = (
129
- self .cfg .enable_logprob
130
- and not self .speculative_decoding
131
- and not self .cfg .parallel_config .enable_expert_parallel
132
- )
133
-
134
- target_func = self .process_sampling_with_logprob_results if use_logprobs else self .process_sampling_results
135
129
136
- self .worker = threading .Thread (target = target_func )
130
+ self .worker = threading .Thread (target = self . process_sampling_results )
137
131
138
132
self .worker .daemon = True
139
133
self .worker .start ()
140
134
141
- def process_sampling_with_logprob_results (self ):
142
- """
143
- read tokens from paddle inference engine and process logprob results
144
- """
145
- if current_platform .is_cuda ():
146
- from fastdeploy .model_executor .ops .gpu import get_output_topk
147
- else :
148
- raise NotImplementedError ("Only CUDA platform supports logprob." )
149
-
150
- rank_id = self .cfg .parallel_config .local_data_parallel_id
151
-
152
- while True :
153
- try :
154
- is_blocking = True
155
- get_output_topk (
156
- self .output_tokens ,
157
- self .output_scores ,
158
- self .output_ranks ,
159
- K ,
160
- rank_id ,
161
- is_blocking ,
162
- )
163
-
164
- if self .output_tokens [0 , 0 ] == - 2 :
165
- continue
166
- llm_logger .debug (
167
- f"rank_id { rank_id } self.output_tokens[0, 0] { self .output_tokens [0 , 0 ]} "
168
- f"rank_id { rank_id } self.output_scores[0, 0] { self .output_scores [0 , 0 ]} "
169
- )
170
- self ._process_prefill_metrics ()
171
- self ._process_sampling_with_logprob_batch_output ()
172
- except Exception as e :
173
- llm_logger .info (f"while get input_data error: { e } { traceback .format_exc ()!s} " )
174
-
175
135
def process_sampling_results (self ):
176
136
"""
177
137
read tokens from paddle inference engine and process
@@ -187,6 +147,7 @@ def process_sampling_results(self):
187
147
from fastdeploy .model_executor .ops .gpu import (
188
148
get_output ,
189
149
get_output_ep ,
150
+ get_output_topk ,
190
151
speculate_get_output ,
191
152
)
192
153
rank_id = self .cfg .parallel_config .local_data_parallel_id
@@ -207,7 +168,17 @@ def process_sampling_results(self):
207
168
get_output_ep (self .output_tokens , rank_id , is_blocking )
208
169
209
170
else :
210
- get_output (self .output_tokens , rank_id , is_blocking )
171
+ if self .use_logprobs :
172
+ get_output_topk (
173
+ self .output_tokens ,
174
+ self .output_scores ,
175
+ self .output_ranks ,
176
+ K ,
177
+ rank_id ,
178
+ is_blocking ,
179
+ )
180
+ else :
181
+ get_output (self .output_tokens , rank_id , is_blocking )
211
182
212
183
if self .output_tokens [0 , 0 ] == - 2 :
213
184
continue
@@ -305,129 +276,23 @@ def _compute_speculative_status(self):
305
276
self .total_step = 0
306
277
self .speculative_stats_step += 1
307
278
308
- def _process_sampling_with_logprob_batch_output (self ):
309
- """
310
- batch post-processing logprob output function
311
- """
312
-
313
- batch = self .output_tokens [1 , 0 ]
314
- tokens = self .output_tokens [2 : batch * (K + 1 ) + 2 ].numpy ().reshape ([batch , K + 1 ])[:, : (K + 1 )]
315
- scores = self .output_scores [: batch * (K + 1 )].numpy ().reshape ([batch , K + 1 ])[:, : (K + 1 )]
316
- ranks = self .output_ranks [:batch ].numpy ()
317
- batch_result = list ()
318
- for i in range (batch ):
319
- if self .resource_manager .stop_flags [i ]:
320
- continue
321
- task = self .resource_manager .tasks_list [i ]
322
- task_id = task .request_id
323
- token_id = int (tokens [i , 0 ])
324
- token_ids = [token_id ]
325
- recovery_stop = token_id == RECOVERY_STOP_SIGNAL
326
- if recovery_stop :
327
- llm_logger .info (f"recovery stop signal found at task { task_id } " )
328
- if not recovery_stop and token_id < 0 :
329
- continue
330
-
331
- if task .get ("prefill_chunk_info" , None ) is not None :
332
- prefill_chunk_num = task .get ("prefill_chunk_num" , 0 )
333
- task .prefill_chunk_num = prefill_chunk_num + 1
334
-
335
- if task .prefill_chunk_num < len (task .prefill_chunk_info ):
336
- continue
337
-
338
- self .total_step += 1
339
- current_time = time .time ()
340
- if self .tokens_counter [task_id ] == 0 :
341
- metrics = RequestMetrics (
342
- arrival_time = task .arrival_time ,
343
- inference_start_time = task .inference_start_time ,
344
- first_token_time = time .time () - task .inference_start_time ,
345
- time_in_queue = task .schedule_start_time - task .preprocess_end_time ,
346
- preprocess_cost_time = task .preprocess_end_time - task .preprocess_start_time ,
347
- request_start_time = task .arrival_time ,
348
- )
349
-
350
- self ._record_first_token_metrics (task , current_time )
351
-
352
- else :
353
- metrics = RequestMetrics (
354
- arrival_time = time .time (),
355
- request_start_time = task .arrival_time ,
356
- )
357
- self .number_of_output_tokens += len (token_ids )
358
- self ._record_metrics (task , current_time , token_ids )
359
- result = RequestOutput (
360
- request_id = task_id ,
361
- outputs = CompletionOutput (
362
- index = i ,
363
- send_idx = self .tokens_counter [task_id ],
364
- token_ids = [],
365
- logprob = None ,
366
- draft_token_ids = [],
367
- top_logprobs = None ,
368
- ),
369
- finished = False ,
370
- metrics = metrics ,
371
- )
372
- if self .tokens_counter [task_id ] == 0 :
373
- if task .messages is not None :
374
- result .prompt = task .messages
375
- result .num_cached_tokens = task .num_cached_tokens
376
-
377
- is_prefill = task .disaggregate_info is not None and task .disaggregate_info ["role" ] == "prefill"
378
-
379
- if is_prefill and len (token_ids ) > 1 :
380
- result .outputs .draft_token_ids = copy .deepcopy (token_ids )
381
-
382
- for idx , token_id in enumerate (token_ids ):
383
- self .tokens_counter [task_id ] += 1
384
- if token_id != RECOVERY_STOP_SIGNAL :
385
- result .outputs .token_ids .append (token_id )
386
- task .output_token_ids .append (token_id )
387
- result .outputs .logprob = float (scores [i , 0 ])
388
- # Construct top_logprobs
389
- topk_token_ids = tokens [i , :].tolist ()
390
- topk_logprobs = scores [i , :].tolist ()
391
- sampled_rank = ranks [i ].item ()
392
-
393
- result .outputs .top_logprobs = LogprobsLists (
394
- logprob_token_ids = [topk_token_ids ],
395
- logprobs = [topk_logprobs ],
396
- sampled_token_ranks = [sampled_rank ],
397
- )
398
-
399
- if token_id in task .eos_token_ids or is_prefill or recovery_stop :
400
- result .finished = True
401
- if recovery_stop :
402
- result .error_msg = "Recover is not supported, the result is incomplete!"
403
- llm_logger .info (
404
- f"Request: { task_id } finished, number of " f"generated tokens: { self .tokens_counter [task_id ]} ."
405
- )
406
- llm_logger .info (
407
- f"Request: { task_id } token ratio: { self .tokens_counter [task_id ] / (time .time () - task .inference_start_time )} "
408
- )
409
- llm_logger .info (f"{ self .resource_manager .info ()} " )
410
- if self .cfg .speculative_config .method :
411
- self ._compute_speculative_status ()
412
- if not is_prefill :
413
- self ._record_completion_metrics (task , current_time )
414
- self ._recycle_resources (task_id , i , task , result , is_prefill )
415
- break
416
- if not is_prefill or self .cfg .scheduler_config .name == "splitwise" :
417
- batch_result .append (result )
418
-
419
- self .postprocess (batch_result )
420
-
421
279
def _process_batch_output (self ):
422
280
"""
423
281
batch post-processing function
424
282
"""
425
283
426
284
tokens = self .output_tokens .numpy ()
285
+ scores = None
286
+ ranks = None
427
287
if self .cfg .speculative_config .method :
428
288
batch = self .output_tokens [1 ]
429
289
accept_num = tokens [2 : batch + 2 ]
430
290
self ._record_speculative_decoding_mertics (accept_num )
291
+ elif self .use_logprobs :
292
+ batch = self .output_tokens [1 , 0 ]
293
+ tokens = tokens [2 : batch * (K + 1 ) + 2 ].reshape ([batch , K + 1 ])[:, : (K + 1 )]
294
+ scores = self .output_scores [: batch * (K + 1 )].numpy ().reshape ([batch , K + 1 ])[:, : (K + 1 )]
295
+ ranks = self .output_ranks [:batch ].numpy ()
431
296
else :
432
297
batch = self .output_tokens [1 , 0 ]
433
298
tokens = tokens [2 : batch + 2 ]
@@ -522,6 +387,17 @@ def _process_batch_output(self):
522
387
if token_id != RECOVERY_STOP_SIGNAL :
523
388
result .outputs .token_ids .append (token_id )
524
389
task .output_token_ids .append (token_id )
390
+ if self .use_logprobs :
391
+ result .outputs .logprob = float (scores [i , 0 ])
392
+ # Construct top_logprobs
393
+ topk_token_ids = tokens [i , :].tolist ()
394
+ topk_logprobs = scores [i , :].tolist ()
395
+ sampled_rank = ranks [i ].item ()
396
+ result .outputs .top_logprobs = LogprobsLists (
397
+ logprob_token_ids = [topk_token_ids ],
398
+ logprobs = [topk_logprobs ],
399
+ sampled_token_ranks = [sampled_rank ],
400
+ )
525
401
if token_id in task .eos_token_ids or is_prefill or recovery_stop :
526
402
result .finished = True
527
403
if recovery_stop :
0 commit comments