-
Notifications
You must be signed in to change notification settings - Fork 139
Add proper token counting to code execution model #1184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb7b23a
f7a35c0
47ac316
5161caf
08430d4
c5eef20
180db93
5e81ca0
d4cd4ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,6 +79,8 @@ def __init__( | |
| # Directory paths for data and output | ||
| data_dir: str = "", | ||
| output_dir: str | None = None, | ||
| # Request tokenizer initialization independent of soft_fail | ||
| require_tokenizer: bool = False, | ||
| ): | ||
| self._tunnel = None | ||
| self.model_name_or_path = model | ||
|
|
@@ -126,7 +128,7 @@ def __init__( | |
| else: | ||
| self.base_url = base_url | ||
|
|
||
| if enable_soft_fail: | ||
| if enable_soft_fail or require_tokenizer: | ||
| self.tokenizer = self._get_tokenizer(tokenizer) | ||
| else: | ||
| self.tokenizer = None | ||
|
|
@@ -202,7 +204,11 @@ def _initialize_tokenizer(self, tokenizer: str | None) -> WrapperAutoTokenizer | | |
| if tokenizer is None: | ||
| return None | ||
| if isinstance(tokenizer, str): | ||
| return WrapperAutoTokenizer(tokenizer) | ||
| try: | ||
| return WrapperAutoTokenizer(tokenizer) | ||
| except OSError: | ||
| LOG.warning(f"Tokenizer not found at '{tokenizer}', trying fallback to server /tokenize endpoint") | ||
| return None | ||
|
Comment on lines
+207
to
+211
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching only Consider checking the flag here and only catching when fallback is acceptable (when the flag is |
||
|
|
||
| @abc.abstractmethod | ||
| def _build_chat_request_params(self, **kwargs) -> dict: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,6 +154,8 @@ async def _generate_single( | |
| # if there's an unfinished code block | ||
| if output.count(code_end) + 1 == output.count(code_begin): | ||
| output += code_end | ||
| # Count tokens for the manually added code_end | ||
| num_generated_tokens += len(self.model.tokenizer.encode(code_end)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In non-streaming mode, |
||
|
|
||
| # Update the prompt based on format | ||
| if is_openai_format: | ||
|
|
@@ -162,16 +164,14 @@ async def _generate_single( | |
| else: | ||
| request["prompt"] += output | ||
|
|
||
| # if it's the extra iteration, we don't execute the code block and just finish | ||
|
|
||
| if generation_index == effective_max_code_executions: | ||
| break | ||
| # adjusting requested tokens to account for what has been generated already | ||
| request["tokens_to_generate"] -= num_generated_tokens | ||
| total_num_generated_tokens += num_generated_tokens | ||
| generation_time += int(time.time() - generation_time_start) | ||
| # TODO: currently we don't account for tokens in the code output that we add to the prompt | ||
| # in most cases the output should be small though | ||
|
|
||
| # if it's the extra iteration, we don't execute the code block and just finish | ||
| if generation_index == effective_max_code_executions: | ||
| break | ||
| if request["tokens_to_generate"] <= 0: | ||
| break | ||
| # .rfind(code_end, 0, -1) searches for the second-to-last occurrence of code_end and checks | ||
|
|
@@ -195,6 +195,12 @@ async def _generate_single( | |
| if "process_status" in execution_dict and execution_dict["process_status"] == "timeout": | ||
| num_code_timeouts += 1 | ||
|
|
||
| # Account for tokens in the code output | ||
| code_output_tokens = len(self.model.tokenizer.encode(code_output)) | ||
| request["tokens_to_generate"] -= code_output_tokens | ||
| total_num_generated_tokens += code_output_tokens | ||
i-vainn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if request["tokens_to_generate"] <= 0: | ||
| break | ||
| if is_openai_format: | ||
| request["prompt"][-2]["content"] += code_output | ||
| else: | ||
|
|
@@ -270,6 +276,12 @@ async def generate_async( | |
|
|
||
| Not every server supports that, so make sure to override this method directly if that's not the case. | ||
| """ | ||
| if self.model.tokenizer is None: | ||
| raise RuntimeError( | ||
| "Tokenizer is required for CodeExecutionWrapper to correctly count tokens. " | ||
| "Please initialize the model with require_tokenizer=True or provide a valid tokenizer." | ||
| ) | ||
|
|
||
| if top_logprobs is not None: # TODO: add this | ||
| raise NotImplementedError("top_logprobs is not supported yet.") | ||
|
|
||
|
|
@@ -363,24 +375,25 @@ async def _stream_single( | |
| model_token_iterator = await self.model.generate_async(prompt=current_full_prompt, **request) | ||
|
|
||
| current_output_segment = "" | ||
| num_generated_tokens = 0 | ||
| async for chunk in model_token_iterator: | ||
| yield chunk | ||
| current_output_segment += chunk["generation"] | ||
| num_generated_tokens += 1 | ||
|
|
||
| request["tokens_to_generate"] -= num_generated_tokens | ||
| if request["tokens_to_generate"] <= 0: | ||
| break | ||
| if not current_output_segment: | ||
| break | ||
|
|
||
| # openai and trtllm don't show what stop word was triggered, so we assume that it was `code_end` | ||
| # if there's an unfinished code block | ||
| if current_output_segment.count(code_end) + 1 == current_output_segment.count(code_begin): | ||
| current_output_segment += code_end | ||
| yield {"generation": code_end} | ||
|
|
||
| # Calculate token count for this segment (after adding code_end if needed) | ||
| num_generated_tokens = len(self.model.tokenizer.encode(current_output_segment)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Token count uses |
||
|
|
||
| request["tokens_to_generate"] -= num_generated_tokens | ||
| if request["tokens_to_generate"] <= 0: | ||
| break | ||
| if not current_output_segment: | ||
| break | ||
|
|
||
| # Update the prompt based on format | ||
| if is_openai_format: | ||
| current_full_prompt.append({"role": "assistant", "content": current_output_segment}) | ||
|
|
@@ -417,6 +430,12 @@ async def _stream_single( | |
| ) | ||
| yield {"generation": formatted_code_output} # Yield the entire formatted code output as one chunk | ||
|
|
||
| # Account for tokens in the code output | ||
| code_output_tokens = len(self.model.tokenizer.encode(formatted_code_output)) | ||
| request["tokens_to_generate"] -= code_output_tokens | ||
| if request["tokens_to_generate"] <= 0: | ||
| break | ||
|
|
||
| # Append executed code's output to the prompt | ||
| if is_openai_format: | ||
| current_full_prompt[-2]["content"] += formatted_code_output | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only catching
OSErrormay miss other exceptions during tokenizer initialization (e.g.,ImportError,ValueError). If the goal is to gracefully fall back to server endpoint, consider catching broader exceptions.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!