Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/perf/experimental/trace_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ class PerfettoTraceWriterTest(parameterized.TestCase):
{perf_constants.GROUP_ID: 5},
"rollout (group_id=5)",
),
(
"environment_with_group_id_and_pair_index",
perf_constants.ENVIRONMENT,
{
perf_constants.GROUP_ID: 5,
perf_constants.PAIR_INDEX: 3,
perf_constants.STEP: 100,
},
"environment (step=100, group_id=5, pair_index=3)",
),
(
"environment_with_missing_pair_index",
perf_constants.ENVIRONMENT,
{perf_constants.GROUP_ID: 5},
"environment (group_id=5)",
),
(
"unknown_span_with_extra_tags",
"unknown_span",
Expand Down
55 changes: 55 additions & 0 deletions tests/rl/agentic/trajectory/trajectory_collect_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from absl.testing import absltest
import jax.numpy as jnp
import numpy as np
from tunix.perf.experimental import constants as perf_constants
from tunix.perf.experimental import tracer as perf_tracer_v2
from tunix.rl.agentic import utils
from tunix.rl.agentic.agents import agent_types
from tunix.rl.agentic.agents import base_agent
Expand Down Expand Up @@ -115,6 +117,59 @@ def _mock_rollout_output(text, tokens):
async def _run_collect(self, engine, mode='Trajectory'):
return await engine.collect(mode=mode)

def test_get_perf_tags(self):
self.mock_env.extra_kwargs = {
'group_id': 'test_group',
'pair_index': 42,
}
self.mock_env.task = {
'policy_version': 'v1.0',
}
engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
)
tags = engine._get_perf_tags()
expected_tags = {
perf_constants.GROUP_ID: 'test_group',
perf_constants.PAIR_INDEX: 42,
perf_constants.STEP: 'v1.0',
}
self.assertEqual(tags, expected_tags)

def test_get_perf_tags_missing_attributes(self):
del self.mock_env.extra_kwargs
del self.mock_env.task
engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
)
tags = engine._get_perf_tags()
self.assertEqual(tags, {})

def test_perf_v2_and_noop_used_by_default(self):
self.mock_env.max_steps = 1
self.mock_env.step.return_value = ('obs1', 1.0, True, {})
self.mock_env.extra_kwargs = {'group_id': 'test_group'}

engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
)
self.assertIsInstance(engine.perf_v2, perf_tracer_v2.NoopTracer)
with mock.patch.object(engine.perf_v2, 'span', autospec=True) as mock_span:
mock_span.return_value.__enter__.return_value = (
perf_tracer_v2.AsyncWaitlist()
)
asyncio.run(self._run_collect(engine, mode='Trajectory'))
mock_span.assert_called_once_with(
perf_constants.ENVIRONMENT,
tags={perf_constants.GROUP_ID: 'test_group'},
)

def test_collect_trajectory_mode(self):
self.mock_env.max_steps = 5
self.mock_env.reward_fn.return_value = 0.5
Expand Down
1 change: 1 addition & 0 deletions tunix/perf/experimental/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@
OLD_ACTOR_INFERENCE = "old_actor_inference"
ADVANTAGE_COMPUTATION = "advantage_computation"
PEFT_TRAIN = "peft_train"
ENVIRONMENT = "environment"
3 changes: 2 additions & 1 deletion tunix/perf/experimental/trace_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ def _create_span_name(name: str, tags: Mapping[str, Any]) -> str:
perf_constants.REFERENCE_INFERENCE,
perf_constants.OLD_ACTOR_INFERENCE,
perf_constants.ADVANTAGE_COMPUTATION,
perf_constants.ENVIRONMENT,
]:
if perf_constants.GROUP_ID in tags:
parts.append(f"group_id={tags[perf_constants.GROUP_ID]}")

if name == perf_constants.ROLLOUT:
if name in [perf_constants.ROLLOUT, perf_constants.ENVIRONMENT]:
if perf_constants.PAIR_INDEX in tags:
parts.append(f"pair_index={tags[perf_constants.PAIR_INDEX]}")

