Skip to content

Commit 108bcc7

Browse files
committed
fix reasoning tokens
1 parent c133b7b commit 108bcc7

File tree

4 files changed

+105
-12
lines changed

4 files changed

+105
-12
lines changed

optillm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,41 @@ def get_config():
9393
default_client = LiteLLMWrapper()
9494
return default_client, API_KEY
9595

96+
def count_reasoning_tokens(text: str, tokenizer=None) -> int:
97+
"""
98+
Count tokens within <think>...</think> tags in the given text.
99+
100+
Args:
101+
text: The text to analyze
102+
tokenizer: Optional tokenizer instance for precise counting
103+
104+
Returns:
105+
Number of reasoning tokens (0 if no think tags found)
106+
"""
107+
if not text or not isinstance(text, str):
108+
return 0
109+
110+
# Extract all content within <think>...</think> tags
111+
think_pattern = r'<think>(.*?)</think>'
112+
matches = re.findall(think_pattern, text, re.DOTALL)
113+
114+
if not matches:
115+
return 0
116+
117+
# Combine all thinking content
118+
thinking_content = ''.join(matches)
119+
120+
if tokenizer and hasattr(tokenizer, 'encode'):
121+
# Use tokenizer for precise counting
122+
try:
123+
tokens = tokenizer.encode(thinking_content)
124+
return len(tokens)
125+
except Exception as e:
126+
logger.warning(f"Failed to count tokens with tokenizer: {e}")
127+
128+
# Fallback: rough estimation (4 chars per token on average)
129+
return max(0, len(thinking_content.strip()) // 4)
130+
96131
# Server configuration
97132
server_config = {
98133
'approach': 'none',
@@ -678,11 +713,22 @@ def proxy():
678713
if stream:
679714
return Response(generate_streaming_response(response, model), content_type='text/event-stream')
680715
else:
716+
# Calculate reasoning tokens from the response
717+
reasoning_tokens = 0
718+
if isinstance(response, str):
719+
reasoning_tokens = count_reasoning_tokens(response)
720+
elif isinstance(response, list) and response:
721+
# For multiple responses, sum up reasoning tokens from all
722+
reasoning_tokens = sum(count_reasoning_tokens(resp) for resp in response if isinstance(resp, str))
723+
681724
response_data = {
682725
'model': model,
683726
'choices': [],
684727
'usage': {
685728
'completion_tokens': completion_tokens,
729+
'completion_tokens_details': {
730+
'reasoning_tokens': reasoning_tokens
731+
}
686732
}
687733
}
688734

optillm/inference.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import traceback
1919
import platform
2020
import sys
21+
import re
2122

2223
from optillm.cot_decoding import cot_decode
2324
from optillm.entropy_decoding import entropy_decode
@@ -29,6 +30,41 @@
2930
logging.basicConfig(level=logging.INFO)
3031
logger = logging.getLogger(__name__)
3132

33+
def count_reasoning_tokens(text: str, tokenizer=None) -> int:
34+
"""
35+
Count tokens within <think>...</think> tags in the given text.
36+
37+
Args:
38+
text: The text to analyze
39+
tokenizer: Optional tokenizer instance for precise counting
40+
41+
Returns:
42+
Number of reasoning tokens (0 if no think tags found)
43+
"""
44+
if not text or not isinstance(text, str):
45+
return 0
46+
47+
# Extract all content within <think>...</think> tags
48+
think_pattern = r'<think>(.*?)</think>'
49+
matches = re.findall(think_pattern, text, re.DOTALL)
50+
51+
if not matches:
52+
return 0
53+
54+
# Combine all thinking content
55+
thinking_content = ''.join(matches)
56+
57+
if tokenizer and hasattr(tokenizer, 'encode'):
58+
# Use tokenizer for precise counting
59+
try:
60+
tokens = tokenizer.encode(thinking_content)
61+
return len(tokens)
62+
except Exception as e:
63+
logger.warning(f"Failed to count tokens with tokenizer: {e}")
64+
65+
# Fallback: rough estimation (4 chars per token on average)
66+
return max(0, len(thinking_content.strip()) // 4)
67+
3268
# MLX Support for Apple Silicon
3369
try:
3470
import mlx.core as mx
@@ -1502,10 +1538,11 @@ def __init__(
15021538
self.message.logprobs = logprobs
15031539

15041540
class ChatCompletionUsage:
1505-
def __init__(self, prompt_tokens: int, completion_tokens: int, total_tokens: int):
1541+
def __init__(self, prompt_tokens: int, completion_tokens: int, total_tokens: int, reasoning_tokens: int = 0):
15061542
self.prompt_tokens = prompt_tokens
15071543
self.completion_tokens = completion_tokens
15081544
self.total_tokens = total_tokens
1545+
self.reasoning_tokens = reasoning_tokens
15091546

15101547
class ChatCompletion:
15111548
def __init__(self, response_dict: Dict):
@@ -1547,7 +1584,10 @@ def model_dump(self) -> Dict:
15471584
"usage": {
15481585
"prompt_tokens": self.usage.prompt_tokens,
15491586
"completion_tokens": self.usage.completion_tokens,
1550-
"total_tokens": self.usage.total_tokens
1587+
"total_tokens": self.usage.total_tokens,
1588+
"completion_tokens_details": {
1589+
"reasoning_tokens": getattr(self.usage, 'reasoning_tokens', 0)
1590+
}
15511591
}
15521592
}
15531593

@@ -1766,15 +1806,15 @@ def create(
17661806

17671807
logger.debug(f"ThinkDeeper tokens: user={user_max_tokens}, thinking={max_thinking_tokens}, adjusted={adjusted_max_tokens}")
17681808

1769-
result = thinkdeeper_decode_mlx(
1809+
result, reasoning_tokens = thinkdeeper_decode_mlx(
17701810
pipeline.model,
17711811
pipeline.tokenizer,
17721812
messages,
17731813
thinkdeeper_config_with_tokens
17741814
)
17751815
else:
17761816
logger.info("Using PyTorch ThinkDeeper implementation")
1777-
result = thinkdeeper_decode(
1817+
result, reasoning_tokens = thinkdeeper_decode(
17781818
pipeline.current_model,
17791819
pipeline.tokenizer,
17801820
messages,
@@ -1850,6 +1890,11 @@ def create(
18501890
prompt_tokens = len(pipeline.tokenizer.encode(prompt))
18511891
completion_tokens = sum(token_counts)
18521892

1893+
# Calculate reasoning tokens from all responses
1894+
total_reasoning_tokens = 0
1895+
for response in responses:
1896+
total_reasoning_tokens += count_reasoning_tokens(response, pipeline.tokenizer)
1897+
18531898
# Create OpenAI-compatible response format
18541899
response_dict = {
18551900
"id": f"chatcmpl-{int(time.time()*1000)}",
@@ -1871,7 +1916,8 @@ def create(
18711916
"usage": {
18721917
"prompt_tokens": prompt_tokens,
18731918
"completion_tokens": completion_tokens,
1874-
"total_tokens": completion_tokens + prompt_tokens
1919+
"total_tokens": completion_tokens + prompt_tokens,
1920+
"reasoning_tokens": total_reasoning_tokens
18751921
}
18761922
}
18771923

optillm/thinkdeeper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def reasoning_effort(self, messages) -> str:
168168
response = "".join(response_chunks)
169169
full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response}"
170170

171-
logger.debug(f"Final response length: {len(full_response)} chars, Total thoughts: {self.thought_count}")
172-
return full_response
171+
logger.debug(f"Final response length: {len(full_response)} chars, Total thoughts: {self.thought_count}, Thinking tokens: {n_thinking_tokens}")
172+
return full_response, n_thinking_tokens
173173

174174
def thinkdeeper_decode(
175175
model: PreTrainedModel,
@@ -192,8 +192,8 @@ def thinkdeeper_decode(
192192

193193
try:
194194
processor = ThinkDeeperProcessor(config, tokenizer, model)
195-
response = processor.reasoning_effort(messages)
196-
return response
195+
response, reasoning_tokens = processor.reasoning_effort(messages)
196+
return response, reasoning_tokens
197197

198198
except Exception as e:
199199
logger.error(f"Error in ThinkDeeper processing: {str(e)}")

optillm/thinkdeeper_mlx.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ def reasoning_effort(self, messages) -> str:
243243
response_content = "".join(response_chunks)
244244
full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response_content}"
245245

246-
return full_response
246+
logger.debug(f"MLX Final response length: {len(full_response)} chars, Thinking tokens: {n_thinking_tokens}")
247+
return full_response, n_thinking_tokens
247248

248249
def _generate_chunk(self, prompt: str, max_tokens: int, temperature: float) -> str:
249250
"""Generate a small chunk of text using MLX with proper sampler"""
@@ -319,8 +320,8 @@ def thinkdeeper_decode_mlx(
319320

320321
try:
321322
processor = MLXThinkDeeperProcessor(config, tokenizer, model)
322-
response = processor.reasoning_effort(messages)
323-
return response
323+
response, reasoning_tokens = processor.reasoning_effort(messages)
324+
return response, reasoning_tokens
324325

325326
except Exception as e:
326327
logger.error(f"Error in MLX ThinkDeeper processing: {str(e)}")

0 commit comments

Comments
 (0)