@@ -322,15 +322,27 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
322
322
else :
323
323
position_ids = None
324
324
325
- enable_thinking = request .get ("enable_thinking" , True )
326
- enable_thinking = enable_thinking if enable_thinking is not None else True
327
- self .share_inputs ["enable_thinking" ][:] = enable_thinking
328
- self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1 if enable_thinking else 0
329
- self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" , 2048 )
330
325
self .share_inputs ["rope_emb" ][idx : idx + 1 , :] = self .prepare_rope3d (
331
326
position_ids , request .get ("max_tokens" , 2048 )
332
327
)
333
328
329
+ if request .get ("enable_thinking" , False ):
330
+ # Enable thinking
331
+ req_reasoning_max_tokens = request .get ("reasoning_max_tokens" )
332
+ req_max_tokens = request .get ("max_tokens" )
333
+ final_reasoning_tokens = (
334
+ req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens
335
+ )
336
+
337
+ self .share_inputs ["enable_thinking" ][idx : idx + 1 ] = True
338
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1
339
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = final_reasoning_tokens
340
+ else :
341
+ # Disable thinking
342
+ self .share_inputs ["enable_thinking" ][idx : idx + 1 ] = False
343
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 0
344
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = 0
345
+
334
346
if isinstance (request .prompt_token_ids , np .ndarray ):
335
347
prompt_token_ids = request .prompt_token_ids .tolist ()
336
348
else :
@@ -549,16 +561,28 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
549
561
self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = length
550
562
551
563
if self .enable_mm :
552
- enable_thinking = request .get ("enable_thinking" , True )
553
- enable_thinking = enable_thinking if enable_thinking is not None else True
554
- self .share_inputs ["enable_thinking" ][:] = enable_thinking
555
- self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1 if enable_thinking else 0
556
- self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" , 2048 )
557
564
self .share_inputs ["rope_emb" ][idx : idx + 1 , :] = self .prepare_rope3d (
558
565
position_ids , request .get ("max_tokens" , 2048 )
559
566
)
560
567
self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
561
568
569
+ if request .get ("enable_thinking" , False ):
570
+ # Enable thinking
571
+ req_reasoning_max_tokens = request .get ("reasoning_max_tokens" )
572
+ req_max_tokens = request .get ("max_tokens" )
573
+ final_reasoning_tokens = (
574
+ req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens
575
+ )
576
+
577
+ self .share_inputs ["enable_thinking" ][idx : idx + 1 ] = True
578
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1
579
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = final_reasoning_tokens
580
+ else :
581
+ # Disable thinking
582
+ self .share_inputs ["enable_thinking" ][idx : idx + 1 ] = False
583
+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 0
584
+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = 0
585
+
562
586
def get_attr_from_request (request , attr , default_value = None ):
563
587
res = request .get (attr , default_value )
564
588
if res is not None :
@@ -853,6 +877,11 @@ def _init_share_inputs(self, max_num_seqs: int):
853
877
# Initialize rotary position embedding
854
878
tmp_position_ids = paddle .arange (self .parallel_config .max_model_len ).reshape ((1 , - 1 ))
855
879
880
+ # Initialize thinking related buffers
881
+ self .share_inputs ["need_think_end" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
882
+ self .share_inputs ["enable_thinking" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = False , dtype = "bool" )
883
+ self .share_inputs ["reasoning_index" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
884
+
856
885
# TODO(gongshaotian): move to models
857
886
if not self .enable_mm :
858
887
self .share_inputs ["rope_emb" ] = get_rope (
@@ -952,11 +981,6 @@ def _init_share_inputs(self, max_num_seqs: int):
952
981
dtype = "float32" ,
953
982
)
954
983
self .share_inputs ["image_features" ] = None
955
- self .share_inputs ["need_think_end" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
956
- self .share_inputs ["enable_thinking" ] = paddle .full (
957
- shape = [1 ], fill_value = ("ernie" in self .model_config .model_type ), dtype = "bool"
958
- )
959
- self .share_inputs ["reasoning_index" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
960
984
961
985
def _prepare_inputs (self ) -> None :
962
986
"""Prepare the model inputs"""
@@ -1399,10 +1423,10 @@ def _dummy_run(
1399
1423
),
1400
1424
accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
1401
1425
accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
1402
- enable_thinking = ( self .share_inputs ["enable_thinking" ] if self . enable_mm else None ) ,
1403
- think_end_id = ( getattr ( self .model_config , "think_end_id" , - 1 ) if self . enable_mm else - 1 ) ,
1404
- need_think_end = ( self .share_inputs ["need_think_end" ] if self . enable_mm else None ) ,
1405
- reasoning_index = ( self .share_inputs ["reasoning_index" ] if self . enable_mm else None ) ,
1426
+ enable_thinking = self .share_inputs ["enable_thinking" ],
1427
+ think_end_id = self .model_config . think_end_id ,
1428
+ need_think_end = self .share_inputs ["need_think_end" ],
1429
+ reasoning_index = self .share_inputs ["reasoning_index" ],
1406
1430
stop_token_ids = self .share_inputs ["stop_seqs" ],
1407
1431
stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
1408
1432
)
@@ -1715,10 +1739,10 @@ class at the server level, which is too granular for ModelRunner.
1715
1739
),
1716
1740
accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
1717
1741
accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
1718
- enable_thinking = ( self .share_inputs ["enable_thinking" ] if self . enable_mm else None ) ,
1719
- think_end_id = ( getattr ( self .model_config , "think_end_id" , - 1 ) if self . enable_mm else - 1 ) ,
1720
- need_think_end = ( self .share_inputs ["need_think_end" ][:num_running_requests ] if self . enable_mm else None ) ,
1721
- reasoning_index = ( self .share_inputs ["reasoning_index" ][:num_running_requests ] if self . enable_mm else None ) ,
1742
+ enable_thinking = self .share_inputs ["enable_thinking" ],
1743
+ think_end_id = self .model_config . think_end_id ,
1744
+ need_think_end = self .share_inputs ["need_think_end" ][:num_running_requests ],
1745
+ reasoning_index = self .share_inputs ["reasoning_index" ][:num_running_requests ],
1722
1746
stop_token_ids = self .share_inputs ["stop_seqs" ],
1723
1747
stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
1724
1748
)
0 commit comments