diff --git a/tests/perf/experimental/trace_writer_test.py b/tests/perf/experimental/trace_writer_test.py index 1b071f00d..7cacdc1f8 100644 --- a/tests/perf/experimental/trace_writer_test.py +++ b/tests/perf/experimental/trace_writer_test.py @@ -59,6 +59,22 @@ class PerfettoTraceWriterTest(parameterized.TestCase): {perf_constants.GROUP_ID: 5}, "rollout (group_id=5)", ), + ( + "environment_with_group_id_and_pair_index", + perf_constants.ENVIRONMENT, + { + perf_constants.GROUP_ID: 5, + perf_constants.PAIR_INDEX: 3, + perf_constants.STEP: 100, + }, + "environment (step=100, group_id=5, pair_index=3)", + ), + ( + "environment_with_missing_pair_index", + perf_constants.ENVIRONMENT, + {perf_constants.GROUP_ID: 5}, + "environment (group_id=5)", + ), ( "unknown_span_with_extra_tags", "unknown_span", diff --git a/tests/rl/agentic/trajectory/trajectory_collect_engine_test.py b/tests/rl/agentic/trajectory/trajectory_collect_engine_test.py index dbbe7a3d0..d5ea31c53 100644 --- a/tests/rl/agentic/trajectory/trajectory_collect_engine_test.py +++ b/tests/rl/agentic/trajectory/trajectory_collect_engine_test.py @@ -19,6 +19,8 @@ from absl.testing import absltest import jax.numpy as jnp import numpy as np +from tunix.perf.experimental import constants as perf_constants +from tunix.perf.experimental import tracer as perf_tracer_v2 from tunix.rl.agentic import utils from tunix.rl.agentic.agents import agent_types from tunix.rl.agentic.agents import base_agent @@ -115,6 +117,59 @@ def _mock_rollout_output(text, tokens): async def _run_collect(self, engine, mode='Trajectory'): return await engine.collect(mode=mode) + def test_get_perf_tags(self): + self.mock_env.extra_kwargs = { + 'group_id': 'test_group', + 'pair_index': 42, + } + self.mock_env.task = { + 'policy_version': 'v1.0', + } + engine = trajectory_collect_engine.TrajectoryCollectEngine( + agent=self.mock_agent, + env=self.mock_env, + model_call=self.mock_model_call, + ) + tags = engine._get_perf_tags() + expected_tags = { + perf_constants.GROUP_ID: 'test_group', + perf_constants.PAIR_INDEX: 42, + perf_constants.STEP: 'v1.0', + } + self.assertEqual(tags, expected_tags) + + def test_get_perf_tags_missing_attributes(self): + del self.mock_env.extra_kwargs + del self.mock_env.task + engine = trajectory_collect_engine.TrajectoryCollectEngine( + agent=self.mock_agent, + env=self.mock_env, + model_call=self.mock_model_call, + ) + tags = engine._get_perf_tags() + self.assertEqual(tags, {}) + + def test_perf_v2_and_noop_used_by_default(self): + self.mock_env.max_steps = 1 + self.mock_env.step.return_value = ('obs1', 1.0, True, {}) + self.mock_env.extra_kwargs = {'group_id': 'test_group'} + + engine = trajectory_collect_engine.TrajectoryCollectEngine( + agent=self.mock_agent, + env=self.mock_env, + model_call=self.mock_model_call, + ) + self.assertIsInstance(engine.perf_v2, perf_tracer_v2.NoopTracer) + with mock.patch.object(engine.perf_v2, 'span', autospec=True) as mock_span: + mock_span.return_value.__enter__.return_value = ( + perf_tracer_v2.AsyncWaitlist() + ) + asyncio.run(self._run_collect(engine, mode='Trajectory')) + mock_span.assert_called_once_with( + perf_constants.ENVIRONMENT, + tags={perf_constants.GROUP_ID: 'test_group'}, + ) + def test_collect_trajectory_mode(self): self.mock_env.max_steps = 5 self.mock_env.reward_fn.return_value = 0.5 diff --git a/tunix/perf/experimental/constants.py b/tunix/perf/experimental/constants.py index c313d621e..081443484 100644 --- a/tunix/perf/experimental/constants.py +++ b/tunix/perf/experimental/constants.py @@ -33,3 +33,4 @@ OLD_ACTOR_INFERENCE = "old_actor_inference" ADVANTAGE_COMPUTATION = "advantage_computation" PEFT_TRAIN = "peft_train" +ENVIRONMENT = "environment" diff --git a/tunix/perf/experimental/trace_writer.py b/tunix/perf/experimental/trace_writer.py index cbc6c80d7..f0dddf024 100644 --- a/tunix/perf/experimental/trace_writer.py +++ b/tunix/perf/experimental/trace_writer.py @@ -57,11 +57,12 @@ def _create_span_name(name: str, tags: Mapping[str, Any]) -> str: perf_constants.REFERENCE_INFERENCE, perf_constants.OLD_ACTOR_INFERENCE, perf_constants.ADVANTAGE_COMPUTATION, + perf_constants.ENVIRONMENT, ]: if perf_constants.GROUP_ID in tags: parts.append(f"group_id={tags[perf_constants.GROUP_ID]}") - if name == perf_constants.ROLLOUT: + if name in [perf_constants.ROLLOUT, perf_constants.ENVIRONMENT]: if perf_constants.PAIR_INDEX in tags: parts.append(f"pair_index={tags[perf_constants.PAIR_INDEX]}") diff --git a/tunix/rl/agentic/agentic_rl_learner.py b/tunix/rl/agentic/agentic_rl_learner.py index 295dad76f..093c07051 100644 --- a/tunix/rl/agentic/agentic_rl_learner.py +++ b/tunix/rl/agentic/agentic_rl_learner.py @@ -416,6 +416,7 @@ def _build_orchestrator(self) -> rollout_orchestrator.RolloutOrchestrator: tokenizer=self.tokenizer, chat_parser=self.chat_parser, timeout=self.algo_config.episode_timeout, + perf_v2=self.rl_cluster.perf_v2, ) return rollout_orchestrator.RolloutOrchestrator( engine_cls=trajectory_collect_engine.TrajectoryCollectEngine, diff --git a/tunix/rl/agentic/trajectory/trajectory_collect_engine.py b/tunix/rl/agentic/trajectory/trajectory_collect_engine.py index 1b8c7b69e..9c72e37bb 100644 --- a/tunix/rl/agentic/trajectory/trajectory_collect_engine.py +++ b/tunix/rl/agentic/trajectory/trajectory_collect_engine.py @@ -26,6 +26,8 @@ from absl import logging import numpy as np +from tunix.perf.experimental import constants as perf_constants +from tunix.perf.experimental import tracer as perf_tracer_v2 from tunix.rl.agentic import utils from tunix.rl.agentic.agents import agent_types from tunix.rl.agentic.agents import base_agent @@ -67,6 +69,7 @@ def __init__( tokenizer=None, chat_parser=None, valid_statuses: Optional[Set[agent_types.TrajectoryStatus]] = None, + perf_v2: Optional[perf_tracer_v2.Tracer] = None, ): """Initialize the trajectory collection engine. @@ -93,6 +96,8 @@ def __init__( chat_parser: Optional chat parser for formatting messages valid_statuses (Set[TrajectoryStatus]): A set of statuses that are considered not "penalized" for reward computation. + perf_v2 (Optional[perf_tracer_v2.Tracer]): Optional performance tracer + to use for performance measurements. Defaults to a no-op tracer. """ self.agent = agent self.env = env @@ -111,6 +116,7 @@ def __init__( self.valid_statuses = valid_statuses or { agent_types.TrajectoryStatus.SUCCEEDED } + self.perf_v2 = perf_v2 or perf_tracer_v2.NoopTracer() self.env_time: float = 0.0 if self.max_context_limit and not (self.tokenizer and self.chat_parser): @@ -266,6 +272,7 @@ async def collect_multiple( max_context_limit: Optional[int] = None, timeout: float = 30.0, mode: str = "Trajectory", + perf_v2: Optional[perf_tracer_v2.Tracer] = None, ) -> AsyncGenerator[Tuple[int, Any], None]: """Execute multiple agent-environment pairs concurrently. @@ -281,6 +288,8 @@ async def collect_multiple( max_context_limit (Optional[int]): Maximum context limit per episode timeout (float): Per-episode timeout in seconds mode (str): Output format. See `collect` method for options. + perf_v2 (Optional[perf_tracer_v2.Tracer]): Optional performance tracer + to use for performance measurements. Yields: Tuple[int, Any]: `(pair_index, result)`. The type of `result` @@ -296,6 +305,7 @@ async def _run_one(i: int, agent: ConversationAgentBase, env: BaseTaskEnv): gamma=gamma, max_context_limit=max_context_limit, timeout=timeout, + perf_v2=perf_v2, ) traj = await engine.collect(mode=mode) return i, traj @@ -331,6 +341,22 @@ async def _reset(self): self._start_ts = time.time() + def _get_perf_tags(self) -> Dict[str, Any]: + """Extracts performance tracing tags from the environment.""" + tags = {} + if hasattr(self.env, "extra_kwargs"): + group_id = self.env.extra_kwargs.get("group_id") + if group_id is not None: + tags[perf_constants.GROUP_ID] = group_id + pair_index = self.env.extra_kwargs.get("pair_index") + if pair_index is not None: + tags[perf_constants.PAIR_INDEX] = pair_index + if hasattr(self.env, "task"): + policy_version = self.env.task.get("policy_version") + if policy_version is not None: + tags[perf_constants.STEP] = policy_version + return tags + async def _one_step(self) -> bool: """Executes a single step and returns the Step object and Done status. @@ -364,14 +390,20 @@ def clocked_env_step(action): t_delta = time.thread_time() - t_start return result, t_delta - ( - obs, - rew, - done, - info, - ), thread_delta = await asyncio.get_event_loop().run_in_executor( - None, clocked_env_step, action - ) + tags = self._get_perf_tags() + with self.perf_v2.span( + perf_constants.ENVIRONMENT, + tags=tags, + ): + + ( + obs, + rew, + done, + info, + ), thread_delta = await asyncio.get_event_loop().run_in_executor( + None, clocked_env_step, action + ) self.env_time += thread_delta self.agent.update_from_env(obs, rew, done, info)