Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions agentlightning/verl/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ agentlightning:
trajectory_max_response_length: 8192 # supported in trajectory level aggregation, suggest to set as maximum length for the cumulative agent responses in the full trajectory, i.e., n_turns * (max_response_length + max_prompt_length)
debug: False # supported in trajectory level aggregation, enable to diagnose trace merging failures
mismatch_log_dir: ./mismatch_cases # supported in trajectory level aggregation with debug=True, directory to store logs of mismatch cases
# =========================================================================
# Tool Call Filtering (Youtu-Agent style)
# When enabled, filters out "unexpected tool call" turns where the model
# continues generating after a tool call instead of stopping properly.
# This helps prevent entropy explosion during RL training.
# Reference: contrib/youtu-agent-lightning branch
# =========================================================================
filter_unexpected_tool_calls: False # set to True to enable filtering

data:
filter_overlong_prompts: false
Expand Down
176 changes: 176 additions & 0 deletions agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import logging
import os
import random
import socket
Expand All @@ -19,12 +20,27 @@
from tensordict import TensorDict
from verl import DataProto

# =============================================================================
# Tool Call Filtering Support (for filtering unexpected tool call turns)
# Reference: Youtu-Agent implementation in contrib/youtu-agent-lightning branch
# The ToolParser extracts tool calls from response tokens to detect cases where
# the model continues generating after a tool call (hallucinated tool responses)
# instead of properly stopping with </tool_call><|im_end|>
# =============================================================================
try:
from verl.experimental.agent_loop.tool_parser import ToolParser
TOOL_PARSER_AVAILABLE = True
except ImportError:
TOOL_PARSER_AVAILABLE = False

from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy
from agentlightning.adapter.triplet import TracerTraceToTriplet, TraceToTripletBase
from agentlightning.llm_proxy import LLMProxy, ModelConfig
from agentlightning.store.base import LightningStore
from agentlightning.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task

logger = logging.getLogger(__name__)

__all__ = [
"AgentModeDaemon",
"get_left_padded_ids_and_attention_mask",
Expand Down Expand Up @@ -283,6 +299,11 @@ def __init__(
self._proxy_thread: Optional[threading.Thread] = None
self.is_train = True

# Tool Call Filtering Setup (config key: trace_aggregator["filter_unexpected_tool_calls"])
self.tool_parser = None
self.toolcall_candidate_token_last2_list = []
self._setup_tool_call_filter(train_information, tokenizer)

def _internal_loop_runner(self):
"""Run the internal loop."""
loop = asyncio.new_event_loop()
Expand All @@ -291,6 +312,112 @@ def _internal_loop_runner(self):
loop.run_forever()
loop.close()

# =========================================================================
# Tool Call Filtering Methods
# Reference: Youtu-Agent implementation (contrib/youtu-agent-lightning)
# Purpose: Filter out "unexpected tool call turns" where the model continues
# generating text after a tool call instead of stopping properly.
# =========================================================================

def _setup_tool_call_filter(self, train_information: Dict[str, Any], tokenizer: Any) -> None:
"""Initialize tool parser and valid ending token patterns for filtering.

Uses apply_chat_template to auto-detect the correct tool call ending tokens
rather than hardcoding token IDs. Also builds variants with eos/pad tokens
to allow various ending conditions and prevent over-filtering.

Args:
train_information: Training config containing 'format' for toolcall format
tokenizer: The tokenizer used for encoding/decoding
"""
if not TOOL_PARSER_AVAILABLE:
print("Warning: ToolParser not available, tool call filtering disabled.")
self.tool_parser = None
return

toolcall_format = train_information.get("format", "hermes")
self.tool_parser = ToolParser.get_tool_parser(toolcall_format, tokenizer)

# Use chat template to detect the actual tool call ending token sequence
# Example uses calculator tool to match calc-x example for consistency
tools_examples = [{
"type": "function",
"name": "calculate",
"description": "Evaluate a mathematical expression",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "Math expression, e.g., '2 + 3 * 4'"},
},
"required": ["expression"],
},
}]
toolcall_message_examples = [
{"role": "user", "content": "What is 15 + 27?"},
{"role": "assistant", "content": "", "tool_calls": [{
"id": "call_001",
"type": "function",
"function": {"name": "calculate", "arguments": '{"expression":"15 + 27"}'},
}]},
]
toolcall_example_chat_template = tokenizer.apply_chat_template(
toolcall_message_examples, tools=tools_examples,
add_generation_prompt=False, tokenize=False,
)
# Extract the last 2 tokens from the chat template output (e.g., </tool_call><|im_end|>)
toolcall_example_token_last2 = tokenizer.encode(toolcall_example_chat_template.strip())[-2:]

