Skip to content
Merged
2 changes: 1 addition & 1 deletion nemo_skills/inference/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_model(server_type, tokenizer=None, model_class: str | None = None, **kwa

def get_code_execution_model(server_type, tokenizer=None, code_execution=None, sandbox=None, **kwargs):
"""A helper function to make it easier to set server through cmd."""
model = get_model(server_type=server_type, tokenizer=tokenizer, **kwargs)
model = get_model(server_type=server_type, tokenizer=tokenizer, require_tokenizer=True, **kwargs)
if code_execution is None:
code_execution = {}
code_execution_config = CodeExecutionConfig(**code_execution)
Expand Down
10 changes: 8 additions & 2 deletions nemo_skills/inference/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(
# Directory paths for data and output
data_dir: str = "",
output_dir: str | None = None,
# Request tokenizer initialization independent of soft_fail
require_tokenizer: bool = False,
):
self._tunnel = None
self.model_name_or_path = model
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(
else:
self.base_url = base_url

if enable_soft_fail:
if enable_soft_fail or require_tokenizer:
self.tokenizer = self._get_tokenizer(tokenizer)
else:
self.tokenizer = None
Expand Down Expand Up @@ -202,7 +204,11 @@ def _initialize_tokenizer(self, tokenizer: str | None) -> WrapperAutoTokenizer |
if tokenizer is None:
return None
if isinstance(tokenizer, str):
return WrapperAutoTokenizer(tokenizer)
try:
return WrapperAutoTokenizer(tokenizer)
except OSError:
LOG.warning(f"Tokenizer not found at '{tokenizer}', trying fallback to server /tokenize endpoint")
return None
Comment on lines +207 to +211
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only catching OSError may miss other exceptions during tokenizer initialization (e.g., ImportError, ValueError). If the goal is to gracefully fall back to server endpoint, consider catching broader exceptions.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +207 to +211
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching only OSError doesn't follow CONTRIBUTING.md guidelines about not being overly defensive. If require_tokenizer is True, the code should fail loudly when tokenizer initialization fails, not silently fall back. The runtime check on line 279 of code_execution.py will catch this later, but it happens during generation (after model setup), which could cause issues in production.

Consider checking the flag here and only catching when fallback is acceptable (when the flag is False).


@abc.abstractmethod
def _build_chat_request_params(self, **kwargs) -> dict:
Expand Down
47 changes: 33 additions & 14 deletions nemo_skills/inference/model/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ async def _generate_single(
# if there's an unfinished code block
if output.count(code_end) + 1 == output.count(code_begin):
output += code_end
# Count tokens for the manually added code_end
num_generated_tokens += len(self.model.tokenizer.encode(code_end))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In non-streaming mode, code_end tokens are added to num_generated_tokens which came from output_dict.get("num_generated_tokens", 0). If the server already counted tokens for code_end when it was manually added, this would double-count those tokens.


# Update the prompt based on format
if is_openai_format:
Expand All @@ -162,16 +164,14 @@ async def _generate_single(
else:
request["prompt"] += output

# if it's the extra iteration, we don't execute the code block and just finish

if generation_index == effective_max_code_executions:
break
# adjusting requested tokens to account for what has been generated already
request["tokens_to_generate"] -= num_generated_tokens
total_num_generated_tokens += num_generated_tokens
generation_time += int(time.time() - generation_time_start)
# TODO: currently we don't account for tokens in the code output that we add to the prompt
# in most cases the output should be small though

# if it's the extra iteration, we don't execute the code block and just finish
if generation_index == effective_max_code_executions:
break
if request["tokens_to_generate"] <= 0:
break
# .rfind(code_end, 0, -1) searches for the second-to-last occurrence of code_end and checks
Expand All @@ -195,6 +195,12 @@ async def _generate_single(
if "process_status" in execution_dict and execution_dict["process_status"] == "timeout":
num_code_timeouts += 1

# Account for tokens in the code output
code_output_tokens = len(self.model.tokenizer.encode(code_output))
request["tokens_to_generate"] -= code_output_tokens
total_num_generated_tokens += code_output_tokens
if request["tokens_to_generate"] <= 0:
break
if is_openai_format:
request["prompt"][-2]["content"] += code_output
else:
Expand Down Expand Up @@ -270,6 +276,12 @@ async def generate_async(

Not every server supports that, so make sure to override this method directly if that's not the case.
"""
if self.model.tokenizer is None:
raise RuntimeError(
"Tokenizer is required for CodeExecutionWrapper to correctly count tokens. "
"Please initialize the model with require_tokenizer=True or provide a valid tokenizer."
)

if top_logprobs is not None: # TODO: add this
raise NotImplementedError("top_logprobs is not supported yet.")

Expand Down Expand Up @@ -363,24 +375,25 @@ async def _stream_single(
model_token_iterator = await self.model.generate_async(prompt=current_full_prompt, **request)

current_output_segment = ""
num_generated_tokens = 0
async for chunk in model_token_iterator:
yield chunk
current_output_segment += chunk["generation"]
num_generated_tokens += 1

request["tokens_to_generate"] -= num_generated_tokens
if request["tokens_to_generate"] <= 0:
break
if not current_output_segment:
break

# openai and trtllm don't show what stop word was triggered, so we assume that it was `code_end`
# if there's an unfinished code block
if current_output_segment.count(code_end) + 1 == current_output_segment.count(code_begin):
current_output_segment += code_end
yield {"generation": code_end}

# Calculate token count for this segment (after adding code_end if needed)
num_generated_tokens = len(self.model.tokenizer.encode(current_output_segment))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Token count uses encode() on the segment string, which may differ from actual token count if tokenization is context-dependent. Consider verifying this matches the LLM's actual token usage.


request["tokens_to_generate"] -= num_generated_tokens
if request["tokens_to_generate"] <= 0:
break
if not current_output_segment:
break

# Update the prompt based on format
if is_openai_format:
current_full_prompt.append({"role": "assistant", "content": current_output_segment})
Expand Down Expand Up @@ -417,6 +430,12 @@ async def _stream_single(
)
yield {"generation": formatted_code_output} # Yield the entire formatted code output as one chunk

# Account for tokens in the code output
code_output_tokens = len(self.model.tokenizer.encode(formatted_code_output))
request["tokens_to_generate"] -= code_output_tokens
if request["tokens_to_generate"] <= 0:
break

# Append executed code's output to the prompt
if is_openai_format:
current_full_prompt[-2]["content"] += formatted_code_output
Expand Down