Expand Down
1 change: 1 addition & 0 deletions tunix/rl/agentic/agentic_rl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def _build_orchestrator(self) -> rollout_orchestrator.RolloutOrchestrator:
tokenizer=self.tokenizer,
chat_parser=self.chat_parser,
timeout=self.algo_config.episode_timeout,
perf_v2=self.rl_cluster.perf_v2,
)
return rollout_orchestrator.RolloutOrchestrator(
engine_cls=trajectory_collect_engine.TrajectoryCollectEngine,
Expand Down
48 changes: 40 additions & 8 deletions tunix/rl/agentic/trajectory/trajectory_collect_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from absl import logging
import numpy as np
from tunix.perf.experimental import constants as perf_constants
from tunix.perf.experimental import tracer as perf_tracer_v2
from tunix.rl.agentic import utils
from tunix.rl.agentic.agents import agent_types
from tunix.rl.agentic.agents import base_agent
Expand Down Expand Up @@ -67,6 +69,7 @@ def __init__(
tokenizer=None,
chat_parser=None,
valid_statuses: Optional[Set[agent_types.TrajectoryStatus]] = None,
perf_v2: Optional[perf_tracer_v2.Tracer] = None,
):
"""Initialize the trajectory collection engine.

Expand All @@ -93,6 +96,8 @@ def __init__(
chat_parser: Optional chat parser for formatting messages
valid_statuses (Set[TrajectoryStatus]): A set of statuses that are
considered not "penalized" for reward computation.
perf_v2 (Optional[perf_tracer_v2.Tracer]): Optional performance tracer
to use for performance measurements. Defaults to a no-op tracer.
"""
self.agent = agent
self.env = env
Expand All @@ -111,6 +116,7 @@ def __init__(
self.valid_statuses = valid_statuses or {
agent_types.TrajectoryStatus.SUCCEEDED
}
self.perf_v2 = perf_v2 or perf_tracer_v2.NoopTracer()
self.env_time: float = 0.0

if self.max_context_limit and not (self.tokenizer and self.chat_parser):
Expand Down Expand Up @@ -266,6 +272,7 @@ async def collect_multiple(
max_context_limit: Optional[int] = None,
timeout: float = 30.0,
mode: str = "Trajectory",
perf_v2: Optional[perf_tracer_v2.Tracer] = None,
) -> AsyncGenerator[Tuple[int, Any], None]:
"""Execute multiple agent-environment pairs concurrently.

Expand All @@ -281,6 +288,8 @@ async def collect_multiple(
max_context_limit (Optional[int]): Maximum context limit per episode
timeout (float): Per-episode timeout in seconds
mode (str): Output format. See `collect` method for options.
perf_v2 (Optional[perf_tracer_v2.Tracer]): Optional performance tracer
to use for performance measurements.

Yields:
Tuple[int, Any]: `(pair_index, result)`. The type of `result`
Expand All @@ -296,6 +305,7 @@ async def _run_one(i: int, agent: ConversationAgentBase, env: BaseTaskEnv):
gamma=gamma,
max_context_limit=max_context_limit,
timeout=timeout,
perf_v2=perf_v2,
)
traj = await engine.collect(mode=mode)
return i, traj
Expand Down Expand Up @@ -331,6 +341,22 @@ async def _reset(self):

self._start_ts = time.time()

def _get_perf_tags(self) -> Dict[str, Any]:
"""Extracts performance tracing tags from the environment."""
tags = {}
if hasattr(self.env, "extra_kwargs"):
group_id = self.env.extra_kwargs.get("group_id")
if group_id is not None:
tags[perf_constants.GROUP_ID] = group_id
pair_index = self.env.extra_kwargs.get("pair_index")
if pair_index is not None:
tags[perf_constants.PAIR_INDEX] = pair_index
if hasattr(self.env, "task"):
policy_version = self.env.task.get("policy_version")
if policy_version is not None:
tags[perf_constants.STEP] = policy_version
return tags

async def _one_step(self) -> bool:
"""Executes a single step and returns the Step object and Done status.

Expand Down Expand Up @@ -364,14 +390,20 @@ def clocked_env_step(action):
t_delta = time.thread_time() - t_start
return result, t_delta

(
obs,
rew,
done,
info,
), thread_delta = await asyncio.get_event_loop().run_in_executor(
None, clocked_env_step, action
)
tags = self._get_perf_tags()
with self.perf_v2.span(
perf_constants.ENVIRONMENT,
tags=tags,
):

(
obs,
rew,
done,
info,
), thread_delta = await asyncio.get_event_loop().run_in_executor(
None, clocked_env_step, action
)
self.env_time += thread_delta

self.agent.update_from_env(obs, rew, done, info)
Expand Down
Loading