Skip to content

Commit 4515ad2

Browse files
K11OntheBoatK11OntheBoat
andauthored
Support limit thinking lengths (#4069)
Co-authored-by: K11OntheBoat <“[email protected]”>
1 parent 0c6f193 commit 4515ad2

File tree

9 files changed

+194
-28
lines changed

9 files changed

+194
-28
lines changed

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(
224224
self.vision_config = PretrainedConfig.from_dict(self.vision_config)
225225

226226
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
227+
self.think_end_id = args.get("think_end_id", -1)
227228

228229
architectures = self.architectures[0]
229230

fastdeploy/engine/engine.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import paddle
3535
from tqdm import tqdm
3636

37+
from fastdeploy.config import ErnieArchitectures
3738
from fastdeploy.engine.args_utils import EngineArgs
3839
from fastdeploy.engine.common_engine import EngineService
3940
from fastdeploy.engine.expert_service import start_data_parallel_service
@@ -470,6 +471,14 @@ def _start_worker_service(self):
470471
else len(self.data_processor.tokenizer.vocab)
471472
)
472473

474+
is_ernie = ErnieArchitectures.contains_ernie_arch(self.cfg.model_config.architectures)
475+
if is_ernie:
476+
self.cfg.model_config.think_end_id = self.data_processor.tokenizer.get_vocab().get("</think>", -1)
477+
if self.cfg.model_config.think_end_id != -1:
478+
llm_logger.info(f"Get think_end_id {self.cfg.model_config.think_end_id} from vocab.")
479+
else:
480+
llm_logger.info("No </think> token found in vocabulary, the model can not do reasoning.")
481+
473482
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
474483
ips = None
475484
if self.cfg.ips is not None:
@@ -496,6 +505,7 @@ def _start_worker_service(self):
496505
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
497506
f" --quantization '{json.dumps(self.cfg.model_config.quantization)}'"
498507
f" --ori_vocab_size {ori_vocab_size}"
508+
f" --think_end_id {self.cfg.model_config.think_end_id}"
499509
f" --speculative_config '{self.cfg.speculative_config.to_json_string()}'"
500510
f" --graph_optimization_config '{self.cfg.graph_opt_config.to_json_string()}'"
501511
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"

fastdeploy/entrypoints/engine_client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ async def add_requests(self, task):
155155
task["prompt_token_ids_len"] = len(task["prompt_token_ids"])
156156
input_ids_len = task["prompt_token_ids_len"]
157157
task["max_tokens"] = min(self.max_model_len - input_ids_len, task.get("max_tokens"))
158-
if task.get("reasoning_max_tokens", None) is None:
159-
task["reasoning_max_tokens"] = max(int(task["max_tokens"] * 0.8), 1)
160158
min_tokens = task.get("min_tokens", 1)
161159
if "messages" in task:
162160
del task["messages"]

fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ def process_request_dict(self, request, max_model_len=None):
252252
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
253253
if request.get("max_tokens") is None:
254254
request["max_tokens"] = max(1, max_model_len - len(request["prompt_token_ids"]))
255+
else:
256+
request["max_tokens"] = min(max_model_len - len(request["prompt_token_ids"]), request["max_tokens"])
257+
if request.get("reasoning_max_tokens") is None:
258+
request["reasoning_max_tokens"] = max(int(request["max_tokens"] * 0.8), 1)
255259
data_processor_logger.info(f"Processed request {request}")
256260

257261
return request

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,9 @@ def post_process_normal(
195195
) -> ModelRunnerOutput:
196196
"""Post-processing steps after completing a single token generation."""
197197
# handle vl:
198-
if model_output.enable_thinking:
199-
exists_think_end = sampler_output.sampled_token_ids == model_output.think_end_id
198+
if model_output.think_end_id != -1:
199+
thinking_mask = model_output.enable_thinking
200+
exists_think_end = (sampler_output.sampled_token_ids == model_output.think_end_id) & thinking_mask
200201
paddle.assign(
201202
paddle.where(
202203
exists_think_end,
@@ -206,9 +207,10 @@ def post_process_normal(
206207
model_output.need_think_end,
207208
)
208209

210+
reasoning_index_update_cond = model_output.need_think_end.cast("bool") & thinking_mask
209211
paddle.assign(
210212
paddle.where(
211-
model_output.need_think_end.cast("bool"),
213+
reasoning_index_update_cond,
212214
model_output.reasoning_index - 1,
213215
model_output.reasoning_index,
214216
),
@@ -219,6 +221,8 @@ def post_process_normal(
219221
(sampler_output.sampled_token_ids == model_output.eos_token_id.T).any(axis=1, keepdim=True)
220222
| (model_output.reasoning_index == 0)
221223
) & (model_output.need_think_end > 0)
224+
225+
stop_wo_think = stop_wo_think & thinking_mask
222226
sampler_output.sampled_token_ids = paddle.where(
223227
stop_wo_think,
224228
model_output.think_end_id,

fastdeploy/worker/gpu_model_runner.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,27 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
322322
else:
323323
position_ids = None
324324

325-
enable_thinking = request.get("enable_thinking", True)
326-
enable_thinking = enable_thinking if enable_thinking is not None else True
327-
self.share_inputs["enable_thinking"][:] = enable_thinking
328-
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
329-
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
330325
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
331326
position_ids, request.get("max_tokens", 2048)
332327
)
333328

329+
if request.get("enable_thinking", False):
330+
# Enable thinking
331+
req_reasoning_max_tokens = request.get("reasoning_max_tokens")
332+
req_max_tokens = request.get("max_tokens")
333+
final_reasoning_tokens = (
334+
req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens
335+
)
336+
337+
self.share_inputs["enable_thinking"][idx : idx + 1] = True
338+
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
339+
self.share_inputs["reasoning_index"][idx : idx + 1, :] = final_reasoning_tokens
340+
else:
341+
# Disable thinking
342+
self.share_inputs["enable_thinking"][idx : idx + 1] = False
343+
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
344+
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
345+
334346
if isinstance(request.prompt_token_ids, np.ndarray):
335347
prompt_token_ids = request.prompt_token_ids.tolist()
336348
else:
@@ -549,16 +561,28 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
549561
self.share_inputs["prompt_lens"][idx : idx + 1] = length
550562

551563
if self.enable_mm:
552-
enable_thinking = request.get("enable_thinking", True)
553-
enable_thinking = enable_thinking if enable_thinking is not None else True
554-
self.share_inputs["enable_thinking"][:] = enable_thinking
555-
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 if enable_thinking else 0
556-
self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens", 2048)
557564
self.share_inputs["rope_emb"][idx : idx + 1, :] = self.prepare_rope3d(
558565
position_ids, request.get("max_tokens", 2048)
559566
)
560567
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
561568

569+
if request.get("enable_thinking", False):
570+
# Enable thinking
571+
req_reasoning_max_tokens = request.get("reasoning_max_tokens")
572+
req_max_tokens = request.get("max_tokens")
573+
final_reasoning_tokens = (
574+
req_reasoning_max_tokens if req_reasoning_max_tokens is not None else req_max_tokens
575+
)
576+
577+
self.share_inputs["enable_thinking"][idx : idx + 1] = True
578+
self.share_inputs["need_think_end"][idx : idx + 1, :] = 1
579+
self.share_inputs["reasoning_index"][idx : idx + 1, :] = final_reasoning_tokens
580+
else:
581+
# Disable thinking
582+
self.share_inputs["enable_thinking"][idx : idx + 1] = False
583+
self.share_inputs["need_think_end"][idx : idx + 1, :] = 0
584+
self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0
585+
562586
def get_attr_from_request(request, attr, default_value=None):
563587
res = request.get(attr, default_value)
564588
if res is not None:
@@ -853,6 +877,11 @@ def _init_share_inputs(self, max_num_seqs: int):
853877
# Initialize rotary position embedding
854878
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
855879

880+
# Initialize thinking related buffers
881+
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
882+
self.share_inputs["enable_thinking"] = paddle.full(shape=[max_num_seqs, 1], fill_value=False, dtype="bool")
883+
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
884+
856885
# TODO(gongshaotian): move to models
857886
if not self.enable_mm:
858887
self.share_inputs["rope_emb"] = get_rope(
@@ -952,11 +981,6 @@ def _init_share_inputs(self, max_num_seqs: int):
952981
dtype="float32",
953982
)
954983
self.share_inputs["image_features"] = None
955-
self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
956-
self.share_inputs["enable_thinking"] = paddle.full(
957-
shape=[1], fill_value=("ernie" in self.model_config.model_type), dtype="bool"
958-
)
959-
self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
960984

961985
def _prepare_inputs(self) -> None:
962986
"""Prepare the model inputs"""
@@ -1399,10 +1423,10 @@ def _dummy_run(
13991423
),
14001424
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
14011425
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
1402-
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
1403-
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
1404-
need_think_end=(self.share_inputs["need_think_end"] if self.enable_mm else None),
1405-
reasoning_index=(self.share_inputs["reasoning_index"] if self.enable_mm else None),
1426+
enable_thinking=self.share_inputs["enable_thinking"],
1427+
think_end_id=self.model_config.think_end_id,
1428+
need_think_end=self.share_inputs["need_think_end"],
1429+
reasoning_index=self.share_inputs["reasoning_index"],
14061430
stop_token_ids=self.share_inputs["stop_seqs"],
14071431
stop_seqs_len=self.share_inputs["stop_seqs_len"],
14081432
)
@@ -1715,10 +1739,10 @@ class at the server level, which is too granular for ModelRunner.
17151739
),
17161740
accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None),
17171741
accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None),
1718-
enable_thinking=(self.share_inputs["enable_thinking"] if self.enable_mm else None),
1719-
think_end_id=(getattr(self.model_config, "think_end_id", -1) if self.enable_mm else -1),
1720-
need_think_end=(self.share_inputs["need_think_end"][:num_running_requests] if self.enable_mm else None),
1721-
reasoning_index=(self.share_inputs["reasoning_index"][:num_running_requests] if self.enable_mm else None),
1742+
enable_thinking=self.share_inputs["enable_thinking"],
1743+
think_end_id=self.model_config.think_end_id,
1744+
need_think_end=self.share_inputs["need_think_end"][:num_running_requests],
1745+
reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests],
17221746
stop_token_ids=self.share_inputs["stop_seqs"],
17231747
stop_seqs_len=self.share_inputs["stop_seqs_len"],
17241748
)

fastdeploy/worker/worker_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def parse_args():
587587
help="enable expert parallel",
588588
)
589589
parser.add_argument("--ori_vocab_size", type=int, default=None)
590+
parser.add_argument("--think_end_id", type=int, default=-1)
590591

591592
parser.add_argument(
592593
"--quantization",

tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,21 @@ def test_chat_with_thinking(openai_client, capsys):
516516
assert response.choices[0].message.reasoning_content is None
517517
assert "</think>" not in response.choices[0].message.content
518518

519+
# test logic
520+
reasoning_max_tokens = None
521+
response = openai_client.chat.completions.create(
522+
model="default",
523+
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
524+
temperature=1,
525+
stream=False,
526+
max_tokens=20,
527+
extra_body={
528+
"chat_template_kwargs": {"enable_thinking": True},
529+
"reasoning_max_tokens": reasoning_max_tokens,
530+
},
531+
)
532+
assert response.choices[0].message.reasoning_content is not None
533+
519534
# enable thinking, streaming
520535
reasoning_max_tokens = 3
521536
response = openai_client.chat.completions.create(
@@ -927,3 +942,50 @@ def test_profile_reset_block_num():
927942
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
928943
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
929944
)
945+
946+
947+
def test_thinking_logic_flag(openai_client, capsys):
948+
"""
949+
Test the interaction between token calculation logic and conditional thinking.
950+
This test covers:
951+
1. Default max_tokens calculation when not provided.
952+
2. Capping of max_tokens when it exceeds model limits.
953+
3. Default reasoning_max_tokens calculation when not provided.
954+
4. Activation of thinking based on the final state of reasoning_max_tokens.
955+
"""
956+
957+
response_case_1 = openai_client.chat.completions.create(
958+
model="default",
959+
messages=[{"role": "user", "content": "Explain gravity briefly."}],
960+
temperature=1,
961+
stream=False,
962+
extra_body={
963+
"chat_template_kwargs": {"enable_thinking": True},
964+
},
965+
)
966+
assert response_case_1.choices[0].message.reasoning_content is not None
967+
968+
response_case_2 = openai_client.chat.completions.create(
969+
model="default",
970+
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
971+
temperature=1,
972+
stream=False,
973+
max_tokens=20,
974+
extra_body={
975+
"chat_template_kwargs": {"enable_thinking": True},
976+
"reasoning_max_tokens": 5,
977+
},
978+
)
979+
assert response_case_2.choices[0].message.reasoning_content is not None
980+
981+
response_case_3 = openai_client.chat.completions.create(
982+
model="default",
983+
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
984+
temperature=1,
985+
stream=False,
986+
max_tokens=20,
987+
extra_body={
988+
"chat_template_kwargs": {"enable_thinking": False},
989+
},
990+
)
991+
assert response_case_3.choices[0].message.reasoning_content is None

tests/e2e/test_EB_VL_Lite_serving.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,21 @@ def test_chat_with_thinking(openai_client, capsys):
535535
assert response.choices[0].message.reasoning_content is None
536536
assert "</think>" not in response.choices[0].message.content
537537

538+
# test logic
539+
reasoning_max_tokens = None
540+
response = openai_client.chat.completions.create(
541+
model="default",
542+
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
543+
temperature=1,
544+
stream=False,
545+
max_tokens=20,
546+
extra_body={
547+
"chat_template_kwargs": {"enable_thinking": True},
548+
"reasoning_max_tokens": reasoning_max_tokens,
549+
},
550+
)
551+
assert response.choices[0].message.reasoning_content is not None
552+
538553
# enable thinking, streaming
539554
reasoning_max_tokens = 3
540555
response = openai_client.chat.completions.create(
@@ -642,3 +657,50 @@ def test_profile_reset_block_num():
642657
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
643658
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
644659
)
660+
661+
662+
def test_thinking_logic_flag(openai_client, capsys):
663+
"""
664+
Test the interaction between token calculation logic and conditional thinking.
665+
This test covers:
666+
1. Default max_tokens calculation when not provided.
667+
2. Capping of max_tokens when it exceeds model limits.
668+
3. Default reasoning_max_tokens calculation when not provided.
669+
4. Activation of thinking based on the final state of reasoning_max_tokens.
670+
"""
671+
672+
response_case_1 = openai_client.chat.completions.create(
673+
model="default",
674+
messages=[{"role": "user", "content": "Explain gravity briefly."}],
675+
temperature=1,
676+
stream=False,
677+
extra_body={
678+
"chat_template_kwargs": {"enable_thinking": True},
679+
},
680+
)
681+
assert response_case_1.choices[0].message.reasoning_content is not None
682+
683+
response_case_2 = openai_client.chat.completions.create(
684+
model="default",
685+
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
686+
temperature=1,
687+
stream=False,
688+
max_tokens=20,
689+
extra_body={
690+
"chat_template_kwargs": {"enable_thinking": True},
691+
"reasoning_max_tokens": 5,
692+
},
693+
)
694+
assert response_case_2.choices[0].message.reasoning_content is not None
695+
696+
response_case_3 = openai_client.chat.completions.create(
697+
model="default",
698+
messages=[{"role": "user", "content": "Explain gravity in a way that a five-year-old child can understand."}],
699+
temperature=1,
700+
stream=False,
701+
max_tokens=20,
702+
extra_body={
703+
"chat_template_kwargs": {"enable_thinking": False},
704+
},
705+
)
706+
assert response_case_3.choices[0].message.reasoning_content is None

0 commit comments

Comments
 (0)