eos_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id

# Build candidate list: the detected ending + variants with eos/pad
# This allows various tool-call ending conditions to prevent over-filtering
toolcall_candidate_token_last2_list = [toolcall_example_token_last2]
if toolcall_example_token_last2[-1] != eos_token_id:
toolcall_candidate_token_last2_list.append([toolcall_example_token_last2[0], eos_token_id])
if toolcall_example_token_last2[-1] != pad_token_id:
toolcall_candidate_token_last2_list.append([toolcall_example_token_last2[0], pad_token_id])

self.toolcall_candidate_token_last2_list = toolcall_candidate_token_last2_list
logger.info(
f"Tool call filter initialized: {eos_token_id=}, {pad_token_id=}, "
f"candidates={self.toolcall_candidate_token_last2_list}"
)


def _is_valid_tool_call_response(self, response_ids: List[int]) -> Tuple[bool, bool]:
"""Check if a response with tool calls ends with valid ending tokens.

Uses strict last-2-token check (same as youtu branch): the response must end
with one of the candidate token pairs (e.g., </tool_call><|im_end|> or
</tool_call><|endoftext|>).

Args:
response_ids: List of token IDs from the model's response

Returns:
Tuple of (has_tool_calls, has_valid_ending):
- has_tool_calls: True if the response contains tool calls
- has_valid_ending: True if no tool calls, or tool calls with proper ending
"""
if self.tool_parser is None:
return False, True

_, tool_calls = asyncio.run(self.tool_parser.extract_tool_calls(response_ids))

if not tool_calls:
return False, True

if len(response_ids) < 2:
return True, False

# Strict last-2 check against all valid ending candidates
for candidate in self.toolcall_candidate_token_last2_list:
if response_ids[-2] == candidate[0] and response_ids[-1] == candidate[1]:
return True, True

return True, False

# Multimodal utilities for M-RoPE position embeddings

def _is_mrope_model(self) -> bool:
Expand Down Expand Up @@ -821,6 +948,12 @@ def get_train_data_batch(
finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {}
finished_id_to_final_reward: Dict[str, float] = {}
sample_with_reward_count = 0

# Tool call filtering metrics
n_total_turns_before_filter = 0
n_unexpected_tool_calls = 0
n_skipped_rollouts_by_filter = 0

for rollout_id, rollout in self._completed_rollouts_v0.items():
original_sample = self._task_id_to_original_sample[rollout_id]
sample_with_reward_count += int(rollout.final_reward is not None)
Expand All @@ -842,6 +975,41 @@ def get_train_data_batch(
}
for t in rollout.triplets
]

# Filter void/unexpected tool call turns (Youtu-Agent style)
# When config is OFF: only count for metrics, no filtering
# When config is ON: apply both void and unexpected tool call filtering
if self.tool_parser is not None:
n_total_turns_before_filter += len(trace_list)

# Count unexpected tool calls (always, for metrics)
for trace in trace_list:
if len(trace["prompt_ids"]) and len(trace["response_ids"]):
has_tool_calls, has_valid_ending = self._is_valid_tool_call_response(trace["response_ids"])
if has_tool_calls and not has_valid_ending:
n_unexpected_tool_calls += 1

# Apply filtering only when config is enabled
if self.trace_aggregator.get("filter_unexpected_tool_calls", False):
# 1. Filter void turns (empty prompt or response)
trace_list = [
t for t in trace_list
if len(t["prompt_ids"]) and len(t["response_ids"])
]
# 2. Filter unexpected tool call turns
trace_list_filtered = []
for trace in trace_list:
has_tool_calls, has_valid_ending = self._is_valid_tool_call_response(trace["response_ids"])
if has_tool_calls and not has_valid_ending:
continue # Skip invalid turns
trace_list_filtered.append(trace)
# 3. Skip rollout if only 1 or fewer valid turns remain
if len(trace_list_filtered) <= 1:
n_skipped_rollouts_by_filter += 1
finished_id_to_final_reward[rollout_id] = final_reward
continue
trace_list = trace_list_filtered

info = {
"reward": final_reward,
"trace_list": trace_list,
Expand Down Expand Up @@ -1123,6 +1291,14 @@ def get_train_data_batch(
and self.trace_aggregator.get("debug", False)
else {}
),
"training/n_unexpected_tool_calls": n_unexpected_tool_calls,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comment: only set the logging metrics visible when self.tool_parser is not None.

"training/n_total_turns_before_filter": n_total_turns_before_filter,
"training/unexpected_tool_call_ratio": (
n_unexpected_tool_calls / n_total_turns_before_filter if n_total_turns_before_filter > 0 else 0.0
),
"training/n_skipped_rollouts_by_filter": n_skipped_rollouts_by_filter,
"training/filter_enabled": float(self.trace_aggregator.get("filter_unexpected_tool_calls", False)),
"training/reward_std": np.std(list(finished_id_to_final_reward.values())),
}

# Add non-tensor data for advantage calculation and logging
Expand Down
1 change: 1 addition & 0 deletions agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def _train_step(self, batch_dict: dict) -> dict:
self._dump_generations(
inputs=inputs,
outputs=outputs,
gts=[""] * len(inputs),
scores=scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=rollout_data_dir,
Expand Down
79 changes: 79 additions & 0 deletions examples/calc_x/train_calc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
import agentlightning as agl
from agentlightning.env_var import LightningEnvVar, resolve_bool_env_var, resolve_str_env_var

# Ensure venv bin is in PATH (needed for uvx/mcp-server-calculator in Ray workers)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some unnecessary changes to this file. Only related config should be included here I think.

_script_dir = os.path.dirname(os.path.abspath(__file__))
_venv_bin = os.path.join(_script_dir, "..", "..", ".venv", "bin")
if os.path.isdir(_venv_bin):
os.environ["PATH"] = os.path.abspath(_venv_bin) + ":" + os.environ.get("PATH", "")


def verl_default_config() -> Dict[str, Any]:
config = {
Expand Down Expand Up @@ -123,6 +129,11 @@ def train(
trajectory_level: bool = False,
weave: bool,
mongo_uri: Optional[str],
filter_unexpected_tool_calls: bool = False,
experiment_name: Optional[str] = None,
n_gpus: int = 1,
checkpoint_dir: str = "/home/jovyan/msra/experiments/checkpoints",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain about this line? It seems that this path belongs to someone else?

resume: bool = False,
):
"""The training entrypoint function for Calc-X agent with VERL algorithm.

Expand All @@ -141,6 +152,7 @@ def train(
trajectory_level: Whether to enable trajectory level in trace aggregator.
weave: Whether to enable Weave tracing.
mongo_uri: MongoDB URI to use for the store.
experiment_name: Custom experiment name for W&B logging.
"""
# Load datasets (respect CLI file paths)
train_dataset = cast(agl.Dataset[MathProblem], HuggingFaceDataset.from_parquet(train_file).to_list()) # type: ignore
Expand All @@ -156,6 +168,26 @@ def train(
if model:
config["actor_rollout_ref"]["model"]["path"] = model

# Override experiment name if provided (for W&B logging)
if experiment_name:
config["trainer"]["experiment_name"] = experiment_name
print(f"Using custom experiment name: {experiment_name}")

# Override n_gpus_per_node for multi-GPU training
if n_gpus > 1:
config["trainer"]["n_gpus_per_node"] = n_gpus
print(f"Multi-GPU training enabled: n_gpus_per_node={n_gpus}")

# Set checkpoint directory and conversation dump directory
config["trainer"]["default_local_dir"] = checkpoint_dir
config["trainer"]["resume_mode"] = "auto" if resume else "disable"
conversations_dir = checkpoint_dir.replace("checkpoints", "conversations")
config["trainer"]["rollout_data_dir"] = conversations_dir
os.makedirs(conversations_dir, exist_ok=True)
print(f"Checkpoint directory: {checkpoint_dir}")
print(f"Conversations directory: {conversations_dir}")
print(f"Resume mode: {config['trainer']['resume_mode']}")

# Enable LoRA configuration if requested
if lora:
config["actor_rollout_ref"]["model"]["lora_rank"] = lora_rank
Expand All @@ -175,6 +207,19 @@ def train(
}
print("Trajectory level enabled in trace aggregator.")

# =========================================================================
# Tool Call Filtering (Youtu-Agent style)
# Filters out turns where the model generates unexpected content after
# a tool call (hallucinated tool responses). Helps prevent entropy explosion.
# =========================================================================
if filter_unexpected_tool_calls:
if "agentlightning" not in config:
config["agentlightning"] = {"trace_aggregator": {}}
if "trace_aggregator" not in config["agentlightning"]:
config["agentlightning"]["trace_aggregator"] = {}
config["agentlightning"]["trace_aggregator"]["filter_unexpected_tool_calls"] = True
print("Tool call filtering enabled (Youtu-Agent style).")

# CI toggle keeps everything else the same but you can tweak the lightweight bits here if desired
if ci or ci_fast:
# Config the experiment name and project name so that they are available to CI
Expand Down Expand Up @@ -290,6 +335,35 @@ def main():
default=None,
help="MongoDB URI to use for the store.",
)
parser.add_argument(
"--filter-unexpected-tool-calls",
action="store_true",
help="Enable Youtu-Agent style tool call filtering. "
"Filters out turns where the model generates unexpected content after a tool call.",
)
parser.add_argument(
"--experiment-name",
type=str,
default=None,
help="Custom experiment name for W&B logging (default: calc_x or auto-generated for CI)",
)
parser.add_argument(
"--n-gpus",
type=int,
default=1,
help="Number of GPUs per node for distributed training (default: 1)",
)
parser.add_argument(
"--checkpoint-dir",
type=str,
default="/home/jovyan/msra/experiments/checkpoints",
help="Directory to save checkpoints (default: /home/jovyan/msra/experiments/checkpoints)",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your careful review and for raising this question.

To clarify, /home/jovyan is not a specific person's directory—it is the default home directory name on the OpenHPC server provided by my university (GIST). The msra folder is my personal working directory that I created specifically for this project, which is also linked to my GitHub repository.

I have attached screenshots of my university's HPC-AI Service Portal as evidence. As you can see, /home/jovyan is the default home directory automatically assigned when a workspace is created on this server.
I attached the training code without modification because I wanted to transparently show exactly how the experiments were conducted. However, I realize now that I should have cleaned up these internal file paths before submission. I apologize for any confusion this may have caused—this is my first time collaborating with an industry partner, and I was not aware this could raise concerns.

스크린샷 2026-01-27 오후 5 16 46

)
parser.add_argument(
"--resume",
action="store_true",
help="Resume training from the latest checkpoint in checkpoint-dir",
)

args = parser.parse_args()

Expand Down Expand Up @@ -321,6 +395,11 @@ def main():
trajectory_level=args.trajectory_level,
weave=args.weave,
mongo_uri=args.mongo_uri,
filter_unexpected_tool_calls=args.filter_unexpected_tool_calls,
experiment_name=args.experiment_name,
n_gpus=args.n_gpus,
checkpoint_dir=args.checkpoint_dir,
resume=args.resume,
)


Expand Down