Skip to content

Commit 5978e8f

Browse files
committed
fixes
1 parent 74bebb1 commit 5978e8f

File tree

3 files changed

+93
-14
lines changed

3 files changed

+93
-14
lines changed

optillm.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,24 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int:
108108
return 0
109109

110110
# Extract all content within <think>...</think> tags
111-
think_pattern = r'<think>(.*?)</think>'
112-
matches = re.findall(think_pattern, text, re.DOTALL)
111+
# Handle both complete and truncated think blocks
113112

114-
if not matches:
115-
return 0
113+
# First, find all complete <think>...</think> blocks
114+
complete_pattern = r'<think>(.*?)</think>'
115+
complete_matches = re.findall(complete_pattern, text, re.DOTALL)
116+
117+
# Then check for unclosed <think> tag (truncated response)
118+
# This finds <think> that doesn't have a matching </think> after it
119+
truncated_pattern = r'<think>(?!.*</think>)(.*)$'
120+
truncated_match = re.search(truncated_pattern, text, re.DOTALL)
116121

117122
# Combine all thinking content
118-
thinking_content = ''.join(matches)
123+
thinking_content = ''.join(complete_matches)
124+
if truncated_match:
125+
thinking_content += truncated_match.group(1)
126+
127+
if not thinking_content:
128+
return 0
119129

120130
if tokenizer and hasattr(tokenizer, 'encode'):
121131
# Use tokenizer for precise counting
@@ -125,8 +135,9 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int:
125135
except Exception as e:
126136
logger.warning(f"Failed to count tokens with tokenizer: {e}")
127137

128-
# Fallback: rough estimation (4 chars per token on average)
129-
return max(0, len(thinking_content.strip()) // 4)
138+
# Fallback: rough estimation (4 chars per token on average, minimum 1 token for non-empty content)
139+
content_length = len(thinking_content.strip())
140+
return max(1, content_length // 4) if content_length > 0 else 0
130141

131142
# Server configuration
132143
server_config = {

optillm/inference.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,24 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int:
4545
return 0
4646

4747
# Extract all content within <think>...</think> tags
48-
think_pattern = r'<think>(.*?)</think>'
49-
matches = re.findall(think_pattern, text, re.DOTALL)
48+
# Handle both complete and truncated think blocks
5049

51-
if not matches:
52-
return 0
50+
# First, find all complete <think>...</think> blocks
51+
complete_pattern = r'<think>(.*?)</think>'
52+
complete_matches = re.findall(complete_pattern, text, re.DOTALL)
53+
54+
# Then check for unclosed <think> tag (truncated response)
55+
# This finds <think> that doesn't have a matching </think> after it
56+
truncated_pattern = r'<think>(?!.*</think>)(.*)$'
57+
truncated_match = re.search(truncated_pattern, text, re.DOTALL)
5358

5459
# Combine all thinking content
55-
thinking_content = ''.join(matches)
60+
thinking_content = ''.join(complete_matches)
61+
if truncated_match:
62+
thinking_content += truncated_match.group(1)
63+
64+
if not thinking_content:
65+
return 0
5666

5767
if tokenizer and hasattr(tokenizer, 'encode'):
5868
# Use tokenizer for precise counting
@@ -62,8 +72,9 @@ def count_reasoning_tokens(text: str, tokenizer=None) -> int:
6272
except Exception as e:
6373
logger.warning(f"Failed to count tokens with tokenizer: {e}")
6474

65-
# Fallback: rough estimation (4 chars per token on average)
66-
return max(0, len(thinking_content.strip()) // 4)
75+
# Fallback: rough estimation (4 chars per token on average, minimum 1 token for non-empty content)
76+
content_length = len(thinking_content.strip())
77+
return max(1, content_length // 4) if content_length > 0 else 0
6778

6879
# MLX Support for Apple Silicon
6980
try:

tests/test_reasoning_simple.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,63 @@ def encode(self, text):
8585

8686
result = optillm_count(text, tokenizer)
8787
self.assertGreater(result, 0, "Should fallback to character estimation")
88+
89+
def test_count_reasoning_tokens_truncated_response(self):
90+
"""Test counting tokens when response is truncated (no closing </think> tag)"""
91+
# Test truncated think tag
92+
truncated_text = "<think>This reasoning was cut off due to max tokens"
93+
94+
result1 = optillm_count(truncated_text)
95+
result2 = inference_count(truncated_text)
96+
97+
self.assertGreater(result1, 0, "Should count tokens from truncated think block")
98+
self.assertEqual(result1, result2, "Both functions should return same result")
99+
100+
def test_count_reasoning_tokens_mixed_complete_and_truncated(self):
101+
"""Test with both complete and truncated think blocks"""
102+
mixed_text = """
103+
<think>First complete reasoning block</think>
104+
Some output here
105+
<think>This second block was truncated and never closed
106+
"""
107+
108+
result = optillm_count(mixed_text)
109+
self.assertGreater(result, 0, "Should count tokens from both complete and truncated blocks")
110+
111+
# Should be more than just the first block alone
112+
first_block_only = "<think>First complete reasoning block</think>"
113+
first_result = optillm_count(first_block_only)
114+
self.assertGreater(result, first_result, "Should include truncated content")
115+
116+
def test_count_reasoning_tokens_no_false_positives(self):
117+
"""Test that we don't count think-like content that isn't actually truncated"""
118+
# This should NOT be counted as truncated since there's a </think> later
119+
text_with_complete_blocks = "<think>First block</think>Output<think>Second complete block</think>"
120+
121+
result = optillm_count(text_with_complete_blocks)
122+
123+
# Count manually - should only be the content inside the two complete blocks
124+
manual_count = optillm_count("<think>First blockSecond complete block</think>")
125+
self.assertEqual(result, manual_count, "Should only count complete blocks, not detect false truncation")
126+
127+
def test_count_reasoning_tokens_edge_cases_truncated(self):
128+
"""Test edge cases with truncated responses"""
129+
test_cases = [
130+
("<think>", 0), # Just opening tag, no content
131+
("<think>a", 1), # Minimal content
132+
("Some output <think>reasoning here", None), # Truncated at end
133+
("<think>multi\nline\ntruncated", None), # Multiline truncated
134+
]
135+
136+
for text, expected_min in test_cases:
137+
result = optillm_count(text)
138+
if expected_min is not None:
139+
if expected_min == 0:
140+
self.assertEqual(result, expected_min, f"Should return {expected_min} for: {text}")
141+
else:
142+
self.assertGreaterEqual(result, expected_min, f"Should be at least {expected_min} for: {text}")
143+
else:
144+
self.assertGreater(result, 0, f"Should count truncated content for: {text}")
88145

89146

90147
class TestInferenceStructures(unittest.TestCase):

0 commit comments

Comments
 (0)