diff --git a/README.md b/README.md index cd49db0a8..5bf3644fe 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,7 @@ Log File: /home/user/Code/aiperf/artifacts/granite4:350m-openai-chat-concurrency ### Workloads and Data - [Trace Benchmarking](docs/benchmark_modes/trace_replay.md) - Deterministic workload replay - [Bailian Traces](docs/tutorials/bailian-trace.md) - Bailian production trace replay +- [Conflux Traces](docs/tutorials/conflux-trace.md) - Replay AI coding assistant sessions (agents + subagents) - [Custom Prompt Benchmarking](docs/tutorials/custom-prompt-benchmarking.md) - Send exact prompts as-is - [Custom Dataset](docs/tutorials/custom-dataset.md) - Custom dataset formats - [ShareGPT Dataset](docs/tutorials/sharegpt.md) - Profile with ShareGPT dataset diff --git a/docs/cli-options.md b/docs/cli-options.md index de1f8ec1f..3ed63d0e2 100644 --- a/docs/cli-options.md +++ b/docs/cli-options.md @@ -14,7 +14,7 @@ Install shell completion for this application. ### [`analyze-trace`](#aiperf-analyze-trace) -Analyze a mooncake trace file for ISL/OSL distributions and cache hit rates. +Analyze a trace file or directory for distributions and statistics. ### [`profile`](#aiperf-profile) @@ -54,15 +54,17 @@ Output path for the completion script. If not specified, uses shell-specific def ## `aiperf analyze-trace` -Analyze a mooncake trace file for ISL/OSL distributions and cache hit rates. +Analyze a trace file or directory for distributions and statistics. + +Auto-detects the format: - Conflux JSON (file or directory of files): conversation structure, token distributions, timing - JSONL traces (Mooncake/Bailian): ISL/OSL distributions, prefix cache hit rates #### `--input-file` `` _(Required)_ -Path to input mooncake trace JSONL file. +Path to trace file or directory. #### `--block-size` `` -KV cache block size for analysis (default: 512). +KV cache block size for JSONL prefix analysis (default: 512).
_Default: `512`_ #### `--output-file` `` @@ -223,6 +225,11 @@ Start offset in milliseconds for fixed schedule replay. Skips all requests befor End offset in milliseconds for fixed schedule replay. Stops issuing requests after this timestamp, allowing benchmark of specific trace subsets. Requests at exactly the end offset are included. Defaults to last timestamp in dataset. Must be ≥ `--fixed-schedule-start-offset` if both specified.
_Constraints: ≥ 0_ +#### `--fixed-schedule-speedup` `` + +Scaling factor for fixed schedule timestamps. A value of 2.0 replays the schedule twice as fast (halving inter-request delays), while 0.5 replays at half speed (doubling delays). Applied at the timing layer to any dataset using `--fixed-schedule`. +
_Constraints: > 0_ + #### `--public-dataset` `` Pre-configured public dataset to download and use for benchmarking (e.g., `sharegpt`). AIPerf automatically downloads and parses these datasets. Mutually exclusive with `--custom-dataset-type`. Run `aiperf plugins public_dataset_loader` to list available datasets. @@ -230,8 +237,8 @@ Pre-configured public dataset to download and use for benchmarking (e.g., `share #### `--custom-dataset-type` `` -Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), `mooncake_trace`/`bailian_trace` (timestamped trace files), `random_pool` (directory of reusable prompts; when using `random_pool`, `--conversation-num` defaults to 100 if not specified; batch sizes > 1 sample each modality independently from a flat pool and do not preserve per-entry associations — use `single_turn` if paired modalities must stay together). Requires `--input-file`. Mutually exclusive with `--public-dataset`. -
_Choices: [`bailian_trace`, `mooncake_trace`, `multi_turn`, `random_pool`, `single_turn`]_ +Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), `mooncake_trace`/`bailian_trace` (timestamped trace files), `conflux` (Conflux proxy capture with agent threading and subagent orchestration), `random_pool` (directory of reusable prompts; when using `random_pool`, `--conversation-num` defaults to 100 if not specified; batch sizes > 1 sample each modality independently from a flat pool and do not preserve per-entry associations — use `single_turn` if paired modalities must stay together). If omitted, the format is auto-detected from file contents. Requires `--input-file`. Mutually exclusive with `--public-dataset`. +
_Choices: [`bailian_trace`, `conflux`, `mooncake_trace`, `multi_turn`, `random_pool`, `single_turn`]_ #### `--dataset-sampling-strategy` `` @@ -246,6 +253,11 @@ Random seed for deterministic data generation. When set, makes synthetic prompts Specify service level objectives (SLOs) for goodput as space-separated 'KEY:VALUE' pairs, where KEY is a metric tag and VALUE is a number in the metric's display unit (falls back to its base unit if no display unit is defined). Examples: 'request_latency:250' (ms), 'inter_token_latency:10' (ms), `output_token_throughput_per_user:600` (tokens/s). Only metrics applicable to the current endpoint/config are considered. For more context on the definition of goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 and the blog: https://hao-ai-lab.github.io/blogs/distserve. +#### `--conflux-include-utility-calls` + +Include unattributed utility calls when loading Conflux proxy captures. These are lightweight model calls made by the client for housekeeping tasks (topic detection, title generation) that lack an agent_id and may fall outside the main session timeline. Applies when Conflux format is specified via --custom-dataset-type conflux or auto-detected from file contents. +
_Flag (no value required)_ + ### Audio Input #### `--audio-batch-size`, `--batch-size-audio` `` diff --git a/docs/tutorials/conflux-trace.md b/docs/tutorials/conflux-trace.md new file mode 100644 index 000000000..5b2b04365 --- /dev/null +++ b/docs/tutorials/conflux-trace.md @@ -0,0 +1,291 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +sidebar-title: Replaying AI Coding Sessions +--- + +# Replaying AI Coding Sessions with Conflux Traces + +Benchmark your LLM inference server using real-world traffic captured from AI coding assistants like Claude Code and OpenAI Codex. Conflux trace files record every API call made during a coding session, including the full conversation history, tool definitions, model parameters, and timing data. AIPerf replays these traces against your server to measure how it handles realistic agentic workloads. + +--- + +## What Is a Conflux Trace? + +When you use an AI coding assistant, it makes many API calls behind the scenes. A single user request ("fix this bug") can trigger a chain of calls: the main agent reasons about the problem, spawns subagents to search code or run tests, and each of those subagents makes its own API calls. A Conflux trace captures all of this activity. + +Each trace file is a JSON array where every element represents one API call. The key fields are: + +| Field | What it tells AIPerf | +|---|---| +| `messages` | The full conversation sent to the model (system prompt, user messages, tool results) | +| `tools` | Tool definitions available to the model (file search, code execution, etc.) | +| `model` | Which model was used (e.g. `claude-opus-4-6`, `gpt-4o`) | +| `agent_id` | Which agent thread made this call | +| `is_subagent` | Whether this was a background subagent or the main agent | +| `timestamp` | When the call was made (used for replay timing) | +| `tokens` | How many tokens were used (input, output, cached, reasoning) | +| `hyperparameters` | Generation settings (temperature, top_p, max_tokens, etc.) | + +--- + +## How Agents and Subagents Work + +A coding session typically has a tree-shaped structure: + +```mermaid +graph TD + P[Parent Agent] --> T1[Turn 1: User asks 'refactor the auth module'] + T1 --> T2[Turn 2: Agent reads files, plans approach] + T2 --> T3[Turn 3: Agent writes new code] + T3 --> T4[Turn 4: Agent summarizes changes] + T2 --> SA[Subagent A: search auth references] + T2 --> SB[Subagent B: run existing tests] + T3 --> SC[Subagent C: run updated tests] +``` + +The **parent agent** is the main conversation thread that the user interacts with. It has multiple turns, each representing one API call in sequence. + +**Subagents** are background tasks spawned by the parent. When the coding assistant needs to search files or run a command, it often launches a separate agent to handle that work in parallel. Each subagent has its own conversation thread with its own API calls. + +AIPerf preserves this structure during replay: + +1. Each agent thread becomes a separate **conversation** with its own sequence of turns +2. Subagent conversations are linked to the parent via **spawn points** and, for blocking work, a later **join point** where the parent consumes the child result +3. Turns within each conversation replay in order with the original timing gaps between them + +There are also **utility calls** -- lightweight API calls for housekeeping tasks like generating conversation titles or detecting topics. These lack an `agent_id` and are excluded by default since they happen outside the main coding workflow. + +--- + +## Trace File Format + +A minimal Conflux trace looks like this: + +```json +[ + { + "id": "req_abc123", + "session_id": "sess_001", + "agent_id": "agent_main", + "is_subagent": false, + "timestamp": "2026-03-15T10:00:00Z", + "model": "claude-opus-4-6", + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Fix the login bug in auth.py"} + ], + "tools": [ + {"type": "function", "function": {"name": "read_file", "parameters": {}}} + ], + "tokens": {"input": 1500, "output": 800, "input_cached": 200}, + "duration_ms": 3200, + "hyperparameters": {"temperature": 1.0, "max_tokens": 4096} + }, + { + "id": "req_def456", + "session_id": "sess_001", + "agent_id": "agent_search_1", + "is_subagent": true, + "timestamp": "2026-03-15T10:00:01.500Z", + "model": "claude-opus-4-6", + "messages": [ + {"role": "user", "content": "Search for all files importing auth module"} + ], + "tokens": {"input": 500, "output": 300}, + "duration_ms": 1100 + } +] +``` + +The first record is a parent agent turn. The second is a subagent that was spawned 1.5 seconds into the parent's first turn. AIPerf detects this overlap automatically and links them together. + +For highest fidelity, traces can include a `base64` field containing the raw request body. When present, AIPerf uses it instead of the top-level `messages` and `tools` fields, preserving every detail of the original API call exactly as it was sent. + +--- + +## Getting Started + +### Auto-Detection + +AIPerf automatically detects Conflux format when your JSON file contains records with `messages` and either `agent_id` + `is_subagent` fields or `source: "proxy"`. No `--custom-dataset-type` flag is needed: + +```bash +aiperf profile \ + --url localhost:8000 \ + --model your-model-name \ + --endpoint-type chat \ + --streaming \ + --input-file my-session.json \ + --fixed-schedule +``` + +You can also be explicit: + +```bash +aiperf profile \ + --url localhost:8000 \ + --model your-model-name \ + --endpoint-type chat \ + --streaming \ + --input-file my-session.json \ + --custom-dataset-type conflux \ + --fixed-schedule +``` + + +Use `--fixed-schedule` to replay requests at their original timestamps. This is automatically enabled when you specify `--custom-dataset-type conflux`, but is recommended when relying on auto-detection. + + +### Speed Up or Slow Down Replay + +Real coding sessions can last minutes or hours. Use `--fixed-schedule-speedup` to compress or expand the timeline: + +```bash +# Replay at 10x speed (a 10-minute session completes in 1 minute) +aiperf profile \ + --url localhost:8000 \ + --model your-model-name \ + --endpoint-type chat \ + --streaming \ + --input-file my-session.json \ + --fixed-schedule \ + --fixed-schedule-speedup 10.0 +``` + +A value of `2.0` replays at double speed, `0.5` replays at half speed (useful for stress testing at lower rates). This works with any dataset that uses `--fixed-schedule`, not just Conflux traces. + +### Include Utility Calls + +By default, AIPerf skips utility calls (API requests without an `agent_id`, typically used for title generation or topic detection). To include them: + +```bash +aiperf profile \ + --url localhost:8000 \ + --model your-model-name \ + --endpoint-type chat \ + --streaming \ + --input-file my-session.json \ + --fixed-schedule \ + --conflux-include-utility-calls +``` + +--- + +## What Gets Replayed + +For each API call in the trace, AIPerf sends: + +- **Messages**: The full conversation history (system prompt, user messages, assistant responses, tool results) normalized to OpenAI-compatible format +- **Tools**: All tool definitions that were available to the model +- **Model**: The model identifier from the trace (override with `--model` if targeting a different model) +- **Hyperparameters**: Per-turn settings like temperature, top_p, and reasoning effort, sent as request parameters +- **Timing**: Original inter-request delays preserved via `--fixed-schedule`, including later subagent joins when a child result is consumed several parent turns after it was spawned +- **Max tokens**: The generation limit from the original request + +AIPerf also records **ground truth** metadata from the trace (original token counts, TTFT, duration) so you can compare your server's performance against the captured baseline. + +--- + +## Spawn and Join Patterns + +Across real Claude Code sessions, two distinct subagent coordination patterns appear: + +### Blocking immediate (gap = 1 parent turn) + +The parent spawns agents via the `Agent` tool and the results come back inline — the tool_result containing the child's output is already present in the cumulative message history by the time the parent makes its next API call. AIPerf detects this by finding the first parent turn whose new messages contain a non-acknowledgement tool_result for the spawn's `Agent` tool_use ID. + +```text +Parent turn N: Agent tool_use ──► child runs + tool_result: "child output here" +Parent turn N+1: JOIN — reads child result, continues +``` + +This is the most common pattern. Children often take longer to complete than the gap implies — the parent reads an early partial result or the child was fast enough. The join turn is N+1 regardless of child wall-clock completion time. + +### Blocking notification (gap = 1 to 183 parent turns) + +The parent spawns agents asynchronously (`"Async agent launched successfully"` acknowledgement) and continues doing other work. When each child finishes, Claude Code injects a `` user message into the parent's context: + +```xml + + ae2c5aa3243b040ea + toolu_019BxppEYd3m7AaX5Hc8BLJq + /tmp/.../tasks/ae2c5aa3243b040ea.output + +``` + +The `tool-use-id` matches the original `Agent` tool_use, allowing AIPerf to identify which spawn each notification belongs to. Each child gets its own join turn — the first parent turn where its notification appears. + +Two sub-cases observed in the wild: + +- **Parent idle** (e.g. pr-review): parent had nothing else to do, so each child's completion immediately triggered a new parent turn (~135 ms gap). Four children → four separate join turns spaced 6–9 parent turns after spawn. +- **Parent busy** (e.g. long-horizon): parent continued active work (182 turns of code porting across files) while children ran in parallel. Notifications all arrived at the final summary turn once the parent finished its own work. + +### Join gate behavior during replay + +When replaying, AIPerf fires the join turn at: + +```text +join_dispatch = max(last_child_end, ideal_join_timestamp) + ~few ms +``` + +Whether the gate or the original schedule is the binding constraint depends on how +fast the server responds relative to the inter-turn gaps. + +```text +GATE-bound (slow server: response time > inter-turn gap, children can't keep up) + + PARENT ──[spawn]──────────────────────────────[gate: waiting]──[JOIN]──▶ + │ ↑ + CHILD └──[=======================================done]──────┘ + ↑ + ideal join timestamp ──────┘ (already passed, irrelevant) + + +SCHED-bound (fast server: response time < inter-turn gap, children finish early) + + PARENT ──[spawn]──────────────────[sched wait]──[JOIN]──▶ + │ ↑ + CHILD └──[================done]........│ + child early, │ + wait for schedule │ + │ + ideal join timestamp ─────────────┘ +``` + +With a slow server, children cannot finish within the inter-turn gaps and the gate +drives the join. With a fast server, children finish with time to spare and the join +waits for the original timestamp. The crossover point is when +`server_response_time ≈ original_inter_turn_gap / speedup`. + +### Why timestamps alone are insufficient + +For idle-parent notification joins, timestamp detection (first parent turn starting after child ends) gives the correct answer since the gap is ~135 ms. For busy-parent notification joins it gives the wrong turn — the parent was actively running and the notification arrived much later. AIPerf uses `` content matching as the authoritative signal. + +--- + +## Understanding the Output + +When AIPerf loads a Conflux trace, it reports the conversation structure: + +```text +Loaded 3 agent threads + 5 utility calls skipped (28 total records) +Converted 3 conversations (28 total turns, 2 subagent children incl. 0 orphans) +``` + +This tells you: +- **3 agent threads**: one parent + two subagents +- **5 utility calls skipped**: housekeeping calls excluded (use `--conflux-include-utility-calls` to include them) +- **3 conversations**: the parent and its two children, each with their own turn sequence + +After the benchmark completes, the standard AIPerf metrics apply: throughput, latency percentiles, time-to-first-token, and inter-token latency, all measured under the realistic traffic pattern from the original session. + +--- + +## Related Tutorials + +- [Profile with Bailian Traces](bailian-trace.md) - Replay Alibaba production traces +- [Trace Replay with Mooncake Traces](../benchmark-modes/trace-replay.md) - Mooncake FAST'25 trace replay +- [Fixed Schedule](fixed-schedule.md) - Precise timestamp-based execution for any dataset +- [Multi-Turn Conversations](multi-turn.md) - Multi-turn conversation benchmarking diff --git a/fern/versions/dev.yml b/fern/versions/dev.yml index a83999381..8861eab9d 100644 --- a/fern/versions/dev.yml +++ b/fern/versions/dev.yml @@ -56,6 +56,8 @@ navigation: path: ../../docs/tutorials/sequence-distributions.md - page: Prefix Data Synthesis Tutorial path: ../../docs/tutorials/prefix-synthesis.md + - page: Replaying AI Coding Sessions with Conflux Traces + path: ../../docs/tutorials/conflux-trace.md - section: Load Patterns & Scheduling collapsed: true contents: diff --git a/src/aiperf/cli_commands/analyze_trace.py b/src/aiperf/cli_commands/analyze_trace.py index 70330d558..78562d375 100644 --- a/src/aiperf/cli_commands/analyze_trace.py +++ b/src/aiperf/cli_commands/analyze_trace.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""CLI command for analyzing mooncake traces.""" +"""CLI command for analyzing trace datasets.""" from __future__ import annotations @@ -17,11 +17,15 @@ def analyze_trace( block_size: int = 512, output_file: Path | None = None, ) -> None: - """Analyze a mooncake trace file for ISL/OSL distributions and cache hit rates. + """Analyze a trace file or directory for distributions and statistics. + + Auto-detects the format: + - Conflux JSON (file or directory of files): conversation structure, token distributions, timing + - JSONL traces (Mooncake/Bailian): ISL/OSL distributions, prefix cache hit rates Args: - input_file: Path to input mooncake trace JSONL file - block_size: KV cache block size for analysis (default: 512) + input_file: Path to trace file or directory + block_size: KV cache block size for JSONL prefix analysis (default: 512) output_file: Optional output path for analysis report (JSON) """ from aiperf.dataset.synthesis.cli import analyze_trace as _analyze_trace diff --git a/src/aiperf/common/config/config_defaults.py b/src/aiperf/common/config/config_defaults.py index 6eb042c7a..6363e12f5 100644 --- a/src/aiperf/common/config/config_defaults.py +++ b/src/aiperf/common/config/config_defaults.py @@ -59,6 +59,7 @@ class InputDefaults: FIXED_SCHEDULE_AUTO_OFFSET = False FIXED_SCHEDULE_START_OFFSET = None FIXED_SCHEDULE_END_OFFSET = None + FIXED_SCHEDULE_SPEEDUP = None GOODPUT = None PUBLIC_DATASET = None CUSTOM_DATASET_TYPE = None diff --git a/src/aiperf/common/config/input_config.py b/src/aiperf/common/config/input_config.py index 9fdf38bf9..550f1d27c 100644 --- a/src/aiperf/common/config/input_config.py +++ b/src/aiperf/common/config/input_config.py @@ -88,6 +88,19 @@ def validate_fixed_schedule_start_and_end_offset(self) -> Self: ) return self + @model_validator(mode="after") + def validate_no_double_speedup(self) -> Self: + """Reject combining synthesis speedup with fixed-schedule speedup.""" + if ( + self.synthesis.speedup_ratio != 1.0 + and self.fixed_schedule_speedup is not None + ): + raise ValueError( + "--synthesis-speedup-ratio and --fixed-schedule-speedup cannot be used together. " + "Use --fixed-schedule-speedup to control replay speed at the timing layer." + ) + return self + @model_validator(mode="after") def validate_dataset_type(self) -> Self: """Validate the different dataset type configuration.""" @@ -260,6 +273,20 @@ def validate_goodput(self) -> Self: ), ] = InputDefaults.FIXED_SCHEDULE_END_OFFSET + fixed_schedule_speedup: Annotated[ + float | None, + Field( + gt=0, + description="Scaling factor for fixed schedule timestamps. A value of 2.0 replays the schedule twice as fast " + "(halving inter-request delays), while 0.5 replays at half speed (doubling delays). " + "Applied at the timing layer to any dataset using `--fixed-schedule`.", + ), + CLIParameter( + name=("--fixed-schedule-speedup",), + group=_CLI_GROUP, + ), + ] = InputDefaults.FIXED_SCHEDULE_SPEEDUP + public_dataset: Annotated[ PublicDatasetType | None, Field( @@ -278,10 +305,12 @@ def validate_goodput(self) -> Self: Field( description="Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. " "Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), " - "`mooncake_trace`/`bailian_trace` (timestamped trace files), `random_pool` (directory of reusable prompts; " + "`mooncake_trace`/`bailian_trace` (timestamped trace files), `conflux` (Conflux proxy capture with " + "agent threading and subagent orchestration), `random_pool` (directory of reusable prompts; " "when using `random_pool`, `--conversation-num` defaults to 100 if not specified; " "batch sizes > 1 sample each modality independently from a flat pool and do not preserve " "per-entry associations — use `single_turn` if paired modalities must stay together). " + "If omitted, the format is auto-detected from file contents. " "Requires `--input-file`. Mutually exclusive with `--public-dataset`.", ), CLIParameter( @@ -341,6 +370,21 @@ def validate_goodput(self) -> Self: ), ] = InputDefaults.GOODPUT + conflux_include_utility_calls: Annotated[ + bool, + Field( + description="Include unattributed utility calls when loading Conflux proxy captures. " + "These are lightweight model calls made by the client for housekeeping tasks " + "(topic detection, title generation) that lack an agent_id and may fall outside " + "the main session timeline. Applies when Conflux format is specified via " + "--custom-dataset-type conflux or auto-detected from file contents.", + ), + CLIParameter( + name=("--conflux-include-utility-calls",), + group=_CLI_GROUP, + ), + ] = False + audio: AudioConfig = AudioConfig() image: ImageConfig = ImageConfig() video: VideoConfig = VideoConfig() diff --git a/src/aiperf/common/config/user_config.py b/src/aiperf/common/config/user_config.py index 040e0b548..b0906813a 100644 --- a/src/aiperf/common/config/user_config.py +++ b/src/aiperf/common/config/user_config.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from aiperf.plugin.schema.schemas import EndpointMetadata +import orjson from orjson import JSONDecodeError from pydantic import BeforeValidator, Field, model_validator from typing_extensions import Self @@ -28,7 +29,6 @@ from aiperf.common.config.output_config import OutputConfig from aiperf.common.config.tokenizer_config import TokenizerConfig from aiperf.common.enums import GPUTelemetryMode, ServerMetricsFormat -from aiperf.common.utils import load_json_str from aiperf.plugin import plugins from aiperf.plugin.enums import ( ArrivalPattern, @@ -118,7 +118,7 @@ def validate_timing_mode(self) -> Self: _logger.info( f"No request count value provided for fixed schedule mode, setting to dataset entry count: {self.loadgen.request_count}" ) - elif self._should_use_fixed_schedule_for_trace_dataset(): + elif self._should_auto_enable_fixed_schedule(): self._timing_mode = TimingMode.FIXED_SCHEDULE _logger.info( f"Automatically enabling fixed schedule mode for {self.input.custom_dataset_type} dataset with timestamps" @@ -348,13 +348,17 @@ def validate_unused_options(self) -> Self: return self - def _should_use_fixed_schedule_for_trace_dataset(self) -> bool: - """Check if a trace dataset has timestamps and should use fixed schedule. + def _should_auto_enable_fixed_schedule(self) -> bool: + """Check if a dataset has timestamps and should use fixed schedule. + + Returns True for any custom dataset loader whose plugin metadata + declares ``supports_timing: true``, provided the actual file + contains timestamp data. Returns: - True if fixed schedule should be enabled for this trace dataset. + True if fixed schedule should be enabled for this dataset. """ - if self.input.custom_dataset_type is None or not plugins.is_trace_dataset( + if self.input.custom_dataset_type is None or not plugins.supports_timing( self.input.custom_dataset_type ): return False @@ -363,21 +367,58 @@ def _should_use_fixed_schedule_for_trace_dataset(self) -> bool: return False try: - with open(self.input.file) as f: - for line in f: - if not (line := line.strip()): - continue - try: - data = load_json_str(line) - return "timestamp" in data and data["timestamp"] is not None - except (JSONDecodeError, KeyError): - continue + first_record = self._read_first_record(self.input.file) + return ( + first_record is not None and first_record.get("timestamp") is not None + ) except (OSError, FileNotFoundError): _logger.warning( f"Could not read dataset file {self.input.file} to check for timestamps" ) + return False + + @staticmethod + def _read_first_record(file_path: str) -> dict[str, Any] | None: + """Read the first JSON record from a dataset file. + + Uses file extension to pick the parsing strategy: + ``.json`` files read a bounded probe (1 MB) to avoid loading + multi-GB files into memory. ``.jsonl`` files are parsed line-by-line. + + For directories, reads the first .json file found (sorted by name). + """ + _PROBE_BYTES = 1 << 20 # 1 MB + path = Path(file_path) + if path.is_dir(): + json_files = sorted(path.glob("*.json")) + if not json_files: + return None + path = json_files[0] + with open(path, "rb") as f: + if path.suffix == ".json": + probe = f.read(_PROBE_BYTES) + try: + data = orjson.loads(probe) + except JSONDecodeError: + return None + if isinstance(data, dict): + return data + if isinstance(data, list) and data and isinstance(data[0], dict): + return data[0] + return None + + # JSONL: first non-empty line + for line in f: + if not (line := line.strip()): + continue + try: + record = orjson.loads(line) + if isinstance(record, dict): + return record + except JSONDecodeError: + continue - return False + return None def _count_dataset_entries(self) -> int: """Count the number of valid entries in a custom dataset file. diff --git a/src/aiperf/common/enums/__init__.py b/src/aiperf/common/enums/__init__.py index a1305b0a9..b5383463d 100644 --- a/src/aiperf/common/enums/__init__.py +++ b/src/aiperf/common/enums/__init__.py @@ -25,6 +25,7 @@ MediaType, MessageType, ModelSelectionStrategy, + PrerequisiteKind, PrometheusMetricType, PromptSource, ServerMetricsFormat, @@ -110,6 +111,7 @@ "MetricValueTypeT", "MetricValueTypeVarT", "ModelSelectionStrategy", + "PrerequisiteKind", "PlotMetricDirection", "PowerMetricUnit", "PowerMetricUnitInfo", diff --git a/src/aiperf/common/enums/enums.py b/src/aiperf/common/enums/enums.py index bbe0fe4cc..1b92eb81d 100644 --- a/src/aiperf/common/enums/enums.py +++ b/src/aiperf/common/enums/enums.py @@ -96,6 +96,13 @@ class CommandResponseStatus(CaseInsensitiveStrEnum): UNHANDLED = "unhandled" # The command was received but not handled by any hook +class PrerequisiteKind(CaseInsensitiveStrEnum): + """Type of prerequisite that gates a turn from dispatching.""" + + SPAWN_JOIN = "spawn_join" + """All blocking children from a spawn must complete before the gated turn dispatches.""" + + class ConversationContextMode(CaseInsensitiveStrEnum): """Controls how prior turns are accumulated in multi-turn conversations. diff --git a/src/aiperf/common/models/__init__.py b/src/aiperf/common/models/__init__.py index 7da097643..849578c9a 100644 --- a/src/aiperf/common/models/__init__.py +++ b/src/aiperf/common/models/__init__.py @@ -13,6 +13,7 @@ Audio, Conversation, ConversationMetadata, + ConversationOrigin, DatasetClientMetadata, DatasetMetadata, Image, @@ -20,9 +21,12 @@ Media, MemoryMapClientMetadata, SessionPayloads, + SubagentSpawnInfo, Text, Turn, + TurnGroundTruth, TurnMetadata, + TurnPrerequisite, Video, ) from aiperf.common.models.error_models import ( @@ -151,6 +155,7 @@ "CPUTimes", "Conversation", "ConversationMetadata", + "ConversationOrigin", "CounterMetricData", "CounterSeries", "CounterStats", @@ -227,6 +232,7 @@ "ServiceRunInfo", "SessionPayloads", "SlimRecord", + "SubagentSpawnInfo", "TelemetryExportData", "TelemetryHierarchy", "TelemetryMetrics", @@ -241,7 +247,9 @@ "TokenCounts", "TraceDataExport", "Turn", + "TurnGroundTruth", "TurnMetadata", + "TurnPrerequisite", "Usage", "Video", "VideoResponseData", diff --git a/src/aiperf/common/models/dataset_models.py b/src/aiperf/common/models/dataset_models.py index 42b6ad4ca..94802a6c7 100644 --- a/src/aiperf/common/models/dataset_models.py +++ b/src/aiperf/common/models/dataset_models.py @@ -7,7 +7,7 @@ from pydantic import Field, field_validator -from aiperf.common.enums import ConversationContextMode, MediaType +from aiperf.common.enums import ConversationContextMode, MediaType, PrerequisiteKind from aiperf.common.models.base_models import AIPerfBaseModel from aiperf.common.types import MediaTypeT from aiperf.plugin.enums import DatasetClientStoreType, DatasetSamplingStrategy @@ -103,6 +103,23 @@ class Video(Media): media_type: ClassVar[MediaTypeT] = MediaType.VIDEO +class TurnPrerequisite(AIPerfBaseModel): + """A condition that must be satisfied before a turn dispatches. + + Used by the SubagentOrchestrator to gate turn dispatch on prerequisite + completion. Currently supports 'spawn_join' (all blocking children from + a spawn must complete). Extensible to other gate types. + """ + + kind: PrerequisiteKind = Field( + description="Prerequisite type.", + ) + spawn_id: str | None = Field( + default=None, + description="For spawn_join: which spawn's children must complete.", + ) + + class TurnMetadata(AIPerfBaseModel): """Metadata of a turn.""" @@ -114,6 +131,49 @@ class TurnMetadata(AIPerfBaseModel): default=None, description="The delay of the turn in the conversation (in milliseconds).", ) + input_tokens: int | None = Field( + default=None, + description="Expected input token count for this turn (from trace data). " + "Can be used for per-period token budget enforcement.", + ) + subagent_spawn_ids: list[str] = Field( + default_factory=list, + description="Spawn IDs if this turn is blocked by subagent spawns.", + ) + prerequisites: list[TurnPrerequisite] = Field( + default_factory=list, + description="Conditions that must be met before this turn dispatches.", + ) + + +class TurnGroundTruth(AIPerfBaseModel): + """Original capture metadata for observability and comparison. + + Stores token breakdowns, timing, and output content from the original + API response. Never sent to the inference API — informational only. + """ + + input_cached_tokens: int | None = Field( + default=None, description="Input tokens served from provider cache." + ) + input_cache_write_tokens: int | None = Field( + default=None, description="Input tokens written to provider cache." + ) + output_tokens: int | None = Field( + default=None, description="Output tokens generated in the original response." + ) + output_reasoning_tokens: int | None = Field( + default=None, description="Output tokens used for reasoning/thinking." + ) + ttft_ms: float | None = Field( + default=None, description="Time to first token in milliseconds." + ) + duration_ms: float | None = Field( + default=None, description="Total request duration in milliseconds." + ) + is_streaming: bool | None = Field( + default=None, description="Whether the original request used streaming." + ) class Turn(AIPerfBaseModel): @@ -158,12 +218,37 @@ class Turn(AIPerfBaseModel): videos: list[Video] = Field( default=[], description="Collection of video data in each turn." ) + input_tokens: int | None = Field( + default=None, + description="Expected input token count for this turn (from trace data).", + ) + subagent_spawn_ids: list[str] = Field( + default_factory=list, + description="Spawn IDs if this turn is blocked by subagent spawns.", + ) + prerequisites: list[TurnPrerequisite] = Field( + default_factory=list, + description="Conditions that must be met before this turn dispatches.", + ) + extra_params: dict[str, Any] | None = Field( + default=None, + description="Per-turn hyperparameter overrides merged into the API payload " + "after format_payload(). Populated from dataset capture metadata.", + ) + ground_truth: TurnGroundTruth | None = Field( + default=None, + description="Original capture metadata (token breakdown, timing, output) " + "for observability. Never sent to the inference API.", + ) def metadata(self) -> TurnMetadata: """Get the metadata of the turn.""" return TurnMetadata( timestamp_ms=self.timestamp, delay_ms=self.delay, + input_tokens=self.input_tokens, + subagent_spawn_ids=list(self.subagent_spawn_ids), + prerequisites=list(self.prerequisites), ) def copy_with_stripped_media(self) -> "Turn": @@ -209,9 +294,34 @@ def copy_with_stripped_media(self) -> "Turn": ) for vid in self.videos ], + input_tokens=self.input_tokens, + subagent_spawn_ids=list(self.subagent_spawn_ids), + prerequisites=list(self.prerequisites), + extra_params=dict(self.extra_params) if self.extra_params else None, + ground_truth=self.ground_truth, ) +class SubagentSpawnInfo(AIPerfBaseModel): + """Describes a subagent spawn point linking parent to child conversations. + + When a parent conversation spawns subagents, blocking children must + complete before the gated turn (declared via TurnPrerequisite) dispatches. + Children are separate Conversations with independent hash_ids and sessions. + """ + + spawn_id: str = Field( + description="Subagent spawn identifier, e.g. 's0'.", + ) + child_conversation_ids: list[str] = Field( + description="Conversation IDs of child subagent sessions to start.", + ) + is_background: bool = Field( + default=False, + description="If true, parent continues without waiting for children.", + ) + + class ConversationMetadata(AIPerfBaseModel): """Metadata of a conversation.""" @@ -223,6 +333,18 @@ class ConversationMetadata(AIPerfBaseModel): default_factory=list, description="The metadata of the turns in the conversation.", ) + subagent_spawns: list[SubagentSpawnInfo] = Field( + default_factory=list, + description="Subagent spawn points linking to child conversations.", + ) + agent_depth: int = Field( + default=0, + description="Nesting depth of this conversation. 0=root, 1=child, 2=grandchild, etc.", + ) + parent_conversation_id: str | None = Field( + default=None, + description="Template conversation_id of the parent conversation. None for root conversations.", + ) class DatasetMetadata(AIPerfBaseModel): @@ -278,6 +400,31 @@ def average_turn_count(self) -> float: return self.total_turn_count / len(self.conversations) +class ConversationOrigin(AIPerfBaseModel): + """Source traceability back to the original capture. + + Stores origin metadata so benchmark results can be linked to + the specific capture session and client that produced the data. + """ + + source: str | None = Field( + default=None, description="Record source: proxy, claude, codex." + ) + client: str | None = Field( + default=None, description="Client that produced this record: claude, codex." + ) + client_version: str | None = Field( + default=None, description="Client version string." + ) + original_session_id: str | None = Field( + default=None, description="Original session identifier from the capture." + ) + original_request_ids: list[str] = Field( + default_factory=list, + description="Provider request identifiers, one per turn.", + ) + + class Conversation(AIPerfBaseModel): """A dataset representation of a full conversation. @@ -319,12 +466,33 @@ def _reject_unimplemented_context_mode( description="Optional per-conversation user context prepended to the first turn. " "Unique for each conversation when using --user-context-prompt-length.", ) + agent_depth: int = Field( + default=0, + description="Nesting depth of this conversation. 0=root, 1=child, 2=grandchild, etc.", + ) + parent_conversation_id: str | None = Field( + default=None, + description="Template session_id of the parent conversation. None for root conversations.", + ) + subagent_spawns: list[SubagentSpawnInfo] = Field( + default_factory=list, + description="Subagent spawn points linking to child conversations.", + ) + origin: ConversationOrigin | None = Field( + default=None, + description="Source traceability back to the original capture. " + "Populated by loaders that have origin metadata (e.g. Conflux).", + ) def metadata(self) -> ConversationMetadata: """Get the metadata of the conversation.""" + turn_metas = [turn.metadata() for turn in self.turns] return ConversationMetadata( conversation_id=self.session_id, - turns=[turn.metadata() for turn in self.turns], + turns=turn_metas, + subagent_spawns=self.subagent_spawns, + agent_depth=self.agent_depth, + parent_conversation_id=self.parent_conversation_id, ) diff --git a/src/aiperf/credit/callback_handler.py b/src/aiperf/credit/callback_handler.py index b5d81c974..be53e792c 100644 --- a/src/aiperf/credit/callback_handler.py +++ b/src/aiperf/credit/callback_handler.py @@ -4,11 +4,15 @@ Handles ALL credit lifecycle callbacks (returns + TTFT) directly from CreditRouter. -Key responsibilities: -- Track credit returns (increment_returned, release slots) -- Handle TTFT events (increment_prefill_released, release prefill slot) -- Dispatch next turn to timing strategy (handle_credit_return) -- Cleanup in-flight sessions on phase end +Processing order for credit returns (see SubagentOrchestrator module docstring +for why this ordering is load-bearing for subagent correctness):: + + 1. Atomic counting (increment_returned) + 2. Track prefill release if TTFT never arrived + 3. Release concurrency slots + 4. on_failed_credit for errored/cancelled child gate cleanup + 5. Signal all_credits_returned_event if final return + 6. handle_credit_return → strategy dispatch (with child bypass) """ from __future__ import annotations @@ -54,24 +58,10 @@ class PhaseCallbackContext: class CreditCallbackHandler: """Handles credit lifecycle callbacks from CreditRouter. - Unified callback handler for all phases. - Callback flow: - Worker → CreditRouter → CreditCallbackHandler → [count, release slots, dispatch] - - Processing order for credit returns: - 1. Atomic counting (increment_returned) - 2. Track prefill release if TTFT never arrived - 3. Release concurrency slots - 4. Dispatch next turn via timing strategy (if applicable) + Worker -> CreditRouter -> CreditCallbackHandler -> [count, release, dispatch] - Processing order for TTFT: - 1. Track prefill release (increment_prefill_released) - 2. Release prefill slot - - Phase Registration: - PhaseRunner calls register_phase() BEFORE any credits are sent. - This ensures callbacks work from the first credit. + PhaseRunner calls register_phase() BEFORE any credits are sent. """ def __init__(self, concurrency_manager: ConcurrencyManager) -> None: @@ -128,18 +118,7 @@ def unregister_phase(self, phase: CreditPhase) -> None: async def on_credit_return( self, worker_id: str, credit_return: CreditReturn ) -> None: - """Handle credit return from worker. - - Processing order: - 1. Atomic counting (increment_returned) - 2. Track prefill release if TTFT never arrived - 3. Release concurrency slots - 4. Dispatch next turn via strategy (if applicable) - - Args: - worker_id: ID of the worker returning the credit. - credit_return: Return details including credit and status. - """ + """Handle credit return from worker. See module docstring for step ordering.""" credit = credit_return.credit phase = credit.phase @@ -164,6 +143,7 @@ async def on_credit_return( is_final_returned = handler.progress.increment_returned( credit.is_final_turn, credit_return.cancelled, + agent_depth=credit.agent_depth, ) # 2. Track prefill release if TTFT never arrived @@ -175,14 +155,25 @@ async def on_credit_return( phase, credit, credit_return, is_final_returned, handler ) - # 4. Signal completion if this was the final return + # 4. on_failed_credit for errored/cancelled child gate cleanup. + # ORDER MATTERS: Must run BEFORE handle_credit_return (step 6) so + # terminate_child marks the child before _handle_child_credit checks + # _is_terminated. Reversing steps 4 and 6 causes zombie child dispatch. + # Not gated by can_send_any_turn -- this is bookkeeping, not new work. + if credit_return.error or credit_return.cancelled: + handler.strategy.on_failed_credit(credit_return) + + # 5. Signal all_credits_returned_event if final return if is_final_returned: handler.progress.all_credits_returned_event.set() - # 5. Notify timing strategy for subsequent turns when phase can still send - # Timing strategy queues subsequent turns for rate-limited issuance. - # Skipped when phase can't send - if handler.stop_checker.can_send_any_turn(): + # 6. handle_credit_return with child bypass. + # Child returns (depth > 0) MUST always reach the orchestrator for gate + # accounting, even after stop fires. Without this bypass, child final + # returns are silently dropped, leaving parent gates permanently stuck. + # The orchestrator has its own guards against post-stop dispatch + # (see SubagentOrchestrator module docstring "Stop Condition Interaction"). + if handler.stop_checker.can_send_any_turn() or credit.agent_depth > 0: await handler.strategy.handle_credit_return(credit) def _release_slots_for_return( @@ -209,8 +200,9 @@ def _release_slots_for_return( """ concurrency = handler.concurrency_manager - # Release session slot when conversation ends (final turn, whether completed or cancelled) - if credit.is_final_turn: + # Release session slot when root conversation ends (final turn, whether completed or cancelled). + # Child sessions (depth > 0) never acquired a session slot, so skip release. + if credit.is_final_turn and credit.agent_depth == 0: concurrency.release_session_slot(phase) # On phase end, release slots for sessions still in flight. diff --git a/src/aiperf/credit/issuer.py b/src/aiperf/credit/issuer.py index 811695d3f..56be21a52 100644 --- a/src/aiperf/credit/issuer.py +++ b/src/aiperf/credit/issuer.py @@ -114,19 +114,22 @@ async def issue_credit(self, turn: TurnToSend) -> bool: 6. If final credit: freeze counts + set event """ is_first_turn = turn.turn_index == 0 + is_child = turn.agent_depth > 0 - # Select appropriate check function based on turn type - # - First turns need can_start_new_session (more restrictive - checks session quota) - # - Subsequent turns use can_send_any_turn (less restrictive - allows finishing existing sessions) + # Select appropriate check function based on turn type. + # - Root first turns need can_start_new_session (checks session quota) + # - Child turns use can_send_any_turn (continuation work, not new user sessions) + # - Root subsequent turns use can_send_any_turn (finishing existing sessions) can_proceed_fn = ( self._stop_checker.can_start_new_session - if is_first_turn + if is_first_turn and not is_child else self._stop_checker.can_send_any_turn ) - # Session concurrency: one slot per conversation, acquired on first turn only. - # Controls how many multi-turn conversations can be active simultaneously. - if is_first_turn: + # Session concurrency: one slot per root conversation, acquired on first turn only. + # Children skip session slot acquisition to avoid contention with root sessions. + needs_session_slot = is_first_turn and not is_child + if needs_session_slot: acquired = await self._concurrency_manager.acquire_session_slot( self._phase, self._stop_checker.can_start_new_session ) @@ -140,7 +143,7 @@ async def issue_credit(self, turn: TurnToSend) -> bool: ) if not acquired: # CRITICAL: Release session slot if we acquired it to maintain symmetry - if is_first_turn: + if needs_session_slot: self._concurrency_manager.release_session_slot(self._phase) return False @@ -162,11 +165,12 @@ async def try_issue_credit(self, turn: TurnToSend) -> bool | None: None: No slots available, credit NOT issued. Retry later. """ is_first_turn = turn.turn_index == 0 + is_child = turn.agent_depth > 0 # Select appropriate check function based on turn type can_proceed_fn = ( self._stop_checker.can_start_new_session - if is_first_turn + if is_first_turn and not is_child else self._stop_checker.can_send_any_turn ) @@ -174,7 +178,8 @@ async def try_issue_credit(self, turn: TurnToSend) -> bool | None: if not can_proceed_fn(): return False - if is_first_turn: + needs_session_slot = is_first_turn and not is_child + if needs_session_slot: acquired = self._concurrency_manager.try_acquire_session_slot( self._phase, can_proceed_fn ) @@ -186,7 +191,7 @@ async def try_issue_credit(self, turn: TurnToSend) -> bool | None: ) if not acquired: # CRITICAL: Release session slot if we acquired it to maintain symmetry - if is_first_turn: + if needs_session_slot: self._concurrency_manager.release_session_slot(self._phase) return None # No slot - credit not issued @@ -227,6 +232,8 @@ async def _issue_credit_internal(self, turn: TurnToSend) -> bool: issued_at_ns=issued_at_ns, cancel_after_ns=cancel_after_ns, url_index=url_index, + agent_depth=turn.agent_depth, + parent_correlation_id=turn.parent_correlation_id, ) await self._credit_router.send_credit(credit=credit) diff --git a/src/aiperf/credit/structs.py b/src/aiperf/credit/structs.py index b5f646351..7f050eff1 100644 --- a/src/aiperf/credit/structs.py +++ b/src/aiperf/credit/structs.py @@ -35,6 +35,8 @@ class Credit( Note: this is NOT the same as the credit being cancelled! url_index: Index of the URL to use when multiple --url values are configured (optional). None means use the default (first) URL. + agent_depth: Nesting depth of this agent in the subagent hierarchy (0 = top-level). + parent_correlation_id: Correlation ID of the parent conversation for subagent tracking. """ id: int @@ -46,6 +48,8 @@ class Credit( issued_at_ns: int cancel_after_ns: int | None = None url_index: int | None = None + agent_depth: int = 0 + parent_correlation_id: str | None = None @property def is_final_turn(self) -> bool: @@ -93,6 +97,8 @@ class TurnToSend(Struct, frozen=True): x_correlation_id: str turn_index: int num_turns: int + agent_depth: int = 0 + parent_correlation_id: str | None = None @property def is_final_turn(self) -> bool: @@ -106,4 +112,6 @@ def from_previous_credit(cls, credit: Credit) -> Self: x_correlation_id=credit.x_correlation_id, turn_index=credit.turn_index + 1, num_turns=credit.num_turns, + agent_depth=credit.agent_depth, + parent_correlation_id=credit.parent_correlation_id, ) diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index 3dd25ec25..0890bceef 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -119,7 +119,7 @@ def _infer_type( # Check for explicit type field first (most efficient). # Skip values that aren't known dataset types (e.g. Bailian's "type": "text" # is a request type, not a dataset type) and fall through to structural detection. - if data is not None and data.get("type") in CustomDatasetType: + if isinstance(data, dict) and data.get("type") in CustomDatasetType: explicit_type = CustomDatasetType(data["type"]) LoaderClass = plugins.get_class( PluginType.CUSTOM_DATASET_LOADER, explicit_type diff --git a/src/aiperf/dataset/dataset_manager.py b/src/aiperf/dataset/dataset_manager.py index 2148c6d7a..2ed03991b 100644 --- a/src/aiperf/dataset/dataset_manager.py +++ b/src/aiperf/dataset/dataset_manager.py @@ -362,10 +362,16 @@ async def _configure_dataset(self) -> None: "from WorkerPodManager before accessing dataset" ) + has_timing = any( + turn.timestamp is not None or turn.delay is not None + for conv in conversations + for turn in conv.turns + ) self.dataset_metadata = DatasetMetadata( conversations=[conversation.metadata() for conversation in conversations], sampling_strategy=self.user_config.input.dataset_sampling_strategy, default_context_mode=self._default_context_mode, + has_timing_data=has_timing, ) self.info( f"sampling strategy: {self.dataset_metadata.sampling_strategy}, " diff --git a/src/aiperf/dataset/loader/__init__.py b/src/aiperf/dataset/loader/__init__.py index b637b4c69..fc892bfb1 100644 --- a/src/aiperf/dataset/loader/__init__.py +++ b/src/aiperf/dataset/loader/__init__.py @@ -6,9 +6,11 @@ from aiperf.dataset.loader.base_loader import BaseFileLoader, BaseLoader from aiperf.dataset.loader.base_public_dataset import BasePublicDatasetLoader from aiperf.dataset.loader.base_trace_loader import BaseTraceDatasetLoader +from aiperf.dataset.loader.conflux import ConfluxLoader from aiperf.dataset.loader.mixins import MediaConversionMixin from aiperf.dataset.loader.models import ( BailianTrace, + ConfluxRecord, MooncakeTrace, MultiTurn, RandomPool, @@ -27,6 +29,8 @@ "BaseLoader", "BasePublicDatasetLoader", "BaseTraceDatasetLoader", + "ConfluxLoader", + "ConfluxRecord", "MediaConversionMixin", "MooncakeTrace", "MooncakeTraceDatasetLoader", diff --git a/src/aiperf/dataset/loader/conflux.py b/src/aiperf/dataset/loader/conflux.py new file mode 100644 index 000000000..d7db4129a --- /dev/null +++ b/src/aiperf/dataset/loader/conflux.py @@ -0,0 +1,1078 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Conflux dataset loader for verbatim replay of Claude Code proxy captures. + +Loads JSON files produced by the Conflux proxy containing an array of API +request/response records with explicit agent_id/is_subagent fields for +thread grouping. Supports parent + subagent hierarchies with timestamp-based +inter-turn delays. +""" + +from __future__ import annotations + +import base64 +import functools +import gzip +import re +from datetime import datetime +from pathlib import Path +from typing import Any + +import orjson + +from aiperf.common.config.user_config import UserConfig +from aiperf.common.enums import ConversationContextMode, PrerequisiteKind +from aiperf.common.models import Conversation, Turn +from aiperf.common.models.dataset_models import ( + ConversationOrigin, + SubagentSpawnInfo, + TurnGroundTruth, + TurnPrerequisite, +) +from aiperf.dataset.loader.base_loader import BaseFileLoader +from aiperf.dataset.loader.models import ConfluxRecord +from aiperf.dataset.message_normalizer import normalize_messages +from aiperf.plugin.enums import DatasetSamplingStrategy + + +@functools.lru_cache(maxsize=4096) +def _parse_timestamp_s(iso_str: str) -> float: + """Parse an ISO timestamp string to seconds since epoch (cached).""" + return datetime.fromisoformat(iso_str.replace("Z", "+00:00")).timestamp() + + +def _parse_timestamp_ms(iso_str: str) -> float: + """Parse an ISO timestamp string to milliseconds since epoch.""" + return _parse_timestamp_s(iso_str) * 1000 + + +def _decode_base64_payload(encoded: str) -> bytes: + """Decode a base64 string, auto-detecting gzip/zstd compression. + + Conflux base64 payloads may be gzip or zstd compressed before encoding. + Detects compression via magic bytes and decompresses transparently. + Uncompressed payloads (the common case) are returned as-is. + """ + raw = base64.b64decode(encoded) + # gzip: magic bytes 1f 8b + if len(raw) >= 2 and raw[0] == 0x1F and raw[1] == 0x8B: + return gzip.decompress(raw) + # zstd: magic bytes 28 b5 2f fd + if len(raw) >= 4 and raw[:4] == b"\x28\xb5\x2f\xfd": + try: + import zstandard + + return zstandard.ZstdDecompressor().decompress(raw) + except ImportError as exc: + raise ImportError( + "zstandard package required to decompress zstd-compressed " + "Conflux base64 payloads. Install with: uv add zstandard" + ) from exc + return raw + + +def _messages_from_payload(payload: dict[str, Any]) -> list[dict[str, Any]]: + """Build messages list from a decoded request payload, prepending system when present.""" + msgs = list(payload.get("messages", [])) + system = payload.get("system") + if system is not None: + msgs.insert(0, {"role": "system", "content": system}) + return msgs + + +def _decode_request_payload(record: ConfluxRecord) -> dict[str, Any] | None: + """Decode the base64 request body from a ConfluxRecord, or None if absent.""" + if record.base64 and record.base64.get("request_body"): + return orjson.loads(_decode_base64_payload(record.base64["request_body"])) + return None + + +def _decode_messages(record: ConfluxRecord) -> list[dict[str, Any]]: + """Decode the messages array from a ConfluxRecord, preferring base64 payload.""" + payload = _decode_request_payload(record) + if payload is not None: + return _messages_from_payload(payload) + return list(record.messages) + + +def _record_end_ms(record: ConfluxRecord) -> float: + """Best-effort completion timestamp for a Conflux record.""" + start_ms = _parse_timestamp_ms(record.timestamp) + if record.completed_at: + return _parse_timestamp_ms(record.completed_at) + if record.duration_ms > 0: + return start_ms + record.duration_ms + return start_ms + + +def _new_messages( + previous_messages: list[dict[str, Any]], + current_messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Return messages appended between two consecutive parent turns. + + Assumes append-only growth. Uses length-based slicing (O(1)) since + Conflux conversations grow by appending. Falls back to sequential + comparison only when the prefix assumption fails. + """ + prev_len = len(previous_messages) + if prev_len <= len(current_messages) and ( + prev_len == 0 or previous_messages[-1] == current_messages[prev_len - 1] + ): + return current_messages[prev_len:] + common_prefix_len = 0 + for previous, current in zip(previous_messages, current_messages, strict=False): + if previous != current: + break + common_prefix_len += 1 + return current_messages[common_prefix_len:] + + +def _iter_message_blocks(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Extract structured content blocks from appended messages.""" + blocks: list[dict[str, Any]] = [] + for message in messages: + content = message.get("content", "") + if isinstance(content, list): + blocks.extend(block for block in content if isinstance(block, dict)) + return blocks + + +def _collect_agent_tool_use_ids(messages: list[dict[str, Any]]) -> set[str]: + """Extract all Agent tool_use block IDs from a message list.""" + ids: set[str] = set() + for block in _iter_message_blocks(messages): + if block.get("type") == "tool_use" and block.get("name") == "Agent": + bid = block.get("id") + if isinstance(bid, str): + ids.add(bid) + return ids + + +# Protocol strings from Claude Code's Conflux format for background agent detection +_QUEUED_SIGNAL = "queued for running" +_ASYNC_LAUNCHED_SIGNAL = "Async agent launched" + + +def _stringify_block_content(content: Any) -> str: + """Flatten nested tool-result payloads for heuristic inspection.""" + if isinstance(content, str): + return content + if isinstance(content, list): + return " ".join(_stringify_block_content(item) for item in content) + if isinstance(content, dict): + if isinstance(content.get("text"), str): + return content["text"] + if "content" in content: + return _stringify_block_content(content["content"]) + return orjson.dumps(content).decode() + return str(content) + + +def _detect_join_turn_from_content( + parent_records: list[ConfluxRecord], + spawn_turn_index: int, +) -> tuple[int | None, bool]: + """Infer a later join turn from appended Agent tool_result messages. + + Returns: + (join_turn_index, saw_background_signal) + - join_turn_index: first parent turn whose new messages contain the + result for the Agent tool_use emitted by ``spawn_turn_index``. + - saw_background_signal: True when only queued/async acknowledgements + were observed for that spawn and no later full result was found. + """ + if spawn_turn_index + 1 >= len(parent_records): + return None, False + + spawn_msgs = _decode_messages(parent_records[spawn_turn_index]) + existing_agent_ids = _collect_agent_tool_use_ids(spawn_msgs) + + spawn_tool_use_ids: set[str] = set() + saw_queued_result = False + + prev_msgs = spawn_msgs + for turn_index in range(spawn_turn_index + 1, len(parent_records)): + current_messages = _decode_messages(parent_records[turn_index]) + appended_messages = _new_messages(prev_msgs, current_messages) + prev_msgs = current_messages + + for block in _iter_message_blocks(appended_messages): + if ( + turn_index == spawn_turn_index + 1 + and block.get("type") == "tool_use" + and block.get("name") == "Agent" + ): + block_id = block.get("id") + if not isinstance(block_id, str): + continue + if block_id in existing_agent_ids: + continue + spawn_tool_use_ids.add(block_id) + continue + + if block.get("type") != "tool_result": + continue + + tool_use_id = block.get("tool_use_id") + matches_spawn = ( + isinstance(tool_use_id, str) and tool_use_id in spawn_tool_use_ids + ) + if not matches_spawn: + continue + + result_text = _stringify_block_content(block.get("content", "")) + if _QUEUED_SIGNAL in result_text or _ASYNC_LAUNCHED_SIGNAL in result_text: + saw_queued_result = True + continue + + return turn_index, False + + return None, saw_queued_result + + +def _find_join_turn_index( + parent_records: list[ConfluxRecord], + spawn_turn_index: int, + children: list[tuple[str, list[ConfluxRecord], Conversation]], +) -> int | None: + """Infer which parent turn waits for/consumes a spawn's child results. + + Preference order: + 1. Explicit tool_result content matched to the Agent tool_use. + 2. First later parent turn that starts after all children have completed. + 3. None if the children outlive the parent thread or content explicitly + marks the spawn as background-only. + """ + join_turn_from_content, saw_background_signal = _detect_join_turn_from_content( + parent_records, spawn_turn_index + ) + if join_turn_from_content is not None: + return join_turn_from_content + if saw_background_signal: + return None + + latest_child_end_ms = max( + _record_end_ms(child_records[-1]) for _, child_records, _ in children + ) + + for turn_index in range(spawn_turn_index + 1, len(parent_records)): + parent_turn_start_ms = _parse_timestamp_ms(parent_records[turn_index].timestamp) + if parent_turn_start_ms >= latest_child_end_ms: + return turn_index + + return None + + +_TASK_NOTIFICATION_RE = re.compile( + r".*?(.*?)", + re.DOTALL, +) +_AGENT_ID_RE = re.compile(r"agentId:\s*(\S+)") + + +def _extract_notification_joins( + parent_records: list[ConfluxRecord], +) -> dict[str, int]: + """Scan parent turns for completion signals. + + Returns {Agent_tool_use_id: first_turn_index} for each child whose + completion was signalled via a user message injected + by Claude Code when the background agent finished. + """ + joins: dict[str, int] = {} + prev_msgs: list[dict[str, Any]] = [] + for ti, record in enumerate(parent_records): + curr_msgs = _decode_messages(record) + new_msgs = _new_messages(prev_msgs, curr_msgs) + prev_msgs = curr_msgs + for msg in new_msgs: + content = msg.get("content", "") + texts: list[str] = [] + if isinstance(content, str): + texts.append(content) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + t = block.get("text", "") + if t: + texts.append(t) + for text in texts: + if "" not in text: + continue + for m in _TASK_NOTIFICATION_RE.finditer(text): + tuid = m.group(1).strip() + if tuid not in joins: + joins[tuid] = ti + return joins + + +def _build_spawn_tuid_to_agent_id( + parent_records: list[ConfluxRecord], + spawn_turn_index: int, +) -> dict[str, str]: + """Build tool_use_id -> agent_id for newly-spawned async agents. + + At spawn_turn_index+1, finds "Async agent launched" tool_results and + extracts the agentId for each new Agent tool_use (not previously in the + message history). + """ + if spawn_turn_index + 1 >= len(parent_records): + return {} + + prev = _decode_messages(parent_records[spawn_turn_index]) + existing_ids = _collect_agent_tool_use_ids(prev) + + curr = _decode_messages(parent_records[spawn_turn_index + 1]) + new_msgs = _new_messages(prev, curr) + new_spawn_ids = _collect_agent_tool_use_ids(new_msgs) - existing_ids + + tuid_to_agent: dict[str, str] = {} + for msg in new_msgs: + content = msg.get("content", "") + if isinstance(content, list): + for block in content: + if not (isinstance(block, dict) and block.get("type") == "tool_result"): + continue + tuid = block.get("tool_use_id") + if not isinstance(tuid, str) or tuid not in new_spawn_ids: + continue + result_text = _stringify_block_content(block.get("content", "")) + if _ASYNC_LAUNCHED_SIGNAL not in result_text: + continue + m = _AGENT_ID_RE.search(result_text) + if m: + tuid_to_agent[tuid] = m.group(1) + return tuid_to_agent + + +_CAN_LOAD_PROBE_BYTES = 1 << 20 # 1 MB probe limit for format detection + + +class ConfluxLoader(BaseFileLoader): + """Dataset loader for Conflux proxy capture JSON files. + + Expects a JSON file containing an array of API request records, each with + agent_id, is_subagent, messages, tools, model, and timestamp fields. + """ + + def __init__( + self, + *, + filename: str, + user_config: UserConfig, + **kwargs: Any, + ) -> None: + super().__init__(filename=filename, user_config=user_config, **kwargs) + self._orphan_ids: set[str] = set() + self._orphan_counter: int = 0 + self._file_boundaries: list[set[str]] = [] + + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> bool: + """Return True if filename is a Conflux JSON file or a directory containing one.""" + if filename is None: + return False + path = Path(filename) + if path.is_dir(): + return any(cls._probe_file(f) for f in sorted(path.glob("*.json"))[:5]) + return cls._probe_file(path) + + @classmethod + def _probe_file(cls, path: Path) -> bool: + """Return True if a single file looks like a Conflux JSON capture.""" + if not path.is_file() or path.suffix != ".json": + return False + try: + with open(path, "rb") as f: + probe = f.read(_CAN_LOAD_PROBE_BYTES) + try: + raw = orjson.loads(probe) + except orjson.JSONDecodeError: + # Truncated read of large file - fall back to byte-level detection + if not probe or probe[0:1] != b"[": + return False + has_messages = b'"messages"' in probe + has_agent = b'"agent_id"' in probe and b'"is_subagent"' in probe + has_proxy = b'"source"' in probe and b'"proxy"' in probe + return has_messages and (has_agent or has_proxy) + if not isinstance(raw, list) or len(raw) == 0: + return False + # Check any record for Conflux signature fields. Two detection + # paths: explicit agent threading (agent_id + is_subagent + messages) + # or the proxy source marker (source == "proxy" + messages). + return any( + isinstance(r, dict) + and "messages" in r + and ( + ("agent_id" in r and "is_subagent" in r) + or r.get("source") == "proxy" + ) + for r in raw[:20] + ) + except Exception: + return False + + @classmethod + def get_default_context_mode(cls) -> ConversationContextMode: + return ConversationContextMode.MESSAGE_ARRAY_WITH_RESPONSES + + @classmethod + def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: + return DatasetSamplingStrategy.SEQUENTIAL + + def load_dataset(self) -> dict[str, list[ConfluxRecord]]: + """Load and group Conflux records by agent_id. + + Records with an agent_id are grouped into multi-turn agent threads. + Records without an agent_id (e.g. haiku tool-result processing calls) + become single-turn background subagent children of the parent agent, + each mapped to the closest parent turn by timestamp. + + For directory input, each file is loaded independently with a file + prefix on agent_ids to avoid cross-file collisions. File boundaries + are preserved for per-file zero-alignment in convert_to_conversations. + """ + self._orphan_ids = set() + self._orphan_counter = 0 + self._file_boundaries = [] + + path = Path(self.filename) + if path.is_dir(): + return self._load_directory(path) + + groups = self._load_single_file(self.filename) + self._file_boundaries.append(set(groups.keys())) + return groups + + def _load_directory(self, path: Path) -> dict[str, list[ConfluxRecord]]: + """Load all JSON files in a directory as independent sessions.""" + json_files = sorted(path.glob("*.json")) + if not json_files: + raise FileNotFoundError( + f"No .json files found in directory: {self.filename}" + ) + + all_groups: dict[str, list[ConfluxRecord]] = {} + total_records = 0 + for file_idx, json_file in enumerate(json_files): + file_groups = self._load_single_file(str(json_file), prefix=f"f{file_idx}_") + file_keys: set[str] = set() + for key, records in file_groups.items(): + all_groups[key] = records + file_keys.add(key) + total_records += len(records) + self._file_boundaries.append(file_keys) + self.debug( + lambda fn=json_file.name, fk=file_keys: ( + f" {fn}: {len(fk)} agent groups -> {fk}" + ) + ) + + file_count = len(json_files) + threaded = len(all_groups) - len(self._orphan_ids) + self.info( + f"Loaded {threaded} agent threads from " + f"{file_count} files ({total_records} total records) in {path.name}/" + ) + return all_groups + + def _load_single_file( + self, filename: str, prefix: str = "" + ) -> dict[str, list[ConfluxRecord]]: + """Load and group records from a single JSON file. + + Args: + filename: Path to the JSON file. + prefix: String prefix for agent_id keys to avoid cross-file collisions + when loading multiple files from a directory. + """ + with open(filename, "rb") as f: + raw_records: list[dict[str, Any]] = orjson.loads(f.read()) + + include_orphans = self.user_config.input.conflux_include_utility_calls + + groups: dict[str, list[ConfluxRecord]] = {} + orphan_count = 0 + + for raw in raw_records: + record = ConfluxRecord.model_validate(raw) + if record.agent_id is not None: + key = f"{prefix}{record.agent_id}" + groups.setdefault(key, []).append(record) + elif include_orphans: + orphan_id = f"{prefix}_orphan_{self._orphan_counter}" + self._orphan_counter += 1 + self._orphan_ids.add(orphan_id) + groups[orphan_id] = [record] + else: + orphan_count += 1 + + # Sort each group by timestamp + for records in groups.values(): + records.sort(key=lambda r: _parse_timestamp_ms(r.timestamp)) + + if not prefix: + orphans_in_file = sum(1 for k in groups if k in self._orphan_ids) + threaded = len(groups) - orphans_in_file + total_records = sum(len(recs) for recs in groups.values()) + orphan_label = ( + f"{orphans_in_file} utility calls included" + if include_orphans + else f"{orphan_count} utility calls skipped" + ) + self.info( + f"Loaded {threaded} agent threads + {orphan_label} " + f"({total_records} total records)" + ) + + return groups + + def convert_to_conversations( + self, data: dict[str, list[ConfluxRecord]] + ) -> list[Conversation]: + """Convert grouped Conflux records to Conversation objects. + + Each file boundary (from directory loading) is processed independently + with its own zero-alignment, so timestamps from separate captures + don't bleed into each other. + """ + if not data: + return [] + + # Without load_dataset, _file_boundaries is empty and all data is dropped. + if not self._file_boundaries: + self._file_boundaries = [set(data.keys())] + + all_conversations: list[Conversation] = [] + for file_keys in self._file_boundaries: + file_data = {k: v for k, v in data.items() if k in file_keys} + convs = self._convert_file_groups(file_data) + self._zero_align_timestamps(convs) + all_conversations.extend(convs) + + _parse_timestamp_s.cache_clear() + return all_conversations + + @staticmethod + def _register_spawn( + parent_conv: Conversation, + spawn_counter: int, + child_session_ids: list[str], + spawn_turn_index: int, + *, + is_background: bool, + join_turn_index: int | None = None, + ) -> str: + """Wire up a SubagentSpawnInfo on the parent and return the spawn_id. + + Handles spawn info creation, spawn_id annotation on the spawn turn, + and optional prerequisite on the join turn. + """ + spawn_id = f"s{spawn_counter}" + parent_conv.subagent_spawns.append( + SubagentSpawnInfo( + spawn_id=spawn_id, + child_conversation_ids=child_session_ids, + is_background=is_background, + ) + ) + if spawn_turn_index < len(parent_conv.turns): + parent_conv.turns[spawn_turn_index].subagent_spawn_ids.append(spawn_id) + if ( + join_turn_index is not None + and not is_background + and join_turn_index < len(parent_conv.turns) + ): + parent_conv.turns[join_turn_index].prerequisites.append( + TurnPrerequisite( + kind=PrerequisiteKind.SPAWN_JOIN, + spawn_id=spawn_id, + ) + ) + return spawn_id + + def _convert_file_groups( + self, data: dict[str, list[ConfluxRecord]] + ) -> list[Conversation]: + """Convert one file's grouped records to Conversation objects. + + Three categories: + - The group with is_subagent=False is the parent agent thread. + - Groups with is_subagent=True are children linked via SubagentSpawnInfo. + Children spawned at the same parent turn are grouped into a single + SubagentSpawnInfo. Blocking joins are attached to the first later + parent turn that consumes the child result, inferred from Agent + tool_result content when available and child-completion timing + otherwise. + - Orphan groups (no agent_id) become single-turn background subagent + children, each spawned at the closest parent turn by timestamp. + """ + # Separate orphans from threaded groups + threaded_data: dict[str, list[ConfluxRecord]] = {} + orphan_data: dict[str, list[ConfluxRecord]] = {} + for agent_id, records in data.items(): + if agent_id in self._orphan_ids: + orphan_data[agent_id] = records + else: + threaded_data[agent_id] = records + + parent_id: str | None = None + child_ids: list[str] = [] + unclassified_ids: list[str] = [] + + for agent_id, records in threaded_data.items(): + # `is True`/`is False`: distinguish from None (un-enriched). + if records[0].is_subagent is True: + child_ids.append(agent_id) + elif records[0].is_subagent is False: + if parent_id is not None: + # Multiple groups marked is_subagent=False — keep the one with + # the most records as parent, demote the other to child. + if len(records) > len(threaded_data[parent_id]): + child_ids.append(parent_id) + parent_id = agent_id + else: + child_ids.append(agent_id) + else: + parent_id = agent_id + else: + # is_subagent=None: un-enriched record, classify heuristically + unclassified_ids.append(agent_id) + + # For un-enriched data: if no explicit parent, elect the group with + # the most records as parent, rest become children. + if parent_id is None and unclassified_ids: + unclassified_ids.sort(key=lambda aid: len(threaded_data[aid]), reverse=True) + parent_id = unclassified_ids.pop(0) + child_ids.extend(unclassified_ids) + else: + # With an explicit parent, unclassified groups become children + child_ids.extend(unclassified_ids) + + conversations: list[Conversation] = [] + child_conversations: list[Conversation] = [] + + if parent_id is not None: + parent_records = threaded_data[parent_id] + parent_conv = self._build_conversation( + parent_id, parent_records, is_child=False + ) + parent_session_id = parent_conv.session_id + self.debug( + lambda: ( + f"Parent {parent_session_id}: {len(parent_records)} turns, " + f"{len(child_ids)} children, {len(orphan_data)} orphans" + ) + ) + + # Group children by spawn turn index, then detect blocking vs background + spawn_turn_to_children: dict[ + int, list[tuple[str, list[ConfluxRecord], Conversation]] + ] = {} + for child_agent_id in child_ids: + child_records = threaded_data[child_agent_id] + child_conv = self._build_conversation( + child_agent_id, + child_records, + is_child=True, + parent_session_id=parent_session_id, + ) + spawn_turn_index = self._find_spawn_point(parent_records, child_records) + spawn_turn_to_children.setdefault(spawn_turn_index, []).append( + (child_agent_id, child_records, child_conv) + ) + self.debug( + lambda cid=child_conv.session_id, + recs=child_records, + si=spawn_turn_index: ( + f" Child {cid}: {len(recs)} turns, spawn_point=turn[{si}]" + ) + ) + + spawn_counter = 0 + blocking_count = 0 + background_count = 0 + notification_joins = _extract_notification_joins(parent_records) + if notification_joins: + self.debug( + lambda: f" Found {len(notification_joins)} task-notification join(s)" + ) + for spawn_turn_index in sorted(spawn_turn_to_children): + children_at_turn = spawn_turn_to_children[spawn_turn_index] + join_turn_index = _find_join_turn_index( + parent_records, spawn_turn_index, children_at_turn + ) + is_background = join_turn_index is None + self.debug( + lambda cat=children_at_turn, + si=spawn_turn_index, + ji=join_turn_index, + bg=is_background: ( + f" Spawn at turn[{si}]: " + f"{len(cat)} children {[aid for aid, _, _ in cat]}, " + f"join=turn[{ji}], background={bg}" + ) + ) + + # For background spawns, check if each child reports back via + # . If so, split into per-child blocking + # spawns with per-child join turns derived from the notification. + if is_background and notification_joins: + tuid_to_agent = _build_spawn_tuid_to_agent_id( + parent_records, spawn_turn_index + ) + agent_to_join: dict[str, int] = { + agent_id: notification_joins[tuid] + for tuid, agent_id in tuid_to_agent.items() + if tuid in notification_joins + } + if agent_to_join: + self.debug( + lambda atj=agent_to_join: ( + f" Notification join split: {atj}" + ) + ) + # Match via record's agent_id (unprefixed) since + # agent_to_join keys come from message content. + notification_children = [ + (aid, recs, conv, agent_to_join[recs[0].agent_id]) + for aid, recs, conv in children_at_turn + if recs[0].agent_id in agent_to_join + ] + bg_children = [ + (aid, recs, conv) + for aid, recs, conv in children_at_turn + if recs[0].agent_id not in agent_to_join + ] + for ( + _, + _, + child_conv, + notification_turn, + ) in notification_children: + self._register_spawn( + parent_conv, + spawn_counter, + [child_conv.session_id], + spawn_turn_index, + is_background=False, + join_turn_index=notification_turn, + ) + spawn_counter += 1 + child_conversations.append(child_conv) + blocking_count += 1 + for _, _, child_conv in bg_children: + self._register_spawn( + parent_conv, + spawn_counter, + [child_conv.session_id], + spawn_turn_index, + is_background=True, + ) + spawn_counter += 1 + child_conversations.append(child_conv) + background_count += 1 + continue + + child_conv_ids = [conv.session_id for _, _, conv in children_at_turn] + self._register_spawn( + parent_conv, + spawn_counter, + child_conv_ids, + spawn_turn_index, + is_background=is_background, + join_turn_index=join_turn_index, + ) + spawn_counter += 1 + + for _, _, child_conv in children_at_turn: + child_conversations.append(child_conv) + + if is_background: + background_count += len(children_at_turn) + else: + blocking_count += len(children_at_turn) + + # Attach orphan records as single-turn background subagent children + for orphan_id, orphan_records in orphan_data.items(): + child_conv = self._build_conversation( + orphan_id, + orphan_records, + is_child=True, + parent_session_id=parent_session_id, + ) + + spawn_turn_index = self._find_spawn_point( + parent_records, orphan_records + ) + self._register_spawn( + parent_conv, + spawn_counter, + [child_conv.session_id], + spawn_turn_index, + is_background=True, + ) + spawn_counter += 1 + child_conversations.append(child_conv) + background_count += 1 + + conversations.append(parent_conv) + conversations.extend(child_conversations) + else: + blocking_count = 0 + background_count = 0 + for agent_id, records in threaded_data.items(): + conversations.append( + self._build_conversation(agent_id, records, is_child=False) + ) + + total_turns = sum(len(c.turns) for c in conversations) + self.info( + f"Converted {len(conversations)} conversations " + f"({total_turns} total turns, " + f"{len(child_conversations)} subagent children: " + f"{blocking_count} blocking, {background_count} background, " + f"incl. {len(orphan_data)} orphans)" + ) + return conversations + + @staticmethod + def _zero_align_timestamps(conversations: list[Conversation]) -> None: + """Shift all turn timestamps so the earliest becomes 0.""" + min_ts = min( + ( + turn.timestamp + for conv in conversations + for turn in conv.turns + if turn.timestamp is not None + ), + default=None, + ) + if min_ts is None or min_ts == 0: + return + for conv in conversations: + for turn in conv.turns: + if turn.timestamp is not None: + turn.timestamp -= min_ts + + def _build_conversation( + self, + agent_id: str, + records: list[ConfluxRecord], + *, + is_child: bool, + parent_session_id: str | None = None, + ) -> Conversation: + """Build a Conversation from a list of ConfluxRecords for one agent.""" + first = records[0] + + origin = ConversationOrigin( + source=first.source, + client=first.client, + client_version=first.client_version, + original_session_id=first.session_id, + original_request_ids=[ + r.request_id for r in records if r.request_id is not None + ], + ) + + conversation = Conversation( + session_id=f"conflux_{agent_id}", + agent_depth=1 if is_child else 0, + parent_conversation_id=parent_session_id if is_child else None, + origin=origin, + ) + + for record in records: + ts_ms = _parse_timestamp_ms(record.timestamp) + input_tokens = record.tokens.input if record.tokens else None + + # Extract messages and tools from best available source + messages, tools, max_tokens = self._extract_record_fields(record) + + # Normalize to OpenAI canonical format (N+M architecture) + provider = self._detect_conflux_provider(record) + raw_messages, raw_tools = normalize_messages( + messages, tools, provider=provider + ) + + # Build extra_params from hyperparameters (excluding max_tokens and nulls) + extra_params = self._extract_extra_params(record) + + # Build ground truth from token breakdown, timing, and output + ground_truth = self._extract_ground_truth(record) + + turn = Turn( + role="user", + model=record.model, + timestamp=ts_ms, + max_tokens=max_tokens, + input_tokens=input_tokens, + raw_messages=raw_messages, + raw_tools=raw_tools, + extra_params=extra_params, + ground_truth=ground_truth, + ) + conversation.turns.append(turn) + + return conversation + + _EXTRA_PARAMS_SKIP = frozenset({"max_tokens", "max_output_tokens"}) + + @staticmethod + def _extract_extra_params(record: ConfluxRecord) -> dict[str, Any] | None: + """Extract per-turn hyperparameter overrides from a ConfluxRecord. + + Excludes max_tokens fields (already on Turn.max_tokens) and null values. + Returns None if no non-null hyperparameters remain. + """ + if not record.hyperparameters: + return None + hp = record.hyperparameters + params: dict[str, Any] = {} + for field_name in type(hp).model_fields: + if field_name in ConfluxLoader._EXTRA_PARAMS_SKIP: + continue + value = getattr(hp, field_name, None) + if value is not None: + params[field_name] = value + return params or None + + @staticmethod + def _extract_ground_truth(record: ConfluxRecord) -> TurnGroundTruth | None: + """Extract ground truth metadata from a ConfluxRecord. + + Returns None if no meaningful ground truth data is available. + """ + tokens = record.tokens + has_token_detail = tokens is not None and ( + tokens.input_cached > 0 + or tokens.input_cache_write > 0 + or tokens.output > 0 + or tokens.output_reasoning > 0 + ) + has_timing = record.ttft_ms is not None or record.duration_ms > 0 + has_streaming = record.is_streaming is not None + if not (has_token_detail or has_timing or has_streaming): + return None + + return TurnGroundTruth( + input_cached_tokens=tokens.input_cached if tokens else None, + input_cache_write_tokens=tokens.input_cache_write if tokens else None, + output_tokens=tokens.output if tokens else None, + output_reasoning_tokens=tokens.output_reasoning if tokens else None, + ttft_ms=record.ttft_ms, + duration_ms=record.duration_ms if record.duration_ms > 0 else None, + is_streaming=record.is_streaming, + ) + + @staticmethod + def _extract_record_fields( + record: ConfluxRecord, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None, int | None]: + """Extract messages, tools, and max_tokens from the best available source. + + Base64 path: decode the verbatim request body for full fidelity + (includes system, tools, thinking config, etc.). + Fallback: use top-level Conflux fields. + + Returns: + (messages, tools, max_tokens) + """ + payload = _decode_request_payload(record) + if payload is not None: + payload.pop("metadata", None) + messages = _messages_from_payload(payload) + # Don't collapse [] to None -- tools=[] means "no tools" vs absent. + tools = payload.get("tools") + # Anthropic: max_tokens. OpenAI: max_completion_tokens. + max_tokens = payload.get("max_tokens") + if max_tokens is None: + max_tokens = payload.get("max_completion_tokens") + return messages, tools, max_tokens + + messages = list(record.messages) + # record.tools defaults to [] via default_factory, not from the original request. + tools = record.tools or None + hp = record.hyperparameters + if hp is None: + max_tokens = None + else: + max_tokens = ( + hp.max_tokens if hp.max_tokens is not None else hp.max_output_tokens + ) + return messages, tools, max_tokens + + @staticmethod + def _detect_conflux_provider(record: ConfluxRecord) -> str | None: + """Detect the provider from Conflux record metadata. + + Uses the ``client`` field (set by Conflux enrichment) or the + ``provider`` field (set by the adapter pipeline) to map to a + provider hint for the normalizer. + """ + if record.provider: + provider = record.provider.lower() + if provider in ("anthropic", "openai"): + return provider + if record.client == "claude": + return "anthropic" + if record.client == "codex": + return "openai" + return None + + @staticmethod + def _find_spawn_point( + parent_records: list[ConfluxRecord], + child_records: list[ConfluxRecord], + ) -> int: + """Find the parent turn that spawned a child agent. + + Three-tier matching: + + 1. **In-flight overlap**: Child's first request falls within a parent + turn's execution window (start -> completion). This handles children + that start while the parent is still processing. + + 2. **Post-completion gap**: Child starts in the gap between parent turn + N completing and turn N+1 starting. This handles the common case + where children are spawned immediately after the parent processes the + response (typically within a few hundred milliseconds). + + 3. **Closest timestamp** (fallback): Assigns to the nearest parent turn + by start timestamp when neither overlap nor gap matching succeeds. + """ + child_first_ts = _parse_timestamp_ms(child_records[0].timestamp) + + # Single pass: check in-flight overlap and post-completion gap together + best_idx = 0 + best_diff = float("inf") + for i, record in enumerate(parent_records): + start_ts = _parse_timestamp_ms(record.timestamp) + end_ts = _record_end_ms(record) + has_end = end_ts > start_ts + + # Tier 1: child started while parent turn was in-flight + if has_end and start_ts <= child_first_ts <= end_ts: + return i + + # Tier 2: child started in the gap after this turn completed + if has_end and child_first_ts > end_ts: + next_start_ts = ( + _parse_timestamp_ms(parent_records[i + 1].timestamp) + if i + 1 < len(parent_records) + else float("inf") + ) + if child_first_ts <= next_start_ts: + return i + + # Tier 3: track closest by start timestamp as fallback + diff = abs(start_ts - child_first_ts) + if diff < best_diff: + best_diff = diff + best_idx = i + + return best_idx diff --git a/src/aiperf/dataset/loader/conflux_analyzer.py b/src/aiperf/dataset/loader/conflux_analyzer.py new file mode 100644 index 000000000..5270bda6a --- /dev/null +++ b/src/aiperf/dataset/loader/conflux_analyzer.py @@ -0,0 +1,370 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Analyzer for Conflux trace files and directories. + +Computes structural, token, cache, concurrency, and timing statistics for +Conflux proxy captures without loading the full dataset pipeline. +Used by ``aiperf analyze-trace``. +""" + +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import orjson +from pydantic import Field + +from aiperf.common.models import AIPerfBaseModel +from aiperf.dataset.loader.conflux import _parse_timestamp_s +from aiperf.dataset.synthesis.models import MetricStats + +# ── Internal helpers ── + + +@dataclass(slots=True) +class _TokenAccum: + """Accumulates token totals for a group of records.""" + + input: int = 0 + output: int = 0 + cached: int = 0 + cache_write: int = 0 + models: Counter = field(default_factory=Counter) + + def add(self, record: dict[str, Any]) -> None: + tokens = record.get("tokens") + if isinstance(tokens, dict): + self.input += tokens.get("input", 0) or 0 + self.output += tokens.get("output", 0) or 0 + self.cached += tokens.get("input_cached", 0) or 0 + self.cache_write += tokens.get("input_cache_write", 0) or 0 + model = record.get("model") + if model: + self.models[model] += 1 + + @property + def uncached(self) -> int: + return self.input - self.cached + + @property + def cache_hit_pct(self) -> float: + return (self.cached / self.input * 100.0) if self.input > 0 else 0.0 + + @property + def cache_roi(self) -> float: + return (self.cached / self.cache_write) if self.cache_write > 0 else 0.0 + + @property + def primary_model(self) -> str: + return self.models.most_common(1)[0][0] if self.models else "unknown" + + +def _safe_pct(num: int | float, denom: int | float) -> float: + return (num / denom * 100.0) if denom > 0 else 0.0 + + +# ── Models ── + + +class AgentSummary(AIPerfBaseModel): + """Per-agent breakdown row.""" + + agent_id: str = Field(description="Agent identifier") + is_parent: bool = Field(description="Whether this is the parent agent") + model: str = Field(description="Primary model used") + requests: int = Field(description="Number of API calls") + input_tokens: int = Field(description="Total input tokens") + cached_tokens: int = Field(description="Total cached input tokens") + uncached_tokens: int = Field(description="Total uncached input tokens") + output_tokens: int = Field(description="Total output tokens") + cache_hit_pct: float = Field(description="Weighted cache hit rate (%)") + cache_writes: int = Field(description="Total cache write tokens") + cache_roi: float = Field( + description="Cache ROI (cached hits / cache writes). 0 if no writes." + ) + + +class ConfluxAnalysisStats(AIPerfBaseModel): + """Statistics extracted from Conflux trace analysis.""" + + # Structure + total_records: int = Field(description="Total API call records across all files") + total_files: int = Field(description="Number of JSON files loaded") + total_agents: int = Field(description="Distinct agent threads (with agent_id)") + parent_agents: int = Field(description="Agent threads with is_subagent=False") + child_agents: int = Field(description="Agent threads with is_subagent=True") + orphan_records: int = Field(description="Records without agent_id (utility calls)") + models_used: dict[str, int] = Field(description="Model name -> request count") + + # Session timeline + session_span_s: float = Field( + description="Wall clock span from first to last request (seconds)" + ) + active_time_s: float = Field(description="Sum of all request durations (seconds)") + active_pct: float = Field(description="Active time as % of span") + + # Concurrency + max_concurrency: int = Field(description="Peak in-flight concurrent requests") + avg_concurrency: float = Field( + description="Average concurrency (active_time / span)" + ) + + # Token totals + total_input_tokens: int = Field(description="Sum of all input tokens") + total_output_tokens: int = Field(description="Sum of all output tokens") + total_cached_tokens: int = Field(description="Sum of all cached input tokens") + total_uncached_tokens: int = Field( + description="Sum of uncached input tokens (input - cached)" + ) + total_cache_write_tokens: int = Field(description="Sum of cache write tokens") + input_share_pct: float = Field(description="Input tokens as % of total tokens") + + # Cache economics + weighted_cache_hit_pct: float = Field( + description="Weighted cache hit rate: cached / input (%)" + ) + cache_roi: float = Field(description="Cache ROI: cached hits / cache writes (x)") + effective_token_pct: float = Field( + description="Effective tokens (uncached + output) as % of total" + ) + + # Per-request distributions + input_tokens_stats: MetricStats | None = Field( + default=None, description="Input tokens per request" + ) + output_tokens_stats: MetricStats | None = Field( + default=None, description="Output tokens per request" + ) + cached_tokens_stats: MetricStats | None = Field( + default=None, description="Cached input tokens per request" + ) + cache_hit_pct_stats: MetricStats | None = Field( + default=None, description="Per-request cache hit % distribution" + ) + osl_isl_ratio_stats: MetricStats | None = Field( + default=None, description="Output/input token ratio per request" + ) + + # Timing + duration_ms_stats: MetricStats | None = Field( + default=None, description="Request duration in ms" + ) + ttft_ms_stats: MetricStats | None = Field( + default=None, description="Time to first token in ms" + ) + + # Per-agent turn counts + turns_per_agent_stats: MetricStats | None = Field( + default=None, description="Turns per agent thread" + ) + + # Request shape + tool_count_stats: MetricStats | None = Field( + default=None, description="Tool definitions per request" + ) + message_count_stats: MetricStats | None = Field( + default=None, description="Messages per request" + ) + + # Streaming + streaming_pct: float = Field(description="Percentage of requests using streaming") + + # Per-agent breakdown + agent_breakdown: list[AgentSummary] = Field( + default_factory=list, + description="Per-agent token and cache summary (sorted by input tokens desc)", + ) + + +# ── Analysis ── + + +def analyze_conflux(input_path: Path) -> ConfluxAnalysisStats: + """Analyze a Conflux JSON file or directory of files.""" + if input_path.is_dir(): + json_files = sorted(input_path.glob("*.json")) + if not json_files: + raise FileNotFoundError(f"No .json files found in {input_path}") + else: + json_files = [input_path] + + all_records: list[dict[str, Any]] = [] + for json_file in json_files: + with open(json_file, "rb") as f: + data = orjson.loads(f.read()) + if isinstance(data, list): + all_records.extend(data) + + # Group by agent_id and classify + agents: dict[str, list[dict]] = {} + orphan_count = 0 + parent_ids: set[str] = set() + child_ids: set[str] = set() + + for record in all_records: + agent_id = record.get("agent_id") + if agent_id is not None: + agents.setdefault(agent_id, []).append(record) + else: + orphan_count += 1 + + for agent_id, records in agents.items(): + if records[0].get("is_subagent") is True: + child_ids.add(agent_id) + elif records[0].get("is_subagent") is False: + parent_ids.add(agent_id) + + # Session timeline and concurrency + events: list[tuple[float, int]] = [] + first_ts = float("inf") + last_ts = float("-inf") + total_active_s = 0.0 + + for r in all_records: + ts_str = r.get("timestamp") + if not ts_str: + continue + start_s = _parse_timestamp_s(ts_str) + completed_str = r.get("completed_at") + dur_s = (r.get("duration_ms", 0) or 0) / 1000.0 + + if completed_str: + end_s = _parse_timestamp_s(completed_str) + elif dur_s > 0: + end_s = start_s + dur_s + else: + end_s = start_s + + first_ts = min(first_ts, start_s) + last_ts = max(last_ts, end_s) + total_active_s += end_s - start_s + events.append((start_s, 1)) + events.append((end_s, -1)) + + session_span_s = max(last_ts - first_ts, 0.001) + avg_concurrency = total_active_s / session_span_s if session_span_s > 0 else 0.0 + + events.sort() + max_conc = 0 + cur_conc = 0 + for _, delta in events: + cur_conc += delta + max_conc = max(max_conc, cur_conc) + + # Per-record distributions + global totals + totals = _TokenAccum() + input_tokens_list: list[int] = [] + output_tokens_list: list[int] = [] + cached_tokens_list: list[int] = [] + cache_hit_pcts: list[float] = [] + osl_isl_ratios: list[float] = [] + durations: list[float] = [] + ttfts: list[float] = [] + tool_counts: list[int] = [] + message_counts: list[int] = [] + streaming_count = 0 + streaming_known = 0 + + for r in all_records: + totals.add(r) + + tokens = r.get("tokens") + if isinstance(tokens, dict): + inp = tokens.get("input", 0) or 0 + out = tokens.get("output", 0) or 0 + cached = tokens.get("input_cached", 0) or 0 + if inp > 0: + input_tokens_list.append(inp) + cache_hit_pcts.append(cached / inp * 100.0) + if out > 0: + output_tokens_list.append(out) + if cached > 0: + cached_tokens_list.append(cached) + if inp > 0 and out > 0: + osl_isl_ratios.append(out / inp) + + dur = r.get("duration_ms", 0) + if dur and dur > 0: + durations.append(float(dur)) + ttft = r.get("ttft_ms") + if ttft is not None and ttft > 0: + ttfts.append(float(ttft)) + + tools = r.get("tools") + if isinstance(tools, list): + tool_counts.append(len(tools)) + msgs = r.get("messages") + if isinstance(msgs, list): + message_counts.append(len(msgs)) + + is_streaming = r.get("is_streaming") + if is_streaming is not None: + streaming_known += 1 + if is_streaming: + streaming_count += 1 + + total_tokens = totals.input + totals.output + + # Per-agent breakdown + agent_breakdown: list[AgentSummary] = [] + for agent_id, records in agents.items(): + acc = _TokenAccum() + for r in records: + acc.add(r) + agent_breakdown.append( + AgentSummary( + agent_id=agent_id, + is_parent=agent_id in parent_ids, + model=acc.primary_model, + requests=len(records), + input_tokens=acc.input, + cached_tokens=acc.cached, + uncached_tokens=acc.uncached, + output_tokens=acc.output, + cache_hit_pct=acc.cache_hit_pct, + cache_writes=acc.cache_write, + cache_roi=acc.cache_roi, + ) + ) + agent_breakdown.sort(key=lambda a: a.input_tokens, reverse=True) + + return ConfluxAnalysisStats( + total_records=len(all_records), + total_files=len(json_files), + total_agents=len(agents), + parent_agents=len(parent_ids), + child_agents=len(child_ids), + orphan_records=orphan_count, + models_used=dict(totals.models.most_common()), + session_span_s=session_span_s, + active_time_s=total_active_s, + active_pct=_safe_pct(total_active_s, session_span_s), + max_concurrency=max_conc, + avg_concurrency=avg_concurrency, + total_input_tokens=totals.input, + total_output_tokens=totals.output, + total_cached_tokens=totals.cached, + total_uncached_tokens=totals.uncached, + total_cache_write_tokens=totals.cache_write, + input_share_pct=_safe_pct(totals.input, total_tokens), + weighted_cache_hit_pct=totals.cache_hit_pct, + cache_roi=totals.cache_roi, + effective_token_pct=_safe_pct(totals.uncached + totals.output, total_tokens), + input_tokens_stats=MetricStats.from_values(input_tokens_list), + output_tokens_stats=MetricStats.from_values(output_tokens_list), + cached_tokens_stats=MetricStats.from_values(cached_tokens_list), + cache_hit_pct_stats=MetricStats.from_values(cache_hit_pcts), + osl_isl_ratio_stats=MetricStats.from_values(osl_isl_ratios), + duration_ms_stats=MetricStats.from_values(durations), + ttft_ms_stats=MetricStats.from_values(ttfts), + turns_per_agent_stats=MetricStats.from_values( + [len(recs) for recs in agents.values()] + ), + tool_count_stats=MetricStats.from_values(tool_counts), + message_count_stats=MetricStats.from_values(message_counts), + streaming_pct=_safe_pct(streaming_count, streaming_known), + agent_breakdown=agent_breakdown, + ) diff --git a/src/aiperf/dataset/loader/models.py b/src/aiperf/dataset/loader/models.py index 675b3c597..9668dbb0a 100644 --- a/src/aiperf/dataset/loader/models.py +++ b/src/aiperf/dataset/loader/models.py @@ -336,8 +336,183 @@ class BailianTrace(AIPerfBaseModel): ) +class ConfluxTokens(AIPerfBaseModel): + """Normalized token counts across providers.""" + + input: int = Field( + default=0, + description="Total input tokens processed (all input the model saw). " + "For Anthropic: input_tokens + cache_creation_input_tokens + cache_read_input_tokens. " + "For OpenAI: prompt_tokens (already includes cached).", + ) + input_cached: int = Field( + default=0, + description="Tokens read from cache. " + "For Anthropic: cache_read_input_tokens. For OpenAI: equivalent cached tokens.", + ) + input_cache_write: int = Field( + default=0, + description="Tokens written to cache. For Anthropic: cache_creation_input_tokens.", + ) + output: int = Field(default=0, description="Total output tokens generated.") + output_reasoning: int = Field( + default=0, + description="Output tokens used for reasoning/thinking, when available from the provider.", + ) + + +class ConfluxHyperparameters(AIPerfBaseModel): + """Normalized generation hyperparameters extracted from the request body. + + Matches the canonical keys defined in Conflux's hyperparameters.rs + normalization layer. Unknown keys are dropped during normalization; + null values are omitted. + """ + + max_tokens: int | None = Field( + default=None, description="Maximum tokens to generate (Anthropic max_tokens)." + ) + max_output_tokens: int | None = Field( + default=None, + description="Maximum output tokens (OpenAI max_completion_tokens).", + ) + temperature: float | None = Field(default=None, description="Sampling temperature.") + top_p: float | None = Field(default=None, description="Nucleus sampling cutoff.") + top_k: int | None = Field(default=None, description="Top-k sampling cutoff.") + presence_penalty: float | None = Field( + default=None, description="Presence penalty for repeated tokens." + ) + frequency_penalty: float | None = Field( + default=None, description="Frequency penalty for repeated tokens." + ) + seed: int | None = Field( + default=None, description="Random seed for deterministic generation." + ) + stop: Any = Field(default=None, description="Stop sequences or tokens.") + reasoning_effort: str | None = Field( + default=None, description="Reasoning effort level (e.g. low, medium, high)." + ) + reasoning_summary: str | None = Field( + default=None, description="Reasoning summary mode (OpenAI-specific)." + ) + text_verbosity: str | None = Field( + default=None, description="Text verbosity mode (OpenAI-specific)." + ) + + +class ConfluxRecord(AIPerfBaseModel): + """A single unified API call from a Conflux proxy capture. + + Conforms to the Conflux unified canonical schema. Each record represents + one API request/response cycle captured via MITM proxy intercept, with + agent threading metadata and full request payload for verbatim replay. + """ + + model_config = ConfigDict(populate_by_name=True) + + type: Literal[CustomDatasetType.CONFLUX] = CustomDatasetType.CONFLUX + + id: str = Field( + description="Unique identifier for this API call. " + "Typically the provider request ID (e.g. req_...) or a synthesized key.", + ) + source: str | None = Field( + default=None, + description="How this record was captured (e.g. 'proxy' for MITM proxy intercept).", + ) + client: str | None = Field( + default=None, + description="Which AI coding tool made this API call (claude, codex, unknown).", + ) + request_id: str | None = Field( + default=None, + description="Provider-assigned request identifier " + "(e.g. Anthropic req_... or OpenAI chatcmpl-...).", + ) + session_id: str = Field( + description="Session identifier grouping related API calls in a single coding session.", + ) + agent_id: str | None = Field( + default=None, + description="Identifier of the agent/persona that made this call.", + ) + is_subagent: bool | None = Field( + default=None, + description="Whether this call was made by a sub-agent (e.g. a tool-spawned background task). " + "None means un-enriched (not yet classified by the adapter pipeline).", + ) + timestamp: str = Field( + description="ISO 8601 timestamp when the API request was sent.", + ) + duration_ms: int | float = Field( + default=0, + description="Time in milliseconds from request start to response completion.", + ) + completed_at: str | None = Field( + default=None, + description="ISO 8601 timestamp when the API response completed. " + "Derived from timestamp + duration_ms.", + ) + provider: str | None = Field( + default=None, + description="The LLM provider that served this request (anthropic, openai, unknown).", + ) + model: str | None = Field( + default=None, + description="Model identifier (e.g. claude-opus-4-6, gpt-4o).", + ) + client_version: str | None = Field( + default=None, + description="Client CLI/runtime version for the calling tool (e.g. Claude Code 2.1.39).", + ) + tokens: ConfluxTokens | None = Field( + default=None, + description="Normalized token counts across providers.", + ) + tools: list[dict[str, Any]] = Field( + default_factory=list, + description="Full tool definitions available to the model for this API call. " + "Each element is the complete tool object from the provider request body.", + ) + messages: list[dict[str, Any]] = Field( + default_factory=list, + description="Input messages from the request " + "(system, user, tool results, prior assistant turns).", + ) + output: list[dict[str, Any]] = Field( + default_factory=list, + description="The assistant's output messages extracted from the API response " + "(text, tool calls, etc.).", + ) + hyperparameters: ConfluxHyperparameters | None = Field( + default=None, + description="Normalized generation hyperparameters extracted from the request body.", + ) + is_streaming: bool | None = Field( + default=None, + description="Whether this API call used SSE streaming. " + "Inferred from response Content-Type or response body format.", + ) + ttft_ms: int | float | None = Field( + default=None, + description="Time to first token in milliseconds. Only present for streaming API calls. " + "Measured from request sent to first SSE data chunk received.", + ) + base64: dict[str, str] | None = Field( + default=None, + description="Raw base64-encoded artifacts captured by the proxy. " + "Keys: request_body, response_body, provider_usage. " + "May be gzip or zstd compressed before encoding.", + ) + + CustomDatasetT = TypeVar( "CustomDatasetT", - bound=SingleTurn | MultiTurn | RandomPool | MooncakeTrace | BailianTrace, + bound=SingleTurn + | MultiTurn + | RandomPool + | MooncakeTrace + | BailianTrace + | ConfluxRecord, ) """A union type of all custom data types.""" diff --git a/src/aiperf/dataset/message_normalizer.py b/src/aiperf/dataset/message_normalizer.py new file mode 100644 index 000000000..f99c89743 --- /dev/null +++ b/src/aiperf/dataset/message_normalizer.py @@ -0,0 +1,757 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Normalize provider-native messages and tools to OpenAI canonical format. + +This module converts Anthropic, OpenAI, and other provider message formats +into OpenAI chat completions format -- the canonical internal representation +used by AIPerf. Each loader calls ``normalize_messages()`` once at load time, +and each endpoint emitter converts from canonical to its wire format. + +Architecture: N normalizers (provider -> canonical) + M emitters (canonical -> wire) +instead of NxM direct converters. + +Provider-specific metadata is stored in a single ``_meta`` dict on each message, +tool_call, or tool definition that needs it. This consolidates round-trip fidelity +data into one key rather than scattering underscore-prefixed fields. + +Metadata keys stored in ``_meta``: +- ``is_error``: on role:tool messages, from Anthropic tool_result ``is_error`` +- ``caller``: on tool_call dicts, from Anthropic tool_use ``caller`` +- ``citations``: on canonical messages, from Anthropic text block ``citations`` +- ``cache_control``: on canonical messages/tool dicts, from Anthropic ``cache_control`` +- ``server_tool``: on tool_call dicts, marks Anthropic ``server_tool_use`` origin +- ``passthrough_blocks``: on canonical messages, Anthropic blocks with no OpenAI equivalent +- ``block_order``: on canonical messages, original block ordering for interleaved + thinking + server tools (needed for Anthropic signature verification) +""" + +from __future__ import annotations + +import re +from typing import Any + +import orjson + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +# Anthropic block types preserved opaquely for round-trip fidelity +_ANTHROPIC_PASSTHROUGH_BLOCK_TYPES = frozenset( + {"document", "search_result", "container_upload", "tool_reference"} +) + +# Server tool result block types (returned inline in assistant messages) +_SERVER_TOOL_RESULT_TYPES = frozenset( + { + "web_search_tool_result", + "web_fetch_tool_result", + "code_execution_tool_result", + "bash_code_execution_tool_result", + "text_editor_code_execution_tool_result", + "tool_search_tool_result", + } +) + +# Union of both passthrough sets -- used in the normalizer where the distinction +# doesn't matter (both go to passthrough_blocks). +_ALL_PASSTHROUGH_BLOCK_TYPES = ( + _ANTHROPIC_PASSTHROUGH_BLOCK_TYPES | _SERVER_TOOL_RESULT_TYPES +) + +# Versioned server/computer-use tool type pattern +_VERSIONED_TOOL_TYPE_RE = re.compile( + r"^(computer|bash|text_editor|web_search|web_fetch|code_execution|memory|" + r"tool_search_tool_bm25|tool_search_tool_regex)_\d{8}$" +) + +# Tool ID must match ^[a-zA-Z0-9_-]+$ +_VALID_TOOL_ID_RE = re.compile(r"^[a-zA-Z0-9_-]+$") +_INVALID_TOOL_ID_CHAR_RE = re.compile(r"[^a-zA-Z0-9_-]") + +# OpenAI content types that pass through without conversion +_OPENAI_PASSTHROUGH_CONTENT_TYPES = frozenset( + {"input_audio", "audio_url", "guarded_text", "video_url", "file"} +) + +_BILLING_PREFIX = "x-anthropic-billing-header:" + +# Anthropic-specific message block types used for provider detection +_ANTHROPIC_MESSAGE_BLOCK_TYPES = frozenset( + {"tool_use", "tool_result", "thinking", "redacted_thinking", "server_tool_use"} +) + +_META = "_meta" + + +# --------------------------------------------------------------------------- +# _meta accessors -- single place for all metadata read/write +# --------------------------------------------------------------------------- + + +def _get_meta(d: dict[str, Any], key: str, default: Any = None) -> Any: + """Read a metadata value from the ``_meta`` dict.""" + meta = d.get(_META) + if meta is None: + return default + return meta.get(key, default) + + +def _set_meta(d: dict[str, Any], key: str, value: Any) -> None: + """Write a metadata value into the ``_meta`` dict.""" + meta = d.get(_META) + if meta is None: + meta = {} + d[_META] = meta + meta[key] = value + + +def _has_meta(d: dict[str, Any], key: str) -> bool: + """Check whether a metadata key exists.""" + meta = d.get(_META) + return meta is not None and key in meta + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _is_passthrough_tool(tool: dict[str, Any]) -> bool: + """Return True if *tool* should pass through without conversion. + + Matches versioned server/computer-use tools and MCP server tools. + """ + tool_type = tool.get("type", "") + if _VERSIONED_TOOL_TYPE_RE.match(tool_type): + return True + return tool_type == "url" and "url" in tool + + +def _sanitize_tool_id(tool_id: str) -> str: + """Sanitize a tool ID to match Anthropic's pattern: ``^[a-zA-Z0-9_-]+$``.""" + if not tool_id or _VALID_TOOL_ID_RE.match(tool_id): + return tool_id + return _INVALID_TOOL_ID_CHAR_RE.sub("_", tool_id) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def normalize_messages( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + provider: str | None = None, +) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: + """Normalize provider-native messages and tools to OpenAI canonical format. + + Args: + messages: Provider-native message dicts (from Conflux UnifiedMessage, + raw API capture, etc.) + tools: Provider-native tool definitions, or None. + provider: Provider hint -- ``"anthropic"``, ``"openai"``, or ``None`` + for auto-detection. + + Returns: + Tuple of (normalized_messages, normalized_tools) in OpenAI format. + """ + if provider is None: + provider = _detect_provider(messages, tools) + + # Strip Conflux metadata + messages = [ + {k: v for k, v in m.items() if k != "tokens"} if "tokens" in m else m + for m in messages + ] + + if provider == "anthropic": + messages = _normalize_anthropic_messages(messages) + tools = _normalize_anthropic_tools(tools) if tools else tools + + return messages, tools + + +def _detect_provider( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, +) -> str: + """Auto-detect provider from message and tool content shapes.""" + if tools: + for tool in tools: + if "input_schema" in tool and "function" not in tool: + return "anthropic" + if _is_passthrough_tool(tool): + return "anthropic" + if "function" in tool or "parameters" in tool: + return "openai" + + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + if block_type in _ANTHROPIC_MESSAGE_BLOCK_TYPES: + return "anthropic" + if block_type == "image" and isinstance(block.get("source"), dict): + return "anthropic" + if block_type in _ALL_PASSTHROUGH_BLOCK_TYPES: + return "anthropic" + # Assistant messages with list-of-blocks content are Anthropic-native; + # OpenAI assistant messages use plain string content. + if msg.get("role") == "assistant": + return "anthropic" + if msg.get("tool_calls") is not None: + return "openai" + if msg.get("role") == "tool" and "tool_call_id" in msg: + return "openai" + + return "openai" + + +# --------------------------------------------------------------------------- +# Anthropic -> OpenAI canonical +# --------------------------------------------------------------------------- + + +def _normalize_anthropic_messages( + messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert Anthropic message format to OpenAI canonical format.""" + result: list[dict[str, Any]] = [] + + for msg in messages: + role = msg.get("role", "") + + if role == "system": + flattened = _flatten_text_content( + msg.get("content"), strip_billing_headers=True + ) + if flattened: + result.append({"role": "system", "content": flattened}) + elif role == "assistant": + result.extend(_normalize_anthropic_assistant(msg)) + elif role == "user": + result.extend(_normalize_anthropic_user(msg)) + else: + result.append(msg) + + return result + + +def _normalize_anthropic_assistant(msg: dict[str, Any]) -> list[dict[str, Any]]: + """Convert an Anthropic assistant message to OpenAI format. + + Tracks original block ordering via ``block_order`` in _meta when thinking + blocks are interleaved with server tools (needed for Anthropic signature + verification). + """ + content = msg.get("content") + if not isinstance(content, list): + return [msg] + + text_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + thinking_blocks: list[dict[str, Any]] = [] + passthrough_blocks: list[dict[str, Any]] = [] + citations: list[dict[str, Any]] = [] + block_order: list[tuple[str, int]] = [] + + for block in content: + if not isinstance(block, dict): + continue + block_type = block.get("type") + + if block_type == "text": + text = block.get("text", "") + if text: + text_parts.append(text) + # Must be inside `if text` -- moving outside produces stale index. + block_order.append(("text", len(text_parts) - 1)) + if "citations" in block: + citations.extend(block["citations"]) + + elif block_type in ("tool_use", "server_tool_use"): + tool_call = _anthropic_tool_use_to_call(block) + if block_type == "server_tool_use": + _set_meta(tool_call, "server_tool", True) + tool_calls.append(tool_call) + block_order.append(("tool_call", len(tool_calls) - 1)) + + elif block_type in ("thinking", "redacted_thinking"): + thinking_blocks.append(block) + block_order.append(("thinking", len(thinking_blocks) - 1)) + + elif block_type in _ALL_PASSTHROUGH_BLOCK_TYPES: + passthrough_blocks.append(block) + block_order.append(("passthrough", len(passthrough_blocks) - 1)) + + out: dict[str, Any] = {"role": "assistant"} + + combined_text = "\n\n".join(text_parts) if text_parts else None + if tool_calls: + out["content"] = combined_text or "" + out["tool_calls"] = tool_calls + elif combined_text is not None: + out["content"] = combined_text + else: + out["content"] = "" + + if thinking_blocks: + out["thinking_blocks"] = thinking_blocks + if passthrough_blocks: + _set_meta(out, "passthrough_blocks", passthrough_blocks) + if citations: + _set_meta(out, "citations", citations) + + # Only store block order for interleaved thinking + server tool messages + if thinking_blocks and any(_get_meta(tc, "server_tool") for tc in tool_calls): + _set_meta(out, "block_order", block_order) + + return [out] + + +def _anthropic_tool_use_to_call(block: dict[str, Any]) -> dict[str, Any]: + """Convert a single Anthropic tool_use/server_tool_use block to OpenAI tool_call.""" + tool_call: dict[str, Any] = { + "id": block.get("id", ""), + "type": "function", + "function": { + "name": block.get("name", ""), + "arguments": ( + orjson.dumps(block["input"]).decode() + if isinstance(block.get("input"), dict) + else str(block.get("input", "{}")) + ), + }, + } + if "caller" in block: + _set_meta(tool_call, "caller", block["caller"]) + if "cache_control" in block: + _set_meta(tool_call, "cache_control", block["cache_control"]) + return tool_call + + +def _normalize_anthropic_user(msg: dict[str, Any]) -> list[dict[str, Any]]: + """Convert an Anthropic user message to OpenAI format.""" + content = msg.get("content") + if not isinstance(content, list): + return [msg] + + text_parts: list[str] = [] + content_parts: list[dict[str, Any]] = [] + has_unconvertible_blocks = False + result: list[dict[str, Any]] = [] + + # Build result preserving original interleaved order of user content and + # tool_results. Flush accumulated user text/content before each tool_result. + def _flush_user() -> None: + if text_parts and not content_parts and not has_unconvertible_blocks: + result.append({"role": "user", "content": "\n\n".join(text_parts)}) + elif text_parts or content_parts: + combined = [{"type": "text", "text": t} for t in text_parts] + combined.extend(content_parts) + result.append({"role": "user", "content": combined}) + text_parts.clear() + content_parts.clear() + + for block in content: + if not isinstance(block, dict): + if isinstance(block, str): + text_parts.append(block) + continue + block_type = block.get("type") + + if block_type == "tool_result": + _flush_user() + tool_msg: dict[str, Any] = { + "role": "tool", + "tool_call_id": block.get("tool_use_id", ""), + "content": _flatten_text_content(block.get("content")), + } + if block.get("is_error"): + _set_meta(tool_msg, "is_error", True) + if "cache_control" in block: + _set_meta(tool_msg, "cache_control", block["cache_control"]) + result.append(tool_msg) + elif block_type == "text": + text = block.get("text", "") + if text: + text_parts.append(text) + elif block_type == "image": + content_parts.append(_anthropic_image_to_openai(block)) + elif block_type in _OPENAI_PASSTHROUGH_CONTENT_TYPES: + content_parts.append(block) + else: + content_parts.append(block) + has_unconvertible_blocks = True + + _flush_user() + return result if result else [msg] + + +def _flatten_text_content(content: Any, *, strip_billing_headers: bool = False) -> str: + """Flatten Anthropic content (string, list-of-blocks, or nested) to a single string.""" + if isinstance(content, str): + if strip_billing_headers: + return _strip_billing_headers(content) + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + text = item + elif isinstance(item, dict): + text = item.get("text") + else: + continue + if not isinstance(text, str) or not text: + continue + if strip_billing_headers and text.startswith(_BILLING_PREFIX): + continue + parts.append(text) + return "\n\n".join(parts) + if content is None: + return "" + if isinstance(content, dict): + return orjson.dumps(content).decode() + return str(content) + + +# --------------------------------------------------------------------------- +# Image conversion helpers +# --------------------------------------------------------------------------- + + +def _anthropic_image_to_openai(block: dict[str, Any]) -> dict[str, Any]: + """Convert Anthropic image block to OpenAI image_url content part.""" + source = block.get("source", {}) + source_type = source.get("type") + + if source_type == "base64": + media_type = source.get("media_type", "image/png") + data = source.get("data", "") + url = f"data:{media_type};base64,{data}" + elif source_type == "url": + url = source.get("url", "") + else: + return block + + result: dict[str, Any] = {"type": "image_url", "image_url": {"url": url}} + if "cache_control" in block: + _set_meta(result, "cache_control", block["cache_control"]) + return result + + +def _openai_image_to_anthropic(part: dict[str, Any]) -> dict[str, Any]: + """Convert OpenAI image_url content part to Anthropic image block.""" + image_url = part.get("image_url", {}) + url = image_url.get("url", "") + + if url.startswith("data:"): + header, _, data = url.partition(",") + media_type = "image/png" + if ":" in header and ";" in header: + media_type = header.split(":", 1)[1].split(";", 1)[0] + result: dict[str, Any] = { + "type": "image", + "source": {"type": "base64", "media_type": media_type, "data": data}, + } + else: + result = {"type": "image", "source": {"type": "url", "url": url}} + + if _has_meta(part, "cache_control"): + result["cache_control"] = _get_meta(part, "cache_control") + return result + + +# --------------------------------------------------------------------------- +# Anthropic tools -> OpenAI tools +# --------------------------------------------------------------------------- + + +def _normalize_anthropic_tools( + tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert Anthropic tool definitions to OpenAI function-calling format.""" + result: list[dict[str, Any]] = [] + + for tool in tools: + if "function" in tool or _is_passthrough_tool(tool): + result.append(tool) + continue + + schema_key = ( + "input_schema" + if "input_schema" in tool + else "parameters" + if "parameters" in tool and "name" in tool + else None + ) + if schema_key: + converted: dict[str, Any] = { + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", "(no description)"), + "parameters": tool.get(schema_key, {}), + }, + } + if "cache_control" in tool: + _set_meta(converted, "cache_control", tool["cache_control"]) + result.append(converted) + continue + + result.append(tool) + + return result + + +def _strip_billing_headers(text: str) -> str: + """Remove x-anthropic-billing-header lines from system text.""" + lines = text.split("\n") + filtered = [line for line in lines if not line.startswith(_BILLING_PREFIX)] + return "\n".join(filtered).strip() + + +# --------------------------------------------------------------------------- +# Canonical -> Anthropic (emitter) +# --------------------------------------------------------------------------- + + +def to_anthropic_messages( + messages: list[dict[str, Any]], +) -> tuple[list[dict[str, Any]], str | list[dict[str, Any]] | None]: + """Convert canonical (OpenAI) messages to Anthropic format. + + Returns: + Tuple of (messages, system) where system is extracted from the first + system message if present, or None. + """ + raw: list[dict[str, Any]] = [] + system_parts: list[str] = [] + + for msg in messages: + role = msg.get("role") + + if role in ("system", "developer"): + content = msg.get("content") + if isinstance(content, str): + system_parts.append(content) + elif isinstance(content, list): + system_parts.append(_flatten_text_content(content)) + continue + + if role == "assistant": + raw.append(_emit_anthropic_assistant(msg)) + elif role == "tool": + raw.append(_emit_anthropic_tool_result(msg)) + elif role == "user": + raw.append(_emit_anthropic_user(msg)) + else: + raw.append(msg) + + system = "\n\n".join(system_parts) if system_parts else None + return _merge_consecutive_roles(raw), system + + +def to_anthropic_tools( + tools: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Convert canonical OpenAI tool definitions to Anthropic format.""" + result: list[dict[str, Any]] = [] + + for tool in tools: + if ("input_schema" in tool and "function" not in tool) or _is_passthrough_tool( + tool + ): + result.append(tool) + continue + + func = tool.get("function") + if isinstance(func, dict): + converted: dict[str, Any] = { + "name": func.get("name", ""), + "input_schema": func.get("parameters", {}), + } + desc = func.get("description") + if desc: + converted["description"] = desc + if _has_meta(tool, "cache_control"): + converted["cache_control"] = _get_meta(tool, "cache_control") + result.append(converted) + continue + + result.append(tool) + + return result + + +def _emit_anthropic_assistant(msg: dict[str, Any]) -> dict[str, Any]: + """Convert a canonical assistant message to Anthropic content blocks. + + When ``block_order`` is present in _meta (interleaved thinking + server + tools), restores the original block ordering required for Anthropic + signature verification. + """ + block_order = _get_meta(msg, "block_order") + if block_order: + return _emit_anthropic_assistant_ordered(msg, block_order) + + content_blocks: list[dict[str, Any]] = [] + content_blocks.extend(msg.get("thinking_blocks", [])) + + text = msg.get("content") + if isinstance(text, str) and text: + content_blocks.append({"type": "text", "text": text}) + elif isinstance(text, list): + content_blocks.extend(text) + + _append_refusal(msg, content_blocks) + + for tc in msg.get("tool_calls", []): + content_blocks.append(_tool_call_to_anthropic_block(tc)) + + content_blocks.extend(_get_meta(msg, "passthrough_blocks", [])) + + result: dict[str, Any] = { + "role": "assistant", + "content": content_blocks or [{"type": "text", "text": ""}], + } + if _has_meta(msg, "cache_control"): + result["cache_control"] = _get_meta(msg, "cache_control") + return result + + +def _emit_anthropic_assistant_ordered( + msg: dict[str, Any], + block_order: list[tuple[str, int]], +) -> dict[str, Any]: + """Emit assistant message with original block ordering restored.""" + thinking_blocks = msg.get("thinking_blocks", []) + tool_calls = msg.get("tool_calls", []) + passthrough_blocks = _get_meta(msg, "passthrough_blocks", []) + + text = msg.get("content") + if isinstance(text, str): + text_parts = [text] if text else [] + elif isinstance(text, list): + text_parts = [b.get("text", "") for b in text if isinstance(b, dict)] + else: + text_parts = [] + + content_blocks: list[dict[str, Any]] = [] + for kind, idx in block_order: + if kind == "thinking" and idx < len(thinking_blocks): + content_blocks.append(thinking_blocks[idx]) + elif kind == "text" and idx < len(text_parts): + content_blocks.append({"type": "text", "text": text_parts[idx]}) + elif kind == "tool_call" and idx < len(tool_calls): + content_blocks.append(_tool_call_to_anthropic_block(tool_calls[idx])) + elif kind == "passthrough" and idx < len(passthrough_blocks): + content_blocks.append(passthrough_blocks[idx]) + + _append_refusal(msg, content_blocks) + + return { + "role": "assistant", + "content": content_blocks or [{"type": "text", "text": ""}], + } + + +def _append_refusal(msg: dict[str, Any], blocks: list[dict[str, Any]]) -> None: + """Append a refusal text block if present.""" + refusal = msg.get("refusal") + if isinstance(refusal, str) and refusal: + blocks.append({"type": "text", "text": refusal}) + + +def _tool_call_to_anthropic_block(tc: dict[str, Any]) -> dict[str, Any]: + """Convert a single OpenAI tool_call dict to an Anthropic tool_use block.""" + func = tc.get("function", {}) + arguments = func.get("arguments", "{}") + if isinstance(arguments, str): + try: + input_val = orjson.loads(arguments) + except orjson.JSONDecodeError: + input_val = arguments + else: + input_val = arguments + + block_type = "server_tool_use" if _get_meta(tc, "server_tool") else "tool_use" + tool_block: dict[str, Any] = { + "type": block_type, + "id": _sanitize_tool_id(tc.get("id", "")), + "name": func.get("name", ""), + "input": input_val, + } + if _has_meta(tc, "caller"): + tool_block["caller"] = _get_meta(tc, "caller") + if _has_meta(tc, "cache_control"): + tool_block["cache_control"] = _get_meta(tc, "cache_control") + return tool_block + + +def _emit_anthropic_tool_result(msg: dict[str, Any]) -> dict[str, Any]: + """Convert a canonical role:tool message to Anthropic tool_result user message.""" + block: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": _sanitize_tool_id(msg.get("tool_call_id", "")), + "content": msg.get("content", ""), + } + if _get_meta(msg, "is_error"): + block["is_error"] = True + if _has_meta(msg, "cache_control"): + block["cache_control"] = _get_meta(msg, "cache_control") + return {"role": "user", "content": [block]} + + +def _emit_anthropic_user(msg: dict[str, Any]) -> dict[str, Any]: + """Convert a canonical user message to Anthropic format.""" + content = msg.get("content") + if isinstance(content, str): + return {"role": "user", "content": content} + if isinstance(content, list): + converted = [ + _openai_image_to_anthropic(part) + if isinstance(part, dict) and part.get("type") == "image_url" + else part + for part in content + ] + return {"role": "user", "content": converted} + return msg + + +def _merge_consecutive_roles( + messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Merge consecutive messages with the same role. + + Anthropic requires strict user/assistant alternation. + """ + if not messages: + return [] + + merged: list[dict[str, Any]] = [messages[0]] + + for msg in messages[1:]: + prev = merged[-1] + if msg.get("role") == prev.get("role"): + prev["content"] = _merge_content(prev.get("content"), msg.get("content")) + else: + merged.append(msg) + + return merged + + +def _merge_content( + a: str | list[dict[str, Any]] | None, + b: str | list[dict[str, Any]] | None, +) -> list[dict[str, Any]]: + """Merge two content values into a list of content blocks.""" + a_list = [{"type": "text", "text": a}] if isinstance(a, str) else (a or []) + b_list = [{"type": "text", "text": b}] if isinstance(b, str) else (b or []) + return a_list + b_list diff --git a/src/aiperf/dataset/synthesis/cli.py b/src/aiperf/dataset/synthesis/cli.py index 9d9efa137..ca063bd55 100644 --- a/src/aiperf/dataset/synthesis/cli.py +++ b/src/aiperf/dataset/synthesis/cli.py @@ -6,18 +6,22 @@ from pathlib import Path +from rich import box from rich.console import Console from rich.table import Table +from aiperf.common.models import AIPerfBaseModel from aiperf.dataset.synthesis.models import MetricStats from aiperf.dataset.synthesis.prefix_analyzer import PrefixAnalyzer _STAT_COLUMNS = ["Mean", "Std Dev", "Min", "P25", "Median", "P75", "Max"] -def _build_stats_table(metrics: dict[str, MetricStats | None]) -> Table: +def _build_stats_table( + metrics: dict[str, MetricStats | None], *, title: str = "Trace Statistics" +) -> Table: """Build a Rich table with metric statistics.""" - table = Table(title="Trace Statistics") + table = Table(title=title, show_lines=False, box=box.SIMPLE_HEAVY, pad_edge=False) table.add_column("Metric", justify="right", style="cyan", no_wrap=True) for col in _STAT_COLUMNS: table.add_column(col, justify="right", style="green", no_wrap=True) @@ -40,40 +44,262 @@ def _build_stats_table(metrics: dict[str, MetricStats | None]) -> Table: return table +def _kv_table(title: str, rows: list[tuple[str, str]]) -> Table: + """Build a headerless key-value table.""" + table = Table(title=title, show_header=False, box=box.SIMPLE_HEAVY, pad_edge=False) + table.add_column("Key", style="cyan", no_wrap=True) + table.add_column("Value", style="white") + for key, value in rows: + table.add_row(key, value) + return table + + +def _save_report( + console: Console, stats: AIPerfBaseModel, output_file: Path | None +) -> None: + """Write JSON report if output_file is set.""" + if output_file: + output_file.parent.mkdir(parents=True, exist_ok=True) + output_file.write_text(stats.model_dump_json(indent=2)) + console.print(f"Analysis report saved to {output_file}") + + +def _is_conflux_format(input_path: Path) -> bool: + """Detect whether the input is a Conflux JSON file or directory.""" + from aiperf.dataset.loader.conflux import ConfluxLoader + + return ConfluxLoader.can_load(filename=str(input_path)) + + def analyze_trace( input_file: Path, block_size: int = 512, output_file: Path | None = None, ) -> None: - """Analyze a mooncake trace file for ISL/OSL distributions and cache hit rates.""" + """Analyze a trace file for distributions and statistics. + + Auto-detects Conflux JSON vs JSONL trace format. + """ if not input_file.exists(): print(f"Error: Input file not found: {input_file}") return + if _is_conflux_format(input_file): + _analyze_conflux(input_file, output_file) + elif input_file.is_dir(): + print( + f"Error: Directory '{input_file}' does not contain Conflux JSON files. " + "For JSONL trace analysis, provide a file path instead." + ) + else: + _analyze_prefix_trace(input_file, block_size, output_file) + + +def _analyze_prefix_trace( + input_file: Path, + block_size: int, + output_file: Path | None, +) -> None: + """Analyze a JSONL trace file for ISL/OSL distributions and cache hit rates.""" analyzer = PrefixAnalyzer(block_size=block_size) stats = analyzer.analyze_file(input_file) console = Console(width=120) - console.print() console.print("[bold]Trace Analysis Report[/bold]") - console.print(f"Total requests: {stats.total_requests:,}") - console.print(f"Unique prefixes: {stats.unique_prefixes:,}") - console.print(f"Prefix groups: {stats.num_prefix_groups:,}") + console.print( + _kv_table( + "Overview", + [ + ("Total requests", f"{stats.total_requests:,}"), + ("Unique prefixes", f"{stats.unique_prefixes:,}"), + ("Prefix groups", f"{stats.num_prefix_groups:,}"), + ], + ) + ) + console.print() + console.print( + _build_stats_table( + { + "Input Length": stats.isl_stats, + "Context Length": stats.context_length_stats, + "Unique Prompt Length": stats.unique_prompt_length_stats, + "Output Length": stats.osl_stats, + "Theoretical Hit Rates": stats.hit_rate_stats, + } + ) + ) console.print() + _save_report(console, stats, output_file) + + +def _fmt_tokens(n: int) -> str: + """Format token count with K/M suffix.""" + if n >= 1_000_000: + return f"{n / 1_000_000:.1f}M" + if n >= 1_000: + return f"{n / 1_000:.1f}K" + return str(n) + - metrics = { - "Input Length": stats.isl_stats, - "Context Length": stats.context_length_stats, - "Unique Prompt Length": stats.unique_prompt_length_stats, - "Output Length": stats.osl_stats, - "Theoretical Hit Rates": stats.hit_rate_stats, - } +def _fmt_duration(s: float) -> str: + """Format seconds to human-readable duration.""" + if s >= 3600: + return f"{s / 3600:.1f}h" + if s >= 60: + return f"{s / 60:.1f}m" + return f"{s:.1f}s" - console.print(_build_stats_table(metrics)) + +def _analyze_conflux( + input_path: Path, + output_file: Path | None, +) -> None: + """Analyze a Conflux JSON file or directory.""" + from aiperf.dataset.loader.conflux_analyzer import ( + ConfluxAnalysisStats, + analyze_conflux, + ) + + stats: ConfluxAnalysisStats = analyze_conflux(input_path) + + console = Console(width=120) console.print() + console.print("[bold]Conflux Trace Analysis[/bold]") - if output_file: - output_file.parent.mkdir(parents=True, exist_ok=True) - output_file.write_text(stats.model_dump_json(indent=2)) - console.print(f"Analysis report saved to {output_file}") + total_tok = stats.total_input_tokens + stats.total_output_tokens + console.print( + _kv_table( + "Overview", + [ + ("Files", f"{stats.total_files:,}"), + ("Requests", f"{stats.total_records:,}"), + ( + "Agents", + f"{stats.total_agents:,} ({stats.parent_agents} parent, " + f"{stats.child_agents} child, {stats.orphan_records} orphan)", + ), + ( + "Session span", + f"{_fmt_duration(stats.session_span_s)} " + f"({stats.active_pct:.0f}% active)", + ), + ( + "Concurrency", + f"max {stats.max_concurrency}, avg {stats.avg_concurrency:.1f}", + ), + ("Streaming", f"{stats.streaming_pct:.1f}%"), + ], + ) + ) + console.print() + + console.print( + _kv_table( + "Token Economics", + [ + ("Total tokens", f"{total_tok:,} ({stats.input_share_pct:.1f}% input)"), + ("Input tokens", f"{stats.total_input_tokens:,}"), + ("Output tokens", f"{stats.total_output_tokens:,}"), + ( + "Cached tokens", + f"{stats.total_cached_tokens:,} " + f"({stats.weighted_cache_hit_pct:.1f}% hit rate)", + ), + ("Uncached tokens", f"{stats.total_uncached_tokens:,}"), + ("Cache writes", f"{stats.total_cache_write_tokens:,}"), + ("Cache ROI", f"{stats.cache_roi:.1f}x (hits / writes)"), + ( + "Effective tokens", + f"{stats.effective_token_pct:.1f}% " + "(only uncached + output need compute)", + ), + ], + ) + ) + console.print() + + # Models + if stats.models_used: + model_table = Table(title="Models", box=box.SIMPLE_HEAVY, pad_edge=False) + model_table.add_column("Model", style="cyan") + model_table.add_column("Requests", justify="right", style="green") + model_table.add_column("%", justify="right", style="green") + for model, count in stats.models_used.items(): + pct = count / stats.total_records * 100 + model_table.add_row(model, f"{count:,}", f"{pct:.1f}%") + console.print(model_table) + console.print() + + # Per-agent breakdown + if stats.agent_breakdown: + agent_table = Table( + title="Agent Breakdown (by input tokens)", + box=box.SIMPLE_HEAVY, + pad_edge=False, + ) + agent_table.add_column("Agent", style="cyan", no_wrap=True) + agent_table.add_column("Role", style="dim") + agent_table.add_column("Model", style="dim") + agent_table.add_column("Req", justify="right", style="green") + agent_table.add_column("Input", justify="right", style="green") + agent_table.add_column("Cached", justify="right", style="green") + agent_table.add_column("Output", justify="right", style="green") + agent_table.add_column("Hit%", justify="right", style="green") + agent_table.add_column("ROI", justify="right", style="green") + for a in stats.agent_breakdown: + # Shorten model names for display + model_short = a.model.replace("claude-", "").replace("-20251001", "") + agent_table.add_row( + a.agent_id[:20], + "parent" if a.is_parent else "child", + model_short, + str(a.requests), + _fmt_tokens(a.input_tokens), + _fmt_tokens(a.cached_tokens), + _fmt_tokens(a.output_tokens), + f"{a.cache_hit_pct:.1f}%", + f"{a.cache_roi:.1f}x" if a.cache_roi > 0 else "-", + ) + console.print(agent_table) + console.print() + + # Distribution tables + console.print( + _build_stats_table( + { + "Input Tokens": stats.input_tokens_stats, + "Output Tokens": stats.output_tokens_stats, + "Cached Tokens": stats.cached_tokens_stats, + "Cache Hit %": stats.cache_hit_pct_stats, + "OSL/ISL Ratio": stats.osl_isl_ratio_stats, + }, + title="Token Distributions", + ) + ) + console.print() + + console.print( + _build_stats_table( + { + "Duration (ms)": stats.duration_ms_stats, + "TTFT (ms)": stats.ttft_ms_stats, + }, + title="Timing", + ) + ) + console.print() + + console.print( + _build_stats_table( + { + "Turns per Agent": stats.turns_per_agent_stats, + "Tools per Request": stats.tool_count_stats, + "Messages per Request": stats.message_count_stats, + }, + title="Request Shape", + ) + ) + console.print() + + _save_report(console, stats, output_file) diff --git a/src/aiperf/dataset/synthesis/models.py b/src/aiperf/dataset/synthesis/models.py index e0a23794a..6bf2a1799 100644 --- a/src/aiperf/dataset/synthesis/models.py +++ b/src/aiperf/dataset/synthesis/models.py @@ -2,6 +2,11 @@ # SPDX-License-Identifier: Apache-2.0 """Pydantic models for synthesis and analysis data.""" +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np from pydantic import Field from typing_extensions import Self @@ -21,6 +26,22 @@ class MetricStats(AIPerfBaseModel): p75: float = Field(description="75th percentile") max: float = Field(description="Maximum value") + @classmethod + def from_values(cls, values: Sequence[float | int]) -> Self | None: + """Compute stats from a list of values. Returns None if empty.""" + if not values: + return None + arr = np.asarray(values, dtype=np.float64) + return cls( + mean=float(np.mean(arr)), + std_dev=float(np.std(arr)), + min=float(np.min(arr)), + p25=float(np.percentile(arr, 25)), + median=float(np.median(arr)), + p75=float(np.percentile(arr, 75)), + max=float(np.max(arr)), + ) + class AnalysisStats(AIPerfBaseModel): """Statistics extracted from trace analysis.""" diff --git a/src/aiperf/dataset/synthesis/prefix_analyzer.py b/src/aiperf/dataset/synthesis/prefix_analyzer.py index 22610e5a4..0356406da 100644 --- a/src/aiperf/dataset/synthesis/prefix_analyzer.py +++ b/src/aiperf/dataset/synthesis/prefix_analyzer.py @@ -4,10 +4,8 @@ import statistics from collections import Counter -from collections.abc import Sequence from pathlib import Path -import numpy as np import orjson from aiperf.common.config.config_defaults import InputTokensDefaults @@ -144,38 +142,12 @@ def _compute_context_lengths(self) -> None: for pos, hash_id in enumerate(hash_ids) if (pos, hash_id) in repeated_hash_ids ) - context_len = repeated_count * self.block_size + context_len = min(repeated_count * self.block_size, isl) unique_prompt_len = isl - context_len self.context_lengths.append(context_len) self.unique_prompt_lengths.append(unique_prompt_len) - def _compute_metric_stats( - self, values: Sequence[float | int] - ) -> MetricStats | None: - """Compute full statistics for a list of values. - - Args: - values: List of numeric values. - - Returns: - MetricStats with mean, std_dev, min, percentiles, max, or None if empty. - """ - if not values: - return None - - arr = np.asarray(values) - - return MetricStats( - mean=float(np.mean(arr)), - std_dev=float(np.std(arr)), - min=float(np.min(arr)), - p25=float(np.percentile(arr, 25)), - median=float(np.median(arr)), - p75=float(np.percentile(arr, 75)), - max=float(np.max(arr)), - ) - def _compute_stats(self) -> AnalysisStats: """Compute final statistics. @@ -202,13 +174,13 @@ def _compute_stats(self) -> AnalysisStats: avg_osl=sum(self.osls) / len(self.osls) if self.osls else 0.0, prefix_reuse_ratio=prefix_reuse, # Extended statistics - isl_stats=self._compute_metric_stats(self.isls), - osl_stats=self._compute_metric_stats(self.osls), - context_length_stats=self._compute_metric_stats(self.context_lengths), - unique_prompt_length_stats=self._compute_metric_stats( + isl_stats=MetricStats.from_values(self.isls), + osl_stats=MetricStats.from_values(self.osls), + context_length_stats=MetricStats.from_values(self.context_lengths), + unique_prompt_length_stats=MetricStats.from_values( self.unique_prompt_lengths ), - hit_rate_stats=self._compute_metric_stats(per_request_hit_rates), + hit_rate_stats=MetricStats.from_values(per_request_hit_rates), ) def _compute_per_request_hit_rates(self) -> list[float]: diff --git a/src/aiperf/endpoints/base_endpoint.py b/src/aiperf/endpoints/base_endpoint.py index ba0a66eed..bd3fac328 100644 --- a/src/aiperf/endpoints/base_endpoint.py +++ b/src/aiperf/endpoints/base_endpoint.py @@ -45,6 +45,25 @@ def get_endpoint_params(self, request_info: RequestInfo) -> dict[str, str]: cfg = self.model_endpoint.endpoint return dict(cfg.url_params) if cfg.url_params else {} + def merge_turn_params( + self, payload: dict[str, Any], extra_params: dict[str, Any] + ) -> dict[str, Any]: + """Merge per-turn extra parameters into a formatted payload. + + Called after format_payload() when the current turn has extra_params set. + Default behavior is a flat update; endpoints with nested parameter + structures (e.g. HuggingFace TGI) should override. + + Args: + payload: The formatted payload dict from format_payload(). + extra_params: Per-turn hyperparameter overrides to merge. + + Returns: + The updated payload dict. + """ + payload.update(extra_params) + return payload + @abstractmethod def format_payload(self, request_info: RequestInfo) -> RequestOutputT: """Format request payload from RequestInfo. diff --git a/src/aiperf/endpoints/openai_chat.py b/src/aiperf/endpoints/openai_chat.py index f430a87f4..867cf4db7 100644 --- a/src/aiperf/endpoints/openai_chat.py +++ b/src/aiperf/endpoints/openai_chat.py @@ -41,12 +41,9 @@ def format_payload(self, request_info: RequestInfo) -> dict[str, Any]: turns = request_info.turns model_endpoint = request_info.model_endpoint - if turns[-1].raw_messages is not None: - messages = turns[-1].raw_messages - else: - messages = self._create_messages( - turns, request_info.system_message, request_info.user_context_message - ) + messages = self._create_messages( + turns, request_info.system_message, request_info.user_context_message + ) payload = { "messages": messages, @@ -120,6 +117,9 @@ def _create_messages( ) for turn in turns: + if turn.raw_messages is not None: + messages.extend(turn.raw_messages) + continue message = { "role": turn.role or _DEFAULT_ROLE, } diff --git a/src/aiperf/plugin/enums.py b/src/aiperf/plugin/enums.py index d275252e6..174f462b7 100644 --- a/src/aiperf/plugin/enums.py +++ b/src/aiperf/plugin/enums.py @@ -59,7 +59,7 @@ CustomDatasetTypeStr: TypeAlias = str CustomDatasetType = plugins.create_enum(PluginType.CUSTOM_DATASET_LOADER, "CustomDatasetType", module=__name__) -"""Dynamic enum for custom dataset loader. Example: CustomDatasetType.BAILIAN_TRACE, CustomDatasetType.MOONCAKE_TRACE, CustomDatasetType.MULTI_TURN""" +"""Dynamic enum for custom dataset loader. Example: CustomDatasetType.BAILIAN_TRACE, CustomDatasetType.CONFLUX, CustomDatasetType.MOONCAKE_TRACE""" PublicDatasetTypeStr: TypeAlias = str PublicDatasetType = plugins.create_enum(PluginType.PUBLIC_DATASET_LOADER, "PublicDatasetType", module=__name__) diff --git a/src/aiperf/plugin/plugins.py b/src/aiperf/plugin/plugins.py index fe5ac8df0..dcce423ea 100644 --- a/src/aiperf/plugin/plugins.py +++ b/src/aiperf/plugin/plugins.py @@ -1201,6 +1201,18 @@ def is_trace_dataset(name: str) -> bool: return get_dataset_loader_metadata(name).is_trace +def supports_timing(name: str) -> bool: + """Check if a custom dataset loader declares timing support. + + Args: + name: Dataset loader plugin name. + + Returns: + True if the loader declares ``supports_timing: true`` in plugin metadata. + """ + return get_dataset_loader_metadata(name).supports_timing + + # Mapping of categories to their metadata classes (for categories with typed metadata) _CATEGORY_METADATA_CLASSES: dict[str, type] = { "endpoint": EndpointMetadata, diff --git a/src/aiperf/plugin/plugins.yaml b/src/aiperf/plugin/plugins.yaml index 064be8f84..f31bcff61 100644 --- a/src/aiperf/plugin/plugins.yaml +++ b/src/aiperf/plugin/plugins.yaml @@ -437,8 +437,19 @@ custom_dataset_loader: conversation chains and fixed_schedule timing mode. metadata: is_trace: true + supports_timing: true default_block_size: 16 + conflux: + class: aiperf.dataset.loader.conflux:ConfluxLoader + description: | + Conflux proxy capture loader for verbatim replay of Claude Code and Codex sessions. + Loads JSON arrays of API request records with explicit agent_id/is_subagent + fields for parent and subagent thread grouping with timestamp-based delays. + metadata: + is_trace: true + supports_timing: true + mooncake_trace: class: aiperf.dataset.loader.mooncake_trace:MooncakeTraceDatasetLoader description: | @@ -446,6 +457,7 @@ custom_dataset_loader: timestamp-based replay support. Designed for fixed_schedule timing mode. metadata: is_trace: true + supports_timing: true default_block_size: 512 multi_turn: diff --git a/src/aiperf/plugin/schema/plugins.schema.json b/src/aiperf/plugin/schema/plugins.schema.json index ea65888c6..9a450bc2b 100644 --- a/src/aiperf/plugin/schema/plugins.schema.json +++ b/src/aiperf/plugin/schema/plugins.schema.json @@ -587,6 +587,12 @@ "title": "Is Trace", "type": "boolean" }, + "supports_timing": { + "default": false, + "description": "Whether this loader produces datasets with embedded timing data (timestamps or delays). When True, fixed_schedule timing mode is auto-enabled if the file contains timestamp fields.", + "title": "Supports Timing", + "type": "boolean" + }, "default_block_size": { "anyOf": [ { diff --git a/src/aiperf/plugin/schema/schemas.py b/src/aiperf/plugin/schema/schemas.py index f4d1fc2e3..47abc7682 100644 --- a/src/aiperf/plugin/schema/schemas.py +++ b/src/aiperf/plugin/schema/schemas.py @@ -335,6 +335,14 @@ class CustomDatasetLoaderMetadata(BaseModel): "options, and prefer sequential sampling with fixed_schedule timing." ), ) + supports_timing: bool = Field( + default=False, + description=( + "Whether this loader produces datasets with embedded timing data " + "(timestamps or delays). When True, fixed_schedule timing mode is " + "auto-enabled if the file contains timestamp fields." + ), + ) default_block_size: int | None = Field( default=None, ge=1, diff --git a/src/aiperf/timing/config.py b/src/aiperf/timing/config.py index 281302764..49f0eb050 100644 --- a/src/aiperf/timing/config.py +++ b/src/aiperf/timing/config.py @@ -182,6 +182,12 @@ class CreditPhaseConfig(AIPerfBaseModel): ge=0, description="The fixed schedule end offset of the timing manager.", ) + fixed_schedule_speedup: float | None = Field( + default=None, + gt=0, + description="Scaling factor for fixed schedule timestamps. " + "2.0 = twice as fast, 0.5 = half speed.", + ) def _build_warmup_config(user_config: UserConfig) -> CreditPhaseConfig | None: @@ -267,4 +273,5 @@ def _build_profiling_config(user_config: UserConfig) -> CreditPhaseConfig: auto_offset_timestamps=input.fixed_schedule_auto_offset, fixed_schedule_start_offset=input.fixed_schedule_start_offset, fixed_schedule_end_offset=input.fixed_schedule_end_offset, + fixed_schedule_speedup=input.fixed_schedule_speedup, ) # fmt: skip diff --git a/src/aiperf/timing/conversation_source.py b/src/aiperf/timing/conversation_source.py index 5345e5bdc..7b5f76fe6 100644 --- a/src/aiperf/timing/conversation_source.py +++ b/src/aiperf/timing/conversation_source.py @@ -18,7 +18,12 @@ import uuid from dataclasses import dataclass -from aiperf.common.models import ConversationMetadata, DatasetMetadata, TurnMetadata +from aiperf.common.models import ( + ConversationMetadata, + DatasetMetadata, + SubagentSpawnInfo, + TurnMetadata, +) from aiperf.credit.structs import Credit, TurnToSend from aiperf.dataset.protocols import DatasetSamplingStrategyProtocol @@ -41,18 +46,31 @@ class SampledSession: metadata: ConversationMetadata x_correlation_id: str - def build_first_turn(self, max_turns: int | None = None) -> TurnToSend: + def build_first_turn( + self, + max_turns: int | None = None, + agent_depth: int | None = None, + parent_correlation_id: str | None = None, + ) -> TurnToSend: """Build first turn (turn_index=0) from sampled conversation. Args: max_turns: The maximum number of turns to send for this user. Simulates a user that is partially through a conversation. If None, the number of turns is determined by the conversation metadata. + agent_depth: Nesting depth of this session. 0=root, 1=child, 2=grandchild. + Non-zero depth sessions skip session slot acquisition. + If None, uses the value from conversation metadata. + parent_correlation_id: Runtime x_correlation_id of the parent session. None for root sessions. """ return TurnToSend( conversation_id=self.conversation_id, x_correlation_id=self.x_correlation_id, turn_index=0, - num_turns=max_turns or len(self.metadata.turns), + num_turns=max_turns if max_turns is not None else len(self.metadata.turns), + agent_depth=( + self.metadata.agent_depth if agent_depth is None else agent_depth + ), + parent_correlation_id=parent_correlation_id, ) @@ -112,3 +130,36 @@ def get_next_turn_metadata(self, credit: Credit) -> TurnMetadata: f"(only {len(metadata.turns)} turns exist)" ) return metadata.turns[next_index] + + def get_turn_metadata_at( + self, conversation_id: str, turn_index: int + ) -> TurnMetadata: + """Get metadata for a specific turn by index.""" + metadata = self.get_metadata(conversation_id) + if turn_index < 0 or turn_index >= len(metadata.turns): + raise ValueError( + f"No turn {turn_index} in conversation {conversation_id} " + f"(only {len(metadata.turns)} turns exist)" + ) + return metadata.turns[turn_index] + + def start_child_session(self, conversation_id: str) -> SampledSession: + """Start a specific child conversation as a new session (for subagent spawns).""" + metadata = self.get_metadata(conversation_id) + return SampledSession( + conversation_id=conversation_id, + metadata=metadata, + x_correlation_id=str(uuid.uuid4()), + ) + + def get_subagent_spawn( + self, conversation_id: str, spawn_id: str + ) -> SubagentSpawnInfo | None: + """Look up a SubagentSpawnInfo by conversation and spawn ID.""" + metadata = self._metadata_lookup.get(conversation_id) + if metadata is None: + return None + for spawn in metadata.subagent_spawns: + if spawn.spawn_id == spawn_id: + return spawn + return None diff --git a/src/aiperf/timing/phase/credit_counter.py b/src/aiperf/timing/phase/credit_counter.py index 938d16f7f..b0c3fe2f6 100644 --- a/src/aiperf/timing/phase/credit_counter.py +++ b/src/aiperf/timing/phase/credit_counter.py @@ -181,7 +181,7 @@ def increment_sent(self, turn_to_send: TurnToSend) -> tuple[int, bool]: new_sent_sessions_count = self._sent_sessions new_total_session_turns = self._total_session_turns - if turn_to_send.turn_index == 0: + if turn_to_send.turn_index == 0 and turn_to_send.agent_depth == 0: new_sent_sessions_count += 1 new_total_session_turns += turn_to_send.num_turns @@ -200,7 +200,9 @@ def increment_sent(self, turn_to_send: TurnToSend) -> tuple[int, bool]: return credit_index, is_final_credit - def increment_returned(self, is_final_turn: bool, cancelled: bool) -> bool: + def increment_returned( + self, is_final_turn: bool, cancelled: bool, agent_depth: int = 0 + ) -> bool: """Atomically increment returned count and check phase completion. Lock-free: no async calls. @@ -208,6 +210,8 @@ def increment_returned(self, is_final_turn: bool, cancelled: bool) -> bool: Args: is_final_turn: Whether the returned turn is the final turn of its session cancelled: Whether the credit was cancelled + agent_depth: Agent depth of the returned credit. Child sessions (depth > 0) + are excluded from session completion/cancellation counts. Returns: True if ALL sent credits have now been returned or cancelled @@ -215,11 +219,11 @@ def increment_returned(self, is_final_turn: bool, cancelled: bool) -> bool: """ if cancelled: self._requests_cancelled += 1 - if is_final_turn: + if is_final_turn and agent_depth == 0: self._cancelled_sessions += 1 else: self._requests_completed += 1 - if is_final_turn: + if is_final_turn and agent_depth == 0: self._completed_sessions += 1 return self.check_all_returned_or_cancelled() diff --git a/src/aiperf/timing/phase/progress_tracker.py b/src/aiperf/timing/phase/progress_tracker.py index afa035e1c..38918de02 100644 --- a/src/aiperf/timing/phase/progress_tracker.py +++ b/src/aiperf/timing/phase/progress_tracker.py @@ -106,12 +106,15 @@ def increment_returned( self, is_final_turn: bool, cancelled: bool, + agent_depth: int = 0, ) -> bool: """Atomically increment returned count. Args: is_final_turn: Whether this turn is the final turn of a session. cancelled: Whether the credit was cancelled. + agent_depth: Agent depth of the returned credit. Child sessions (depth > 0) + are excluded from session completion/cancellation counts. Returns: True if ALL credits returned (this was the final return). @@ -123,7 +126,7 @@ def increment_returned( Note: Late arrivals (after phase complete) are handled by caller checking lifecycle.is_complete before calling this method. """ - return self._counter.increment_returned(is_final_turn, cancelled) + return self._counter.increment_returned(is_final_turn, cancelled, agent_depth) def increment_prefill_released(self) -> None: """Increment prefill released count. diff --git a/src/aiperf/timing/phase/runner.py b/src/aiperf/timing/phase/runner.py index efbc8cf58..9732eb574 100644 --- a/src/aiperf/timing/phase/runner.py +++ b/src/aiperf/timing/phase/runner.py @@ -10,7 +10,7 @@ import asyncio from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from aiperf.common.enums import CreditPhase from aiperf.common.environment import Environment @@ -101,12 +101,18 @@ def __init__( # For FIXED_SCHEDULE mode, use actual dataset size instead of config values. # Config values may reflect pre-filtered file size, but dataset_metadata # reflects the actual filtered dataset after start/end offset filtering. + # + # Counts include ALL conversations (root + children). Child turns are + # dispatched dynamically by SubagentOrchestrator and go through + # issue_credit -> increment_sent, so they contribute to the total. + # The sending phase completes when all turns (root + child) are sent. metadata = conversation_source.dataset_metadata if config.timing_mode == TimingMode.FIXED_SCHEDULE and metadata: + root_convs = [c for c in metadata.conversations if c.agent_depth == 0] self._config = config.model_copy( update={ "total_expected_requests": metadata.total_turn_count, - "expected_num_sessions": len(metadata.conversations), + "expected_num_sessions": len(root_convs), } ) self._phase_publisher = phase_publisher @@ -202,6 +208,22 @@ async def run( StrategyClass = plugins.get_class( PluginType.TIMING_STRATEGY, self._config.timing_mode ) + strategy_kwargs: dict[str, Any] = {} + + metadata = self._conversation_source.dataset_metadata + has_subagents = bool( + metadata and any(c.subagent_spawns for c in metadata.conversations) + ) + if has_subagents: + from aiperf.timing.subagent_orchestrator import SubagentOrchestrator + + strategy_kwargs["subagents"] = SubagentOrchestrator( + conversation_source=self._conversation_source, + credit_issuer=self._credit_issuer, + stop_checker=self._stop_checker, + scheduler=self._scheduler, + ) + strategy: TimingStrategyProtocol = StrategyClass( config=self._config, conversation_source=self._conversation_source, @@ -209,6 +231,7 @@ async def run( stop_checker=self._stop_checker, credit_issuer=self._credit_issuer, lifecycle=self._lifecycle, + **strategy_kwargs, ) try: @@ -253,6 +276,7 @@ async def run( self._lifecycle.mark_complete(grace_period_triggered=True) self._progress.freeze_completed_counts() self._progress.all_credits_returned_event.set() + strategy.cleanup() return self._progress.create_stats(self._lifecycle) # 11. Seamless mode: phase flows into next without waiting for returns @@ -269,15 +293,12 @@ async def run( for ramper in self._rampers: ramper.stop() self._scheduler.cancel_all() + strategy.cleanup() return self._progress.create_stats(self._lifecycle) except Exception as e: - # TODO: This can be improved a bit by having a better way to notify other services - # and the system controller of a failure in the benchmark. - # If there is an error while setting up or executing the phase, - # we need to flush it through the lifecycle to ensure the other services - # are notified that the phase has ended, and the benchmark does not hang forever. + # Flush through the lifecycle so other services are notified the phase ended. self.error(f"Error executing phase {self._config.phase.title}: {e!r}") if not self._was_cancelled: self.cancel() @@ -301,8 +322,32 @@ async def run( stats = self._progress.create_stats(self._lifecycle) await self._phase_publisher.publish_phase_complete(stats) + strategy.cleanup() raise e + def _add_concurrency_ramper( + self, + label: str, + duration_sec: float, + target: int, + setter: Callable[[float], None], + ) -> None: + """Create and register a stepped concurrency ramper.""" + self.info( + f"Starting {label} concurrency ramp: 1 → {target} over {duration_sec}s" + ) + self._rampers.append( + Ramper( + setter=setter, + config=RampConfig( + ramp_type=RampType.LINEAR, + start=1, + target=target, + duration_sec=duration_sec, + ), + ) + ) + def _create_rampers(self, strategy: TimingStrategyProtocol) -> None: """Create rampers for concurrency and rate if ramp durations are configured. @@ -315,43 +360,25 @@ def _create_rampers(self, strategy: TimingStrategyProtocol) -> None: # Session concurrency ramper (stepped mode) if config.concurrency_ramp_duration_sec and config.concurrency: - self.info( - f"Starting session concurrency ramp: 1 → {config.concurrency} " - f"over {config.concurrency_ramp_duration_sec}s" - ) - ramp_config = RampConfig( - ramp_type=RampType.LINEAR, - start=1, - target=config.concurrency, - duration_sec=config.concurrency_ramp_duration_sec, - ) - - def setter(limit: float) -> None: - return self._concurrency_manager.set_session_limit( + self._add_concurrency_ramper( + "session", + config.concurrency_ramp_duration_sec, + config.concurrency, + lambda limit: self._concurrency_manager.set_session_limit( config.phase, int(limit) - ) - - self._rampers.append(Ramper(setter=setter, config=ramp_config)) + ), + ) # Prefill concurrency ramper (stepped mode) if config.prefill_concurrency_ramp_duration_sec and config.prefill_concurrency: - self.info( - f"Starting prefill concurrency ramp: 1 → {config.prefill_concurrency} " - f"over {config.prefill_concurrency_ramp_duration_sec}s" - ) - ramp_config = RampConfig( - ramp_type=RampType.LINEAR, - start=1, - target=config.prefill_concurrency, - duration_sec=config.prefill_concurrency_ramp_duration_sec, - ) - - def setter(limit: float) -> None: - return self._concurrency_manager.set_prefill_limit( + self._add_concurrency_ramper( + "prefill", + config.prefill_concurrency_ramp_duration_sec, + config.prefill_concurrency, + lambda limit: self._concurrency_manager.set_prefill_limit( config.phase, int(limit) - ) - - self._rampers.append(Ramper(setter=setter, config=ramp_config)) + ), + ) # Request rate ramper (continuous mode via update_interval) if config.request_rate_ramp_duration_sec and config.request_rate: diff --git a/src/aiperf/timing/strategies/core.py b/src/aiperf/timing/strategies/core.py index 5ff52fa47..b6e8a97e6 100644 --- a/src/aiperf/timing/strategies/core.py +++ b/src/aiperf/timing/strategies/core.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from aiperf.common.loop_scheduler import LoopScheduler from aiperf.credit.issuer import CreditIssuer + from aiperf.credit.messages import CreditReturn from aiperf.credit.structs import Credit from aiperf.timing.config import CreditPhaseConfig from aiperf.timing.conversation_source import ConversationSource @@ -55,31 +56,35 @@ async def execute_phase(self) -> None: """Execute the main timing loop for first turns. Sends first turns according to the timing strategy (rate, schedule, etc.). - Subsequent turns are handled by handle_credit_return via callbacks. - Subsequent turns can also be handled here if the strategy uses a queue. + Subsequent turns are dispatched by handle_credit_return via callbacks, + or pulled from a continuation queue here for queue-based strategies. - Return from this method once there are no more turns to send. In Queue-based strategies, - they must wait until all turns are sent. In non-queue-based strategies, this can - return once all first-turn credits are sent. + Must not return until all first-turn credits are sent (or queued). """ ... async def handle_credit_return(self, credit: Credit) -> None: """Handle credit return: dispatch next turn if applicable. - Called when a worker completes a turn. Determines if a subsequent turn - should be sent, and if so, dispatches it via the appropriate path - (immediate, scheduled, or queued). + Called by CreditCallbackHandler when can_send_any_turn() is True OR + credit.agent_depth > 0 (child bypass for gate accounting). + Implementations with subagents must call intercept() first -- the + orchestrator handles all child routing and stop-condition checks. + """ + ... - Note: CreditCallbackHandler checks can_send_any_turn() before calling. - Implementations only need to check conversation-specific conditions - (e.g., is_final_turn). + def on_failed_credit(self, credit_return: CreditReturn) -> None: + """Handle errored/cancelled returns for subagent gate cleanup. - Args: - credit: Completed credit with conversation/turn info + Called BEFORE handle_credit_return, regardless of can_send_any_turn(). + No-op for strategies without subagents. """ ... + def cleanup(self) -> None: + """Release resources at phase end. Called by PhaseRunner after completion.""" + ... + # ============================================================================= # RateSettableProtocol - Protocol for strategies that support dynamic rate adjustment diff --git a/src/aiperf/timing/strategies/fixed_schedule.py b/src/aiperf/timing/strategies/fixed_schedule.py index a55713055..a6c2bc524 100644 --- a/src/aiperf/timing/strategies/fixed_schedule.py +++ b/src/aiperf/timing/strategies/fixed_schedule.py @@ -15,6 +15,7 @@ from aiperf.common.constants import MILLIS_PER_SECOND from aiperf.common.mixins import AIPerfLoggerMixin from aiperf.credit.structs import Credit, TurnToSend +from aiperf.timing.strategies.subagent_mixin import SubagentMixin if TYPE_CHECKING: from aiperf.common.loop_scheduler import LoopScheduler @@ -23,6 +24,7 @@ from aiperf.timing.conversation_source import ConversationSource from aiperf.timing.phase.lifecycle import PhaseLifecycle from aiperf.timing.phase.stop_conditions import StopConditionChecker + from aiperf.timing.subagent_orchestrator import SubagentOrchestrator class ScheduleEntry(NamedTuple): @@ -32,14 +34,11 @@ class ScheduleEntry(NamedTuple): turn: TurnToSend -class FixedScheduleStrategy(AIPerfLoggerMixin): +class FixedScheduleStrategy(SubagentMixin, AIPerfLoggerMixin): """Timing strategy for replaying conversation traces with absolute timestamps. Sends first turns at precise timestamps from conversation metadata. Subsequent turns dispatched immediately or after calculated delay. - - This is a pure timing strategy - no lifecycle or orchestration concerns. - The PhaseRunner handles all orchestration. """ def __init__( @@ -51,47 +50,38 @@ def __init__( credit_issuer: CreditIssuer, lifecycle: PhaseLifecycle, stop_checker: StopConditionChecker, + subagents: SubagentOrchestrator | None = None, **kwargs, ): - """Initialize fixed schedule timing strategy with all dependencies.""" super().__init__(logger_name="FixedScheduleTiming") self._config = config self._conversation_source = conversation_source self._scheduler = scheduler self._credit_issuer = credit_issuer self._lifecycle = lifecycle + self._init_subagents(subagents) + self._time_scale = 1.0 / (config.fixed_schedule_speedup or 1.0) - # Computed in setup_phase self._absolute_schedule: list[ScheduleEntry] = [] self._schedule_zero_ms: float = 0.0 def _timestamp_to_perf_sec(self, timestamp_ms: int | float) -> float: - """Convert trace timestamp in milliseconds to perf counter seconds. - - Uses the offset from the schedule zero to calculate the target performance seconds. - """ - target_offset_sec = (timestamp_ms - self._schedule_zero_ms) / MILLIS_PER_SECOND + """Convert trace timestamp in milliseconds to perf counter seconds.""" + scaled_offset_ms = (timestamp_ms - self._schedule_zero_ms) * self._time_scale + target_offset_sec = scaled_offset_ms / MILLIS_PER_SECOND return self._lifecycle.started_at_perf_sec + target_offset_sec async def setup_phase(self) -> None: - """Build absolute schedule from dataset metadata. + """Build absolute schedule from dataset metadata.""" + self._absolute_schedule = [] - Dataset is already filtered by loader (e.g., mooncake_trace._timestamp_within_offsets), - so we just validate and build the schedule. - """ - self._absolute_schedule = [] # Fresh schedule for each phase - - # Validate and build schedule for conv in self._conversation_source.dataset_metadata.conversations: - if not conv.turns: + if not conv.turns or conv.agent_depth > 0: continue - - # Validate first turn has timestamp (required for fixed schedule mode) if conv.turns[0].timestamp_ms is None: raise ValueError( f"First turn of {conv.conversation_id} missing timestamp_ms" ) - self._absolute_schedule.append( ScheduleEntry( timestamp_ms=conv.turns[0].timestamp_ms, @@ -108,7 +98,6 @@ async def setup_phase(self) -> None: raise ValueError("No conversations with valid first-turn timestamps found") self._absolute_schedule.sort(key=lambda x: x.timestamp_ms) - # Calculate schedule zero (dataset already filtered by loader) if self._config.auto_offset_timestamps: self._schedule_zero_ms = self._absolute_schedule[0].timestamp_ms elif self._config.fixed_schedule_start_offset is not None: @@ -116,20 +105,20 @@ async def setup_phase(self) -> None: else: self._schedule_zero_ms = 0.0 + speedup_msg = ( + f", speedup={self._config.fixed_schedule_speedup}x" + if self._config.fixed_schedule_speedup + else "" + ) self.info( f"Built schedule with {len(self._absolute_schedule)} timestamps, " f"zero_ms={self._schedule_zero_ms:.0f}, " f"auto_offset={self._config.auto_offset_timestamps}" + f"{speedup_msg}" ) async def execute_phase(self) -> None: - """Execute absolute schedule: send first turns at precise timestamps. - - Note: Subsequent turns are handled by handle_credit_return. - - Raises: - RuntimeError: If started_at_perf_ns is not set in the lifecycle - """ + """Execute absolute schedule: send first turns at precise timestamps.""" if self._lifecycle.started_at_perf_ns is None: raise RuntimeError("started_at_perf_ns is not set in the lifecycle") @@ -139,30 +128,43 @@ async def execute_phase(self) -> None: self._credit_issuer.issue_credit(entry.turn), ) - async def handle_credit_return( - self, - credit: Credit, - ) -> None: - """Handle credit return: dispatch next turn based on trace timing. + if self._subagents: + self._subagents.dispatch_turn0_background_spawns() + + async def handle_credit_return(self, credit: Credit) -> None: + """Handle credit return: dispatch next turn based on trace timing.""" + if self._subagents and self._subagents.intercept(credit): + return - Calculates delay from timestamp_ms or delay_ms metadata, then issues - credit immediately (delay=0) or schedules for later (delay>0). - """ if credit.is_final_turn: return - # This contains the delay_ms or timestamp_ms for the next turn next_meta = self._conversation_source.get_next_turn_metadata(credit) turn = TurnToSend.from_previous_credit(credit) + self._dispatch_by_timing(turn, next_meta.timestamp_ms, next_meta.delay_ms) - if next_meta.timestamp_ms is not None: + def _dispatch_turn(self, turn: TurnToSend) -> None: + """Dispatch callback for SubagentOrchestrator: look up timing and schedule.""" + meta = self._conversation_source.get_turn_metadata_at( + turn.conversation_id, turn.turn_index + ) + self._dispatch_by_timing(turn, meta.timestamp_ms, meta.delay_ms) + + def _dispatch_by_timing( + self, + turn: TurnToSend, + timestamp_ms: int | float | None, + delay_ms: int | float | None, + ) -> None: + """Dispatch a turn using timestamp, delay, or immediate execution.""" + if timestamp_ms is not None: self._scheduler.schedule_at_perf_sec( - self._timestamp_to_perf_sec(next_meta.timestamp_ms), + self._timestamp_to_perf_sec(timestamp_ms), self._credit_issuer.issue_credit(turn), ) - elif next_meta.delay_ms is not None: + elif delay_ms is not None: self._scheduler.schedule_later( - next_meta.delay_ms / MILLIS_PER_SECOND, + delay_ms * self._time_scale / MILLIS_PER_SECOND, self._credit_issuer.issue_credit(turn), ) else: diff --git a/src/aiperf/timing/strategies/request_rate.py b/src/aiperf/timing/strategies/request_rate.py index 084c85cda..cae9d2763 100644 --- a/src/aiperf/timing/strategies/request_rate.py +++ b/src/aiperf/timing/strategies/request_rate.py @@ -15,6 +15,7 @@ from aiperf.plugin import plugins from aiperf.plugin.enums import PluginType from aiperf.timing.intervals import IntervalGeneratorConfig +from aiperf.timing.strategies.subagent_mixin import SubagentMixin if TYPE_CHECKING: from aiperf.common.loop_scheduler import LoopScheduler @@ -23,9 +24,10 @@ from aiperf.timing.conversation_source import ConversationSource from aiperf.timing.phase.lifecycle import PhaseLifecycle from aiperf.timing.phase.stop_conditions import StopConditionChecker + from aiperf.timing.subagent_orchestrator import SubagentOrchestrator -class RequestRateStrategy(AIPerfLoggerMixin): +class RequestRateStrategy(SubagentMixin, AIPerfLoggerMixin): """Issues credits at a target average rate with configurable arrival patterns. The arrival pattern (Constant, Poisson, Gamma, ConcurrencyBurst) determines @@ -91,6 +93,7 @@ def __init__( stop_checker: StopConditionChecker, credit_issuer: CreditIssuer, lifecycle: PhaseLifecycle, + subagents: SubagentOrchestrator | None = None, **kwargs, ): """Initialize rate timing strategy with all dependencies.""" @@ -101,6 +104,7 @@ def __init__( self._stop_checker = stop_checker self._credit_issuer = credit_issuer self._lifecycle = lifecycle + self._init_subagents(subagents) # Queue for subsequent turns (turn_index > 0) waiting to be issued. # Populated by handle_credit_return when workers complete turns. @@ -118,8 +122,7 @@ def __init__( self._rate_generator = GeneratorClass(interval_config) async def setup_phase(self) -> None: - """Setup the phase.""" - pass # Already setup in __init__ + pass async def execute_phase(self) -> None: """Execute request rate main loop until stop condition reached. @@ -217,13 +220,28 @@ async def handle_credit_return(self, credit: Credit) -> None: The delay_ms from turn metadata (if present) is honored before queuing, simulating user "think time" between turns in a conversation. """ + if self._subagents and self._subagents.intercept(credit): + return + if credit.is_final_turn: return meta = self._conversation_source.get_next_turn_metadata(credit) turn = TurnToSend.from_previous_credit(credit) - # Honor think-time delay from dataset metadata before queuing + if meta.delay_ms is not None: + self._scheduler.schedule_later( + meta.delay_ms / MILLIS_PER_SECOND, + self._continuation_turns.put(turn), + ) + else: + self._continuation_turns.put_nowait(turn) + + def _dispatch_turn(self, turn: TurnToSend) -> None: + """Dispatch callback for SubagentOrchestrator: queue turn for rate-limited issuance.""" + meta = self._conversation_source.get_turn_metadata_at( + turn.conversation_id, turn.turn_index + ) if meta.delay_ms is not None: self._scheduler.schedule_later( meta.delay_ms / MILLIS_PER_SECOND, diff --git a/src/aiperf/timing/strategies/subagent_mixin.py b/src/aiperf/timing/strategies/subagent_mixin.py new file mode 100644 index 000000000..661cf6237 --- /dev/null +++ b/src/aiperf/timing/strategies/subagent_mixin.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Mixin providing subagent lifecycle delegation for timing strategies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from aiperf.credit.messages import CreditReturn + from aiperf.timing.subagent_orchestrator import SubagentOrchestrator + + +class SubagentMixin: + """Mixin that delegates subagent lifecycle methods to a SubagentOrchestrator. + + Strategies inherit this to avoid duplicating the same guard-and-delegate + pattern for error/cancel handling and cleanup. + + Subclasses must: + - Call ``_init_subagents(subagents)`` in their ``__init__`` + - Implement ``_dispatch_turn(turn)`` (strategy-specific dispatch callback) + + The dispatch callback is how the orchestrator re-enters the strategy to + schedule work. It receives a TurnToSend and must schedule it via the + strategy's own timing mechanism (e.g., schedule_at_perf_sec for + FixedSchedule, continuation queue for RequestRate, immediate for + UserCentric). The orchestrator calls this for: + - Gated parent turns (after all children complete) + - Non-final child next turns (continuing a child conversation) + - Background child first turns (fire-and-forget) + """ + + _subagents: SubagentOrchestrator | None + + def _init_subagents(self, subagents: SubagentOrchestrator | None) -> None: + self._subagents = subagents + if self._subagents is not None: + # Wire the dispatch callback. The orchestrator is created by + # PhaseRunner before the strategy, so it can't receive the + # callback at construction time. + self._subagents.set_dispatch(self._dispatch_turn) + + def on_failed_credit(self, credit_return: CreditReturn) -> None: + """Release errored/cancelled non-final children from gate tracking. + + Only non-final children: final turns do gate accounting in + _handle_child_credit. Calling terminate_child for final turns + would double-release from _child_to_gate. + """ + if not self._subagents: + return + credit = credit_return.credit + if credit.agent_depth > 0 and not credit.is_final_turn: + self._subagents.terminate_child(credit) + + def cleanup(self) -> None: + if self._subagents: + self._subagents.cleanup() diff --git a/src/aiperf/timing/strategies/user_centric_rate.py b/src/aiperf/timing/strategies/user_centric_rate.py index 360e3770c..7646fbecd 100644 --- a/src/aiperf/timing/strategies/user_centric_rate.py +++ b/src/aiperf/timing/strategies/user_centric_rate.py @@ -26,8 +26,8 @@ ------------------------------------------- User | Turns | Time | Turn Visualization ------------------------------------------- - 1 | - | - | (All turns completed before t=0) ← User 1 is "virtually done" - 16 | 20 | 0s | ████████████████████ ← New user at t=0 with all turns remaining + 1 | - | - | (All turns completed before t=0) <- User 1 is "virtually done" + 16 | 20 | 0s | ████████████████████ <- New user at t=0 with all turns remaining 5 | 6 | 1s | ██████ 9 | 11 | 2s | ███████████ 13 | 16 | 3s | ████████████████ @@ -66,6 +66,7 @@ from aiperf.common.mixins import AIPerfLoggerMixin from aiperf.credit.structs import Credit, TurnToSend +from aiperf.timing.strategies.subagent_mixin import SubagentMixin if TYPE_CHECKING: from aiperf.common.loop_scheduler import LoopScheduler @@ -74,6 +75,7 @@ from aiperf.timing.conversation_source import ConversationSource, SampledSession from aiperf.timing.phase.lifecycle import PhaseLifecycle from aiperf.timing.phase.stop_conditions import StopConditionChecker + from aiperf.timing.subagent_orchestrator import SubagentOrchestrator def _find_alternate_spacing_step(n: int) -> int: @@ -118,7 +120,7 @@ def build_first_turn(self) -> TurnToSend: return self.sampled.build_first_turn(max_turns=self.max_turns) -class UserCentricStrategy(AIPerfLoggerMixin): +class UserCentricStrategy(SubagentMixin, AIPerfLoggerMixin): """User-centric timing strategy for KV cache benchmarking with realistic multi-user patterns.""" def __init__( @@ -130,6 +132,7 @@ def __init__( stop_checker: StopConditionChecker, credit_issuer: CreditIssuer, lifecycle: PhaseLifecycle, + subagents: SubagentOrchestrator | None = None, **kwargs, ): """Initialize user-centric timing strategy with all dependencies.""" @@ -140,6 +143,7 @@ def __init__( self._stop_checker = stop_checker self._credit_issuer = credit_issuer self._lifecycle = lifecycle + self._init_subagents(subagents) self._num_users = self._config.num_users self._request_rate = self._config.request_rate @@ -159,6 +163,7 @@ def __init__( # Computed in setup_phase self._turn_gap: float = 0.0 + self._session_turns: int = 0 self._session_to_user: dict[str, User] = {} self._initial_users: list[User] = [] self._next_user_id: int = 1 @@ -200,9 +205,10 @@ async def setup_phase(self) -> None: num_users = self._num_users qps = self._request_rate # We allow varying turn counts per conversation, so we use the average across the whole dataset. - session_turns = round( + self._session_turns = round( self._conversation_source.dataset_metadata.average_turn_count ) + session_turns = self._session_turns # num_users firing once per turn_gap gives: qps = num_users / turn_gap self._turn_gap = num_users / qps # Time between each user's consecutive turns @@ -272,7 +278,7 @@ async def execute_phase(self) -> None: f"User-centric mode: " f"qps={self._request_rate}, " f"{self._num_users} users, " - f"session_turns={round(self._conversation_source.dataset_metadata.average_turn_count)}, " + f"session_turns={self._session_turns}, " f"stagger={self._stagger:.3f}s, " f"turn_gap={self._turn_gap:.3f}s" ) @@ -306,7 +312,7 @@ async def execute_phase(self) -> None: # Continuously spawn new users at discrete intervals to maintain the target QPS. while True: spawn_sec = heapq.heappop(spawn_queue) - await asyncio.sleep(spawn_sec - time.perf_counter()) + await asyncio.sleep(max(0.0, spawn_sec - time.perf_counter())) user = self._generate_next_user(spawn_sec) turn = user.build_first_turn() @@ -329,8 +335,10 @@ async def handle_credit_return( This maintains ideal pacing when responses arrive on time, but if the response is late, the max() re-aligns to current time (sends immediately). """ + if self._subagents and self._subagents.intercept(credit): + return + if credit.is_final_turn: - # User finished all their turns. New users continue spawning in execute_phase. self._session_to_user.pop(credit.x_correlation_id, None) return @@ -342,10 +350,14 @@ async def handle_credit_return( ) turn = TurnToSend.from_previous_credit(credit) - # If the next turn time already passed, the max() will - # re-align their schedule to account for the delay. user.next_send_time = max(current_sec, user.next_send_time + self._turn_gap) self._scheduler.schedule_at_perf_sec( user.next_send_time, self._credit_issuer.issue_credit(turn), ) + + def _dispatch_turn(self, turn: TurnToSend) -> None: + """Dispatch callback for SubagentOrchestrator: schedule turn immediately.""" + self._scheduler.execute_async( + self._credit_issuer.issue_credit(turn), + ) diff --git a/src/aiperf/timing/subagent_orchestrator.py b/src/aiperf/timing/subagent_orchestrator.py new file mode 100644 index 000000000..eb6a68205 --- /dev/null +++ b/src/aiperf/timing/subagent_orchestrator.py @@ -0,0 +1,767 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""SubagentOrchestrator: composable subagent spawn/join component. + +Strategies own an instance and call ``intercept(credit)`` at the top of their +handle_credit_return. The orchestrator handles all child routing, spawn +resolution, child dispatch, and gated turn dispatch internally using a strategy- +provided dispatch callback. Strategies forward observer methods +(terminate_child, cleanup) as one-liners. + +Prerequisite-Based Turn Gating +============================== + +Turns declare explicit prerequisites (e.g. ``spawn_join``) that must be +satisfied before the turn dispatches. The orchestrator builds a prerequisite +index at init time from TurnPrerequisite entries on each turn. + +Credit Return Flow for Subagents +================================ + +CreditCallbackHandler.on_credit_return processing order:: + + 1. Atomic counting (increment_returned) + 2. Track prefill release + 3. Release concurrency slots + 4. on_failed_credit → terminate_child (errored/cancelled, ALL credits) + 5. Signal all_credits_returned_event + 6. handle_credit_return → intercept (can_send OR depth > 0) + +Step 4 runs BEFORE step 6 so that terminate_child marks the child in +_terminated_children before _handle_child_credit checks _is_terminated. +Step 6 uses a child bypass (depth > 0) so gate accounting always completes, +even after stop fires. + +Stop Condition Interaction +========================== + +Four coordinated guards achieve zero-overshoot, zero-deadlock: + +1. **Callback handler child bypass** (CreditCallbackHandler step 6): + Child returns (depth > 0) always reach handle_credit_return for gate + accounting. Without this, child final-turn returns would be silently + dropped when stop fires, leaving parent gates permanently unsatisfied. + +2. **Non-final child dispatch suppression** (_handle_child_credit): + Checks can_send_any_turn() before dispatching a child's next turn. + This prevents the child bypass from causing unbounded child work. + +3. **Gate dispatch suppression** (_release_blocked_gate): + Checks can_send_any_turn() before dispatching the gated parent turn. + +4. **Credit issuance failure** (_issue_child_credit_or_release): + When issue_credit returns False (stop fired, no slots), the child is + released from gate tracking to prevent the parent from hanging. + +Additionally, issue_credit itself checks stop conditions at slot +acquisition time, providing a belt-and-suspenders final guard. +""" + +from __future__ import annotations + +import time +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from aiperf.common.enums import PrerequisiteKind +from aiperf.common.mixins import AIPerfLoggerMixin +from aiperf.common.models.dataset_models import TurnPrerequisite +from aiperf.credit.structs import TurnToSend + +if TYPE_CHECKING: + from aiperf.common.loop_scheduler import LoopScheduler + from aiperf.credit.issuer import CreditIssuer + from aiperf.credit.structs import Credit + from aiperf.timing.conversation_source import ConversationSource, SampledSession + from aiperf.timing.phase.stop_conditions import StopConditionChecker + + +@dataclass(slots=True) +class PendingTurnGate: + """Tracks prerequisite completion before dispatching a gated turn.""" + + parent_conversation_id: str + """Template conversation ID of the parent.""" + + parent_correlation_id: str + """Runtime correlation ID of the parent session.""" + + gated_turn_index: int + """Turn index that is blocked until prerequisites are met.""" + + parent_num_turns: int = 0 + """Total turns in the parent conversation.""" + + parent_agent_depth: int = 0 + """Nesting depth of the parent (0 = root).""" + + parent_parent_correlation_id: str | None = None + """Correlation ID of the parent's parent (for nested subagents).""" + + created_at_ns: int = 0 + """Monotonic timestamp when the gate was created, for leak diagnostics.""" + + is_blocked: bool = False + """True once the parent has reached the gated turn and is waiting.""" + + outstanding: dict[str, list[int]] = field(default_factory=dict) + """Maps prereq_key -> [expected_count, completed_count].""" + + @property + def is_satisfied(self) -> bool: + """True when all prerequisites have been met.""" + return all(c >= e for e, c in self.outstanding.values()) + + +@dataclass(slots=True) +class ChildGateEntry: + """Tracks which parent gate a blocking child belongs to.""" + + parent_corr_id: str + """Runtime correlation ID of the parent session.""" + + gated_turn_index: int + """Turn index on the parent that this child gates.""" + + prereq_key: str + """Prerequisite key in the parent's PendingTurnGate.outstanding dict.""" + + +@dataclass(slots=True) +class SubagentStats: + """Counters for subagent observability.""" + + children_spawned: int = 0 + """Total child sessions started across all spawns.""" + + children_completed: int = 0 + """Children that finished their final turn normally.""" + + children_errored: int = 0 + """Children released from gate due to error/cancel or issuance failure.""" + + parents_suspended: int = 0 + """Parent sessions that blocked on a gated turn.""" + + parents_resumed: int = 0 + """Parent sessions resumed after all prerequisites were met.""" + + joins_suppressed: int = 0 + """Gated turns suppressed because stop condition fired.""" + + +class SubagentOrchestrator(AIPerfLoggerMixin): + """Composable subagent spawn/join component. + + Strategies call ``intercept(credit)`` from handle_credit_return. The + orchestrator routes child credits, manages gate tracking, and dispatches + gated parent turns via a strategy-provided callback. + + Integration: SubagentMixin wires _init_subagents, on_failed_credit, cleanup. + See module docstring for credit return flow ordering and stop-condition guards. + """ + + def __init__( + self, + *, + conversation_source: ConversationSource, + credit_issuer: CreditIssuer, + stop_checker: StopConditionChecker, + scheduler: LoopScheduler, + dispatch_fn: Callable[[TurnToSend], None] | None = None, + ) -> None: + """Initialize orchestrator. + + Args: + conversation_source: Dataset metadata and session creation. + credit_issuer: Issues credits to workers. + stop_checker: Evaluates stop conditions (has can_send_any_turn()). + scheduler: For scheduling async coroutines (abandon wrapper). + dispatch_fn: Strategy-specific dispatch. Called with a TurnToSend + for gated turns, child next turns, and background child first turns. + """ + super().__init__(logger_name="SubagentOrchestrator") + self._conversation_source = conversation_source + self._credit_issuer = credit_issuer + self._stop_checker = stop_checker + self._scheduler = scheduler + self._dispatch = dispatch_fn + + self._gated_turns: dict[str, PendingTurnGate] = {} + self._future_gates: dict[str, dict[int, PendingTurnGate]] = {} + self._child_to_gate: dict[str, ChildGateEntry] = {} + self._terminated_children: set[str] = set() + self._cleaning_up: bool = False + self._stats = SubagentStats() + + self._prerequisite_index: dict[tuple[str, int], list[TurnPrerequisite]] = {} + self._spawn_join_index: dict[tuple[str, str], int] = {} + self._build_prerequisite_index() + + def set_dispatch(self, dispatch_fn: Callable[[TurnToSend], None]) -> None: + """Set the strategy-provided dispatch callback.""" + self._dispatch = dispatch_fn + + # ========================================================================= + # Primary entry point: intercept credit returns + # ========================================================================= + + def intercept(self, credit: Credit) -> bool: + """Handle subagent-related credit returns. + + Called at the top of every strategy's handle_credit_return. + Returns True if the credit was fully handled (strategy must NOT + dispatch next turn). Returns False for normal strategy dispatch. + + True only for: (a) child credits (depth > 0), (b) parent suspended + at a gate. Returning True incorrectly drops the parent's next turn. + """ + if self._cleaning_up: + return False + + # Child credits are fully handled here. Strategy must not look up + # child conversation_ids -- they're not in the strategy's schedule. + if credit.agent_depth > 0: + self._handle_child_credit(credit) + return True + + # Parent turn completed -- check if it spawned children. + spawn_ids = self._get_spawn_ids(credit) + if spawn_ids: + self.debug( + lambda: f"Parent {credit.x_correlation_id} turn[{credit.turn_index}] " + f"completed with spawns: {spawn_ids}" + ) + # Turn-0 background spawns are pre-dispatched in execute_phase. + # Filter them out here to avoid double-dispatch. + if credit.turn_index == 0: + spawn_ids = [ + sid + for sid in spawn_ids + if not self._is_background(credit.conversation_id, sid) + ] + if spawn_ids: + self._resolve_and_dispatch_spawns(credit, spawn_ids) + + # Suspend parent if its next turn has unsatisfied prerequisites. + suspended = self._maybe_suspend_parent(credit) + if suspended: + self.debug( + lambda: f"Parent {credit.x_correlation_id} suspended at " + f"turn[{credit.turn_index + 1}] waiting for prerequisites" + ) + return suspended + + # ========================================================================= + # Turn-0 background pre-dispatch (called from execute_phase) + # ========================================================================= + + def dispatch_turn0_background_spawns(self) -> None: + """Pre-dispatch background children for turn 0 of all root conversations. + + Called from execute_phase after first-turn credits are scheduled. + Coordinates with intercept() which filters out turn-0 background + spawns to avoid double-dispatch. Blocking turn-0 spawns go through + intercept() instead. + """ + if self._dispatch is None: + self.warning("dispatch_turn0_background_spawns called before set_dispatch") + return + total_dispatched = 0 + for conv in self._conversation_source.dataset_metadata.conversations: + if conv.agent_depth > 0 or not conv.turns: + continue + bg_ids = self._get_turn0_background_ids(conv.conversation_id) + if not bg_ids: + continue + # Synthetic correlation ID -- these children are not gated so + # the parent_corr_id is only used for logging/tracing. + parent_corr_id = f"bg-turn0-{conv.conversation_id}" + for spawn_id in bg_ids: + spawn = self._conversation_source.get_subagent_spawn( + conv.conversation_id, spawn_id + ) + if spawn is None: + continue + for child_cid in spawn.child_conversation_ids: + session = self._conversation_source.start_child_session(child_cid) + turn = session.build_first_turn( + agent_depth=1, parent_correlation_id=parent_corr_id + ) + self._dispatch(turn) + self._stats.children_spawned += 1 + total_dispatched += 1 + if total_dispatched > 0: + self.info(f"Pre-dispatched {total_dispatched} turn-0 background children") + + # ========================================================================= + # Strategy-facing API (called via SubagentMixin) + # ========================================================================= + + def terminate_child(self, credit: Credit) -> None: + """Release an errored/cancelled child from gate tracking. + + Called from SubagentMixin.on_failed_credit (callback handler step 4). + Must run before _handle_child_credit (step 6) -- see module docstring. + + Skips final turns (gate accounting lives in _handle_child_credit) and + background children (not tracked in _child_to_gate). + """ + if ( + self._cleaning_up + or credit.agent_depth == 0 + or credit.is_final_turn + or credit.x_correlation_id not in self._child_to_gate + ): + return + self._stats.children_errored += 1 + self._terminated_children.add(credit.x_correlation_id) + entry = self._child_to_gate.get(credit.x_correlation_id) + self.debug( + lambda: f"Terminating child {credit.x_correlation_id} " + f"(parent={entry.parent_corr_id if entry else '?'}, " + f"turn[{credit.turn_index}]/{credit.num_turns})" + ) + gated = self._release_child(credit.x_correlation_id) + if gated: + self._dispatch(gated) + + def cleanup(self) -> None: + """Log leaked state and clear all tracking. Idempotent.""" + if self._cleaning_up: + return + self._cleaning_up = True + s = self._stats + self.info( + f"Subagent stats: spawned={s.children_spawned}, " + f"completed={s.children_completed}, errored={s.children_errored}, " + f"suspended={s.parents_suspended}, resumed={s.parents_resumed}, " + f"suppressed={s.joins_suppressed}" + ) + if self._gated_turns or self._future_gates or self._child_to_gate: + self.warning( + f"Leaked state at cleanup: {len(self._gated_turns)} active gates, " + f"{sum(len(g) for g in self._future_gates.values())} future gates, " + f"{len(self._child_to_gate)} tracked children" + ) + now_ns = time.time_ns() + leaked_gates = list(self._gated_turns.items()) + leaked_gates.extend(self._iter_future_gates()) + for parent_corr_id, gate in leaked_gates: + age_ms = (now_ns - gate.created_at_ns) / 1_000_000 + total_expected = sum(e for e, _ in gate.outstanding.values()) + total_completed = sum(c for _, c in gate.outstanding.values()) + self.warning( + f"Abandoned pending gate for parent {parent_corr_id} " + f"(age={age_ms:.0f}ms, completed={total_completed}/" + f"{total_expected})" + ) + self._gated_turns.clear() + self._future_gates.clear() + self._child_to_gate.clear() + self._terminated_children.clear() + + def get_stats(self) -> dict[str, int]: + """Return subagent counters for phase stats.""" + s = self._stats + return { + "subagent_children_spawned": s.children_spawned, + "subagent_children_completed": s.children_completed, + "subagent_children_errored": s.children_errored, + "subagent_parents_suspended": s.parents_suspended, + "subagent_parents_resumed": s.parents_resumed, + "subagent_joins_suppressed": s.joins_suppressed, + } + + # ========================================================================= + # Internal: prerequisite index + # ========================================================================= + + def _build_prerequisite_index(self) -> None: + """Build (conv_id, turn_index) -> prerequisites and (conv_id, spawn_id) -> turn_index lookups.""" + for conv in self._conversation_source.dataset_metadata.conversations: + for idx, turn_meta in enumerate(conv.turns): + if turn_meta.prerequisites: + self._prerequisite_index[(conv.conversation_id, idx)] = list( + turn_meta.prerequisites + ) + for prereq in turn_meta.prerequisites: + if ( + prereq.kind == PrerequisiteKind.SPAWN_JOIN + and prereq.spawn_id + ): + self._spawn_join_index[ + (conv.conversation_id, prereq.spawn_id) + ] = idx + if self._prerequisite_index: + self.info( + f"Prerequisite index: {len(self._prerequisite_index)} gated turns, " + f"{len(self._spawn_join_index)} spawn-join entries" + ) + + def _find_gated_turn_index( + self, conversation_id: str, spawn_ids: list[str] + ) -> int | None: + """Find the turn index gated by the given spawn_ids via the spawn_join index.""" + for spawn_id in spawn_ids: + turn_idx = self._spawn_join_index.get((conversation_id, spawn_id)) + if turn_idx is not None: + return turn_idx + return None + + def _iter_future_gates(self) -> list[tuple[str, PendingTurnGate]]: + """Flatten future gates for cleanup/logging.""" + leaked: list[tuple[str, PendingTurnGate]] = [] + for parent_corr_id, gates in self._future_gates.items(): + for gate in gates.values(): + leaked.append((parent_corr_id, gate)) + return leaked + + def _get_gate( + self, parent_corr_id: str, gated_turn_index: int + ) -> PendingTurnGate | None: + """Look up either an active blocked gate or a future gate.""" + active_gate = self._gated_turns.get(parent_corr_id) + if active_gate is not None and active_gate.gated_turn_index == gated_turn_index: + return active_gate + return self._future_gates.get(parent_corr_id, {}).get(gated_turn_index) + + def _add_future_gate( + self, + *, + parent_corr_id: str, + gated_turn_index: int, + credit: Credit, + prereq_key: str, + expected_children: int, + ) -> None: + """Create or extend a future gate for a later parent turn.""" + gates_for_parent = self._future_gates.setdefault(parent_corr_id, {}) + gate = gates_for_parent.get(gated_turn_index) + if gate is None: + gate = PendingTurnGate( + parent_conversation_id=credit.conversation_id, + parent_correlation_id=parent_corr_id, + gated_turn_index=gated_turn_index, + parent_num_turns=credit.num_turns, + parent_agent_depth=credit.agent_depth, + parent_parent_correlation_id=credit.parent_correlation_id, + created_at_ns=time.time_ns(), + ) + gates_for_parent[gated_turn_index] = gate + gate.outstanding[prereq_key] = [expected_children, 0] + + def _pop_future_gate( + self, parent_corr_id: str, gated_turn_index: int + ) -> PendingTurnGate | None: + """Remove and return a future gate.""" + gates_for_parent = self._future_gates.get(parent_corr_id) + if gates_for_parent is None: + return None + gate = gates_for_parent.pop(gated_turn_index, None) + if not gates_for_parent: + self._future_gates.pop(parent_corr_id, None) + return gate + + def _maybe_suspend_parent(self, credit: Credit) -> bool: + """Suspend parent if its next turn has unsatisfied prerequisites. + + Returns True to block the parent (strategy won't dispatch next turn). + Promotes a "future gate" (created by _resolve_and_dispatch_spawns) to + an active blocked gate in _gated_turns. + + RACE: Children can complete before the parent reaches the gate. If the + future gate is already satisfied, we pop it and return False (no block). + """ + next_turn_index = credit.turn_index + 1 + + # Already blocked (multiple spawns targeting same gated turn). + active_gate = self._gated_turns.get(credit.x_correlation_id) + if ( + active_gate is not None + and active_gate.gated_turn_index == next_turn_index + and not active_gate.is_satisfied + ): + return True + + # Promote future gate to active. + future_gate = self._pop_future_gate(credit.x_correlation_id, next_turn_index) + if future_gate is None: + return False + + # Children already finished -- no need to block. + if future_gate.is_satisfied: + return False + + future_gate.is_blocked = True + self._gated_turns[credit.x_correlation_id] = future_gate + self._stats.parents_suspended += 1 + return True + + def _satisfy_prerequisite( + self, parent_corr_id: str, gated_turn_index: int, prereq_key: str + ) -> TurnToSend | None: + """Increment completion for a prerequisite; dispatch gated turn when all met. + + Returns TurnToSend only when the gate is fully satisfied AND the parent + is already blocked. If satisfied before the parent arrives, pops the + future gate so _maybe_suspend_parent sees no gate and skips suspension. + """ + gate = self._get_gate(parent_corr_id, gated_turn_index) + if gate is None: + return None + + if prereq_key not in gate.outstanding: + return None + + gate.outstanding[prereq_key][1] += 1 + if not gate.is_satisfied: + return None + + # Satisfied before parent arrived -- clean up future gate. + if not gate.is_blocked: + self._pop_future_gate(parent_corr_id, gated_turn_index) + return None + + return self._release_blocked_gate(parent_corr_id) + + def _release_blocked_gate(self, parent_corr_id: str) -> TurnToSend | None: + """Release a blocked parent gate and build its gated turn for dispatch. + + Checks can_send_any_turn() to suppress dispatch after stop fires + (one of the three stop-condition guards; see module docstring). + """ + gate = self._gated_turns.pop(parent_corr_id, None) + if gate is None: + return None + + self._stats.parents_resumed += 1 + if gate.gated_turn_index >= gate.parent_num_turns: + return None + + if not self._stop_checker.can_send_any_turn(): + self._stats.joins_suppressed += 1 + self.debug( + lambda: f"Suppressed gated turn for parent {parent_corr_id} " + f"(stop fired, turn {gate.gated_turn_index}/" + f"{gate.parent_num_turns})" + ) + return None + + return TurnToSend( + conversation_id=gate.parent_conversation_id, + x_correlation_id=parent_corr_id, + turn_index=gate.gated_turn_index, + num_turns=gate.parent_num_turns, + agent_depth=gate.parent_agent_depth, + parent_correlation_id=gate.parent_parent_correlation_id, + ) + + # ========================================================================= + # Internal: child credit handling + # ========================================================================= + + def _handle_child_credit(self, credit: Credit) -> None: + """Route a child credit: gate accounting for final, next turn for non-final. + + Reached even after stop fires (callback handler child bypass), so both + paths check stop conditions before dispatching new work. + """ + if credit.is_final_turn: + # Gate accounting for blocking children only. Background children + # (not in _child_to_gate) are fire-and-forget. + if credit.x_correlation_id in self._child_to_gate: + self._stats.children_completed += 1 + entry = self._child_to_gate.get(credit.x_correlation_id) + self.debug( + lambda: f"Child {credit.x_correlation_id} completed final turn " + f"(parent={entry.parent_corr_id if entry else '?'})" + ) + gated = self._release_child(credit.x_correlation_id) + if gated: + self.debug( + lambda: f"Gate satisfied, resuming parent " + f"{gated.x_correlation_id} at turn[{gated.turn_index}]" + ) + self._dispatch(gated) + else: + # Consume terminated marker (set by terminate_child in step 4) + # to prevent dispatching next turn for errored children. + if self._is_terminated(credit): + return + if not self._stop_checker.can_send_any_turn(): + return + turn = TurnToSend.from_previous_credit(credit) + self._dispatch(turn) + + def _is_terminated(self, credit: Credit) -> bool: + """Check and consume terminated marker for a non-final child credit. + + Consume-on-read: child turns are sequential (one in-flight at a time), + so the marker only needs to block one return. + """ + if credit.x_correlation_id in self._terminated_children: + self._terminated_children.discard(credit.x_correlation_id) + return True + return False + + # ========================================================================= + # Internal: spawn resolution and child dispatch + # ========================================================================= + + def _get_spawn_ids(self, credit: Credit) -> list[str]: + meta = self._conversation_source.get_turn_metadata_at( + credit.conversation_id, credit.turn_index + ) + return meta.subagent_spawn_ids + + def _get_turn0_background_ids(self, conversation_id: str) -> list[str]: + meta = self._conversation_source.get_metadata(conversation_id) + if not meta.turns or not meta.turns[0].subagent_spawn_ids: + return [] + return [ + sid + for sid in meta.turns[0].subagent_spawn_ids + if self._is_background(conversation_id, sid) + ] + + def _resolve_and_dispatch_spawns( + self, credit: Credit, spawn_ids: list[str] + ) -> None: + """Resolve spawns, register gate tracking, and dispatch children. + + Two-phase: register all gates THEN dispatch. No awaits between phases, + so _child_to_gate entries exist before any child credit can return. + """ + parent_corr_id = credit.x_correlation_id + child_depth = credit.agent_depth + 1 + + resolved: list[tuple[bool, str, list[SampledSession], int | None]] = [] + + for spawn_id in spawn_ids: + spawn = self._conversation_source.get_subagent_spawn( + credit.conversation_id, spawn_id + ) + if spawn is None: + continue + + is_blocking = not spawn.is_background + child_sessions = [ + self._conversation_source.start_child_session(cid) + for cid in spawn.child_conversation_ids + ] + gated_turn_index = None + if is_blocking: + gated_turn_index = self._find_gated_turn_index( + credit.conversation_id, [spawn_id] + ) + # Sanity check: gate must be on a future turn. A gate on + # a past/current turn means the dataset is malformed. + if ( + gated_turn_index is not None + and gated_turn_index <= credit.turn_index + ): + self.warning( + f"Ignoring spawn gate on past turn {gated_turn_index} for " + f"parent {parent_corr_id} spawn {spawn_id}" + ) + gated_turn_index = None + resolved.append((is_blocking, spawn_id, child_sessions, gated_turn_index)) + + if not resolved: + return + + if self.is_debug_enabled: + n_blocking = sum(1 for b, _, _, _ in resolved if b) + n_background = len(resolved) - n_blocking + self.debug( + f"Resolved spawns for parent {parent_corr_id}: " + f"{n_blocking} blocking, {n_background} background" + ) + + # PHASE 1: Register all gate tracking before dispatching any children. + for is_blocking, spawn_id, child_sessions, gated_turn_index in resolved: + if not is_blocking: + continue + if gated_turn_index is None: + # Blocking spawn with no matching prerequisite in the dataset. + # The children will run but the parent won't wait for them. + self.warning( + f"Blocking spawn {spawn_id} on parent {parent_corr_id} has no " + "matching prerequisite; parent will not be gated" + ) + continue + prereq_key = f"{PrerequisiteKind.SPAWN_JOIN}:{spawn_id}" + self._add_future_gate( + parent_corr_id=parent_corr_id, + gated_turn_index=gated_turn_index, + credit=credit, + prereq_key=prereq_key, + expected_children=len(child_sessions), + ) + + # PHASE 2: Dispatch children. + for is_blocking, spawn_id, child_sessions, gated_turn_index in resolved: + for session in child_sessions: + turn = session.build_first_turn( + agent_depth=child_depth, + parent_correlation_id=parent_corr_id, + ) + if is_blocking: + self._scheduler.execute_async( + self._issue_child_credit_or_release( + turn, session.x_correlation_id + ), + ) + else: + self._dispatch(turn) + if is_blocking and gated_turn_index is not None: + self._child_to_gate[session.x_correlation_id] = ChildGateEntry( + parent_corr_id=parent_corr_id, + gated_turn_index=gated_turn_index, + prereq_key=f"{PrerequisiteKind.SPAWN_JOIN}:{spawn_id}", + ) + self._stats.children_spawned += 1 + + async def _issue_child_credit_or_release( + self, turn: TurnToSend, corr_id: str + ) -> None: + """Issue a blocking child credit; release from gate if issuance fails. + + When issue_credit returns False (stop fired, no slots), releases the + child from gate tracking so the parent doesn't hang indefinitely. + """ + try: + issued = await self._credit_issuer.issue_credit(turn) + except Exception: + self.warning( + f"Exception issuing credit for child {corr_id}, releasing from gate" + ) + issued = False + if not issued and corr_id in self._child_to_gate: + self._stats.children_errored += 1 + gated = self._release_child(corr_id) + if gated: + self._dispatch(gated) + + # ========================================================================= + # Internal: gate accounting + # ========================================================================= + + def _is_background(self, conversation_id: str, spawn_id: str) -> bool: + spawn = self._conversation_source.get_subagent_spawn(conversation_id, spawn_id) + return spawn is not None and spawn.is_background + + def _release_child(self, child_corr_id: str) -> TurnToSend | None: + """Pop child from gate tracking and satisfy its parent's prerequisite. + + Returns TurnToSend if all prerequisites met and parent is blocked. + Safe to call twice (pop returns None on second call). + """ + entry = self._child_to_gate.pop(child_corr_id, None) + if entry is None: + return None + return self._satisfy_prerequisite( + entry.parent_corr_id, entry.gated_turn_index, entry.prereq_key + ) diff --git a/src/aiperf/workers/inference_client.py b/src/aiperf/workers/inference_client.py index f52c433b1..745c07a5e 100644 --- a/src/aiperf/workers/inference_client.py +++ b/src/aiperf/workers/inference_client.py @@ -99,6 +99,17 @@ async def _send_request_to_transport( request_info.endpoint_headers = self.endpoint.get_endpoint_headers(request_info) request_info.endpoint_params = self.endpoint.get_endpoint_params(request_info) formatted_payload = self.endpoint.format_payload(request_info) + + current_turn = request_info.turns[-1] if request_info.turns else None + if ( + current_turn + and current_turn.extra_params + and isinstance(formatted_payload, dict) + ): + formatted_payload = self.endpoint.merge_turn_params( + formatted_payload, current_turn.extra_params + ) + return await self.transport.send_request( request_info, payload=formatted_payload, diff --git a/tests/component_integration/dataset/test_conflux_timing_modes.py b/tests/component_integration/dataset/test_conflux_timing_modes.py new file mode 100644 index 000000000..d4dff7be4 --- /dev/null +++ b/tests/component_integration/dataset/test_conflux_timing_modes.py @@ -0,0 +1,1049 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Component integration tests for Conflux subagent spawning across all timing modes. + +Exercises the full pipeline: + Handcrafted DatasetMetadata (mimicking ConfluxLoader output) + -> ConversationSource -> SubagentOrchestrator -> Strategy-specific dispatch + -> Credit issuance -> Callback return -> Gate completion -> Gated turn dispatch + +Each timing mode (FIXED_SCHEDULE, REQUEST_RATE, USER_CENTRIC_RATE) dispatches +child turns and gated turns through its own _dispatch_turn callback, using +timestamps/delays from metadata. These tests verify: + +1. All timing modes correctly spawn children, gate the parent, and resume. +2. Timestamps on dispatched turns are plausible (absolute or relative). +3. Delayed joins (spawn_at != join_at - 1) work across all modes. +4. Background spawns do not gate across all modes. +5. Multiple spawns on different turns compose correctly. +6. Error/cancellation on children correctly releases gates. +""" + +from __future__ import annotations + +import itertools +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from aiperf.common.enums import CreditPhase, PrerequisiteKind +from aiperf.common.models import ( + ConversationMetadata, + DatasetMetadata, + SubagentSpawnInfo, + TurnMetadata, + TurnPrerequisite, +) +from aiperf.credit.structs import Credit, TurnToSend +from aiperf.plugin import plugins +from aiperf.plugin.enums import DatasetSamplingStrategy, PluginType +from aiperf.timing.conversation_source import ConversationSource +from aiperf.timing.subagent_orchestrator import SubagentOrchestrator + +# ============================================================================= +# Helpers +# ============================================================================= + +_credit_counter = itertools.count(1) + + +def _make_sampler(conv_ids: list[str]) -> object: + cls = plugins.get_class( + PluginType.DATASET_SAMPLER, DatasetSamplingStrategy.SEQUENTIAL + ) + return cls(conversation_ids=conv_ids) + + +def _make_credit( + *, + conv_id: str, + corr_id: str, + turn_index: int, + num_turns: int, + agent_depth: int = 0, + parent_correlation_id: str | None = None, +) -> Credit: + return Credit( + id=next(_credit_counter), + phase=CreditPhase.PROFILING, + conversation_id=conv_id, + x_correlation_id=corr_id, + turn_index=turn_index, + num_turns=num_turns, + issued_at_ns=time.time_ns(), + agent_depth=agent_depth, + parent_correlation_id=parent_correlation_id, + ) + + +def _build_orchestrator( + ds: DatasetMetadata, + *, + stop_can_send: bool = True, +) -> tuple[SubagentOrchestrator, list[TurnToSend], MagicMock]: + """Build orchestrator from dataset with a capturing dispatch_fn.""" + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids) + src = ConversationSource(ds, sampler) + + scheduler = MagicMock() + scheduler.execute_async = MagicMock() + scheduler.schedule_at_perf_sec = MagicMock() + scheduler.schedule_later = MagicMock() + dispatched: list[TurnToSend] = [] + + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=MagicMock( + issue_credit=AsyncMock(return_value=True), + try_issue_credit=AsyncMock(return_value=True), + ), + stop_checker=MagicMock( + can_send_any_turn=MagicMock(return_value=stop_can_send), + can_start_new_session=MagicMock(return_value=stop_can_send), + ), + scheduler=scheduler, + dispatch_fn=lambda turn: dispatched.append(turn), + ) + return orch, dispatched, scheduler + + +# ============================================================================= +# Dataset builders mimicking ConfluxLoader output +# ============================================================================= + + +def _make_conflux_dataset( + *, + parent_turns: int = 6, + spawn_at: int = 2, + join_at: int | None = None, + num_children: int = 2, + child_turns: int = 3, + is_background: bool = False, + timestamps: bool = False, + timestamp_base_ms: int = 1000, + timestamp_spacing_ms: int = 500, + delay_ms: float | None = None, +) -> DatasetMetadata: + """Build a dataset mimicking Conflux trace output. + + Args: + timestamps: If True, add absolute timestamps on all turns. + delay_ms: If set, add delay_ms on subsequent turns. + """ + join_at = spawn_at + 1 if join_at is None else join_at + child_conv_ids = [f"parent_s0_c{ci}" for ci in range(num_children)] + spawn = SubagentSpawnInfo( + spawn_id="s0", + child_conversation_ids=child_conv_ids, + is_background=is_background, + ) + + turns: list[TurnMetadata] = [] + for i in range(parent_turns): + spawn_ids = ["s0"] if i == spawn_at else [] + prereqs: list[TurnPrerequisite] = [] + if i == join_at and not is_background: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0") + ] + + ts_ms = timestamp_base_ms + (i * timestamp_spacing_ms) if timestamps else None + d_ms = delay_ms if i > 0 else None + + turns.append( + TurnMetadata( + timestamp_ms=ts_ms, + delay_ms=d_ms, + input_tokens=500 + i * 100, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="parent", + turns=turns, + subagent_spawns=[spawn], + ) + ] + for _ci, child_id in enumerate(child_conv_ids): + child_turns_list: list[TurnMetadata] = [] + for j in range(child_turns): + c_ts = None + c_delay = None + if timestamps: + # Children start slightly after spawn turn + c_ts = ( + timestamp_base_ms + + (spawn_at * timestamp_spacing_ms) + + 100 + + (j * timestamp_spacing_ms) + ) + if j > 0 and delay_ms is not None: + c_delay = delay_ms + child_turns_list.append( + TurnMetadata( + timestamp_ms=c_ts, + delay_ms=c_delay, + input_tokens=300 + j * 50, + ) + ) + convs.append( + ConversationMetadata( + conversation_id=child_id, + turns=child_turns_list, + agent_depth=1, + parent_conversation_id="parent", + ) + ) + + return DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + + +def _make_multi_spawn_dataset( + *, + timestamps: bool = False, + delay_ms: float | None = None, +) -> DatasetMetadata: + """Parent with 2 spawns on different turns: s0 at turn 1, s1 at turn 3. + + Parent: 6 turns + - Turn 1: spawn s0 (2 children), join at turn 2 + - Turn 3: spawn s1 (1 child), join at turn 4 + """ + child_s0_ids = ["parent_s0_c0", "parent_s0_c1"] + child_s1_ids = ["parent_s1_c0"] + spawn_s0 = SubagentSpawnInfo(spawn_id="s0", child_conversation_ids=child_s0_ids) + spawn_s1 = SubagentSpawnInfo(spawn_id="s1", child_conversation_ids=child_s1_ids) + + turns: list[TurnMetadata] = [] + for i in range(6): + spawn_ids: list[str] = [] + prereqs: list[TurnPrerequisite] = [] + if i == 1: + spawn_ids = ["s0"] + elif i == 2: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0") + ] + elif i == 3: + spawn_ids = ["s1"] + elif i == 4: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s1") + ] + + ts = (1000 + i * 500) if timestamps else None + d = delay_ms if i > 0 else None + turns.append( + TurnMetadata( + timestamp_ms=ts, + delay_ms=d, + input_tokens=400 + i * 50, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="parent", + turns=turns, + subagent_spawns=[spawn_s0, spawn_s1], + ) + ] + + for child_id in child_s0_ids + child_s1_ids: + child_turns = [ + TurnMetadata( + timestamp_ms=(1000 + 600 + j * 300) if timestamps else None, + delay_ms=delay_ms if j > 0 else None, + input_tokens=200 + j * 30, + ) + for j in range(2) + ] + convs.append( + ConversationMetadata( + conversation_id=child_id, + turns=child_turns, + agent_depth=1, + parent_conversation_id="parent", + ) + ) + + return DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + + +def _run_spawn_lifecycle( + orch: SubagentOrchestrator, + dispatched: list[TurnToSend], + scheduler: MagicMock, + *, + parent_turns: int = 6, + spawn_at: int = 2, + join_at: int | None = None, + num_children: int = 2, + child_turns: int = 3, + is_background: bool = False, +) -> dict: + """Drive the orchestrator through a full spawn->children->join lifecycle. + + Returns a dict of verification data. + """ + join_at = spawn_at + 1 if join_at is None else join_at + + # Advance parent to spawn turn + for t in range(spawn_at + 1): + credit = _make_credit( + conv_id="parent", + corr_id="parent-corr", + turn_index=t, + num_turns=parent_turns, + ) + orch.intercept(credit) + + child_corr_ids = list(orch._child_to_gate.keys()) + child_conv_ids = [f"parent_s0_c{ci}" for ci in range(num_children)] + + # If delayed join, advance parent past spawn to the turn before join + if join_at > spawn_at + 1: + for t in range(spawn_at + 1, join_at): + credit = _make_credit( + conv_id="parent", + corr_id="parent-corr", + turn_index=t, + num_turns=parent_turns, + ) + orch.intercept(credit) + + # Complete all children + dispatched_before_children_complete = len(dispatched) + for ci in range(num_children): + corr_id = child_corr_ids[ci] if ci < len(child_corr_ids) else f"child-{ci}" + for t in range(child_turns): + child_credit = _make_credit( + conv_id=child_conv_ids[ci], + corr_id=corr_id, + turn_index=t, + num_turns=child_turns, + agent_depth=1, + ) + orch.intercept(child_credit) + + # Find gated turn dispatches + gated_dispatches = [ + d + for d in dispatched[dispatched_before_children_complete:] + if d.conversation_id == "parent" + ] + + return { + "child_corr_ids": child_corr_ids, + "child_conv_ids": child_conv_ids, + "gated_dispatches": gated_dispatches, + "all_dispatched": dispatched, + "stats": orch.get_stats(), + } + + +# ============================================================================= +# Tests: Blocking spawn lifecycle across all timing modes +# ============================================================================= + + +@pytest.mark.component_integration +class TestBlockingSpawnAllModes: + """Verify blocking spawn lifecycle works identically across timing modes. + + The SubagentOrchestrator is mode-agnostic; these tests confirm that the + strategy's _dispatch_turn callback correctly routes gated turns regardless + of whether the strategy uses timestamps, rate-limiting, or user-centric pacing. + """ + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + @pytest.mark.parametrize("delay_ms", [None, 200.0], ids=["no-delay", "delay-200ms"]) + def test_blocking_spawn_dispatches_gated_turn( + self, timestamps: bool, delay_ms: float | None + ) -> None: + """Parent suspends at spawn, resumes at join after all children complete.""" + ds = _make_conflux_dataset( + parent_turns=6, + spawn_at=2, + num_children=2, + child_turns=3, + timestamps=timestamps, + delay_ms=delay_ms, + ) + orch, dispatched, scheduler = _build_orchestrator(ds) + result = _run_spawn_lifecycle( + orch, + dispatched, + scheduler, + parent_turns=6, + spawn_at=2, + num_children=2, + child_turns=3, + ) + + assert result["stats"]["subagent_children_spawned"] == 2 + assert result["stats"]["subagent_children_completed"] == 2 + assert result["stats"]["subagent_parents_suspended"] == 1 + assert result["stats"]["subagent_parents_resumed"] == 1 + + assert len(result["gated_dispatches"]) == 1 + gated = result["gated_dispatches"][0] + assert gated.turn_index == 3 + assert gated.conversation_id == "parent" + assert gated.num_turns == 6 + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + def test_child_non_final_turns_dispatch_next(self, timestamps: bool) -> None: + """Non-final child turns dispatch the next child turn.""" + ds = _make_conflux_dataset( + parent_turns=6, + spawn_at=2, + num_children=1, + child_turns=3, + timestamps=timestamps, + ) + orch, dispatched, scheduler = _build_orchestrator(ds) + + # Advance parent to spawn + for t in range(3): + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=t, num_turns=6) + ) + + child_corr_ids = list(orch._child_to_gate.keys()) + assert len(child_corr_ids) == 1 + + # Child turn 0 -> should dispatch turn 1 + orch.intercept( + _make_credit( + conv_id="parent_s0_c0", + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + ) + child_dispatches = [ + d for d in dispatched if d.conversation_id == "parent_s0_c0" + ] + assert len(child_dispatches) == 1 + assert child_dispatches[0].turn_index == 1 + + +# ============================================================================= +# Tests: Delayed join (join_at > spawn_at + 1) +# ============================================================================= + + +@pytest.mark.component_integration +class TestDelayedJoinAllModes: + """Verify delayed join works across all dataset configurations.""" + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + @pytest.mark.parametrize("delay_ms", [None, 200.0], ids=["no-delay", "delay-200ms"]) + def test_delayed_join_blocks_at_correct_turn( + self, timestamps: bool, delay_ms: float | None + ) -> None: + """Join at turn 5 (spawn at turn 2): parent flows through turns 3,4 freely.""" + ds = _make_conflux_dataset( + parent_turns=8, + spawn_at=2, + join_at=5, + num_children=2, + child_turns=3, + timestamps=timestamps, + delay_ms=delay_ms, + ) + orch, dispatched, scheduler = _build_orchestrator(ds) + result = _run_spawn_lifecycle( + orch, + dispatched, + scheduler, + parent_turns=8, + spawn_at=2, + join_at=5, + num_children=2, + child_turns=3, + ) + + assert result["stats"]["subagent_parents_suspended"] == 1 + assert result["stats"]["subagent_parents_resumed"] == 1 + assert len(result["gated_dispatches"]) == 1 + assert result["gated_dispatches"][0].turn_index == 5 + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + def test_delayed_join_children_complete_before_parent_reaches_gate( + self, timestamps: bool + ) -> None: + """If children complete before parent reaches gated turn, no suspension occurs.""" + ds = _make_conflux_dataset( + parent_turns=8, + spawn_at=2, + join_at=6, + num_children=1, + child_turns=2, + timestamps=timestamps, + ) + orch, dispatched, _ = _build_orchestrator(ds) + + # Parent turn 2 (spawn) + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=2, num_turns=8) + ) + + # Children dispatched via future gate + child_corr_ids = list(orch._child_to_gate.keys()) + + # Complete child fully (both turns) + for t in range(2): + orch.intercept( + _make_credit( + conv_id="parent_s0_c0", + corr_id=child_corr_ids[0], + turn_index=t, + num_turns=2, + agent_depth=1, + ) + ) + + # Future gate should be cleaned up + assert ( + "p-1" not in orch._future_gates + or len(orch._future_gates.get("p-1", {})) == 0 + ) + + # Parent continues past spawn turn without suspension + for t in range(3, 6): + result = orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=t, num_turns=8) + ) + assert result is False + + assert orch._stats.parents_suspended == 0 + + +# ============================================================================= +# Tests: Background spawns +# ============================================================================= + + +@pytest.mark.component_integration +class TestBackgroundSpawnAllModes: + """Background spawns do not gate the parent, across all dataset configurations.""" + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + @pytest.mark.parametrize("delay_ms", [None, 200.0], ids=["no-delay", "delay-200ms"]) + def test_background_spawn_does_not_gate( + self, timestamps: bool, delay_ms: float | None + ) -> None: + ds = _make_conflux_dataset( + parent_turns=6, + spawn_at=2, + num_children=2, + child_turns=3, + is_background=True, + timestamps=timestamps, + delay_ms=delay_ms, + ) + orch, dispatched, scheduler = _build_orchestrator(ds) + + # Advance parent to spawn + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=2, num_turns=6) + ) + + # Parent not suspended + assert "p-1" not in orch._gated_turns + assert len(orch._child_to_gate) == 0 + assert orch._stats.parents_suspended == 0 + + # Background children dispatched via dispatch_fn + bg_dispatches = [d for d in dispatched if d.agent_depth == 1] + assert len(bg_dispatches) == 2 + assert orch._stats.children_spawned == 2 + + # scheduler.execute_async NOT used for background + assert scheduler.execute_async.call_count == 0 + + +# ============================================================================= +# Tests: Multiple spawns on different turns +# ============================================================================= + + +@pytest.mark.component_integration +class TestMultiSpawnComposition: + """Multiple spawns on different turns compose correctly.""" + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + def test_two_sequential_spawns(self, timestamps: bool) -> None: + """s0 at turn 1 (join turn 2), s1 at turn 3 (join turn 4).""" + ds = _make_multi_spawn_dataset(timestamps=timestamps) + orch, dispatched, scheduler = _build_orchestrator(ds) + + # Turn 0: no spawn + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=0, num_turns=6) + ) + + # Turn 1: spawn s0 + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=1, num_turns=6) + ) + assert orch._stats.children_spawned == 2 + assert orch._stats.parents_suspended == 1 + + # Complete s0 children + s0_child_corr_ids = [ + cid + for cid, entry in orch._child_to_gate.items() + if entry.prereq_key == "spawn_join:s0" + ] + for ci, corr_id in enumerate(s0_child_corr_ids): + for t in range(2): + orch.intercept( + _make_credit( + conv_id=f"parent_s0_c{ci}", + corr_id=corr_id, + turn_index=t, + num_turns=2, + agent_depth=1, + ) + ) + + # Gated turn 2 dispatched + gated_s0 = [ + d for d in dispatched if d.conversation_id == "parent" and d.turn_index == 2 + ] + assert len(gated_s0) == 1 + assert orch._stats.parents_resumed == 1 + + # Turn 2 returns -> parent continues to turn 3 + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=2, num_turns=6) + ) + + # Turn 3: spawn s1 + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=3, num_turns=6) + ) + assert orch._stats.children_spawned == 3 # 2 from s0 + 1 from s1 + assert orch._stats.parents_suspended == 2 + + # Complete s1 child + s1_child_corr_ids = [ + cid + for cid, entry in orch._child_to_gate.items() + if entry.prereq_key == "spawn_join:s1" + ] + for corr_id in s1_child_corr_ids: + for t in range(2): + orch.intercept( + _make_credit( + conv_id="parent_s1_c0", + corr_id=corr_id, + turn_index=t, + num_turns=2, + agent_depth=1, + ) + ) + + # Gated turn 4 dispatched + gated_s1 = [ + d for d in dispatched if d.conversation_id == "parent" and d.turn_index == 4 + ] + assert len(gated_s1) == 1 + assert orch._stats.parents_resumed == 2 + + # Final stats + stats = orch.get_stats() + assert stats["subagent_children_spawned"] == 3 + assert stats["subagent_children_completed"] == 3 + assert stats["subagent_children_errored"] == 0 + assert stats["subagent_parents_suspended"] == 2 + assert stats["subagent_parents_resumed"] == 2 + + +# ============================================================================= +# Tests: Error/cancellation releases gates +# ============================================================================= + + +@pytest.mark.component_integration +class TestErrorCancellationReleasesGate: + """Errored/cancelled children release gates correctly.""" + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + def test_terminate_child_releases_gate(self, timestamps: bool) -> None: + """Terminating all children releases the gate and dispatches gated turn.""" + ds = _make_conflux_dataset( + parent_turns=6, + spawn_at=2, + num_children=2, + child_turns=3, + timestamps=timestamps, + ) + orch, dispatched, _ = _build_orchestrator(ds) + + # Advance parent to spawn + for t in range(3): + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=t, num_turns=6) + ) + + child_corr_ids = list(orch._child_to_gate.keys()) + assert len(child_corr_ids) == 2 + + # Terminate both children (simulating error/cancel via SubagentMixin.on_failed_credit) + for corr_id in child_corr_ids: + credit = _make_credit( + conv_id="parent_s0_c0", + corr_id=corr_id, + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.terminate_child(credit) + + # Gate released, gated turn dispatched + assert "p-1" not in orch._gated_turns + gated = [d for d in dispatched if d.conversation_id == "parent"] + assert len(gated) == 1 + assert gated[0].turn_index == 3 + assert orch._stats.parents_resumed == 1 + assert orch._stats.children_errored == 2 + + +# ============================================================================= +# Tests: Stop condition suppresses gated turn dispatch +# ============================================================================= + + +@pytest.mark.component_integration +class TestStopConditionSuppression: + """When stop condition fires, gated turn dispatch is suppressed.""" + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + def test_gated_turn_suppressed_when_stop_fired(self, timestamps: bool) -> None: + ds = _make_conflux_dataset( + parent_turns=6, + spawn_at=2, + num_children=1, + child_turns=2, + timestamps=timestamps, + ) + orch, dispatched, _ = _build_orchestrator(ds, stop_can_send=False) + + # Advance parent to spawn + for t in range(3): + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=t, num_turns=6) + ) + + child_corr_ids = list(orch._child_to_gate.keys()) + + # Complete child + for t in range(2): + orch.intercept( + _make_credit( + conv_id="parent_s0_c0", + corr_id=child_corr_ids[0], + turn_index=t, + num_turns=2, + agent_depth=1, + ) + ) + + # Gate released but gated turn suppressed + assert "p-1" not in orch._gated_turns + parent_dispatches = [d for d in dispatched if d.conversation_id == "parent"] + assert len(parent_dispatches) == 0 + assert orch._stats.joins_suppressed == 1 + + +# ============================================================================= +# Tests: Timestamp verification +# ============================================================================= + + +@pytest.mark.component_integration +class TestTimestampPropagation: + """Verify timestamps on metadata are correctly preserved in the dataset.""" + + def test_timestamps_preserved_in_turn_metadata(self) -> None: + """Absolute timestamps on parent turns are accessible via ConversationSource.""" + ds = _make_conflux_dataset( + parent_turns=4, + spawn_at=1, + num_children=1, + child_turns=2, + timestamps=True, + timestamp_base_ms=2000, + timestamp_spacing_ms=300, + ) + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids) + src = ConversationSource(ds, sampler) + + for i in range(4): + meta = src.get_turn_metadata_at("parent", i) + expected_ts = 2000 + i * 300 + assert meta.timestamp_ms == expected_ts, ( + f"Turn {i}: expected ts={expected_ts}, got {meta.timestamp_ms}" + ) + + def test_child_timestamps_offset_from_spawn(self) -> None: + """Child turn timestamps start after the spawn turn.""" + ds = _make_conflux_dataset( + parent_turns=4, + spawn_at=1, + num_children=1, + child_turns=3, + timestamps=True, + timestamp_base_ms=1000, + timestamp_spacing_ms=500, + ) + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids) + src = ConversationSource(ds, sampler) + + parent_spawn_ts = src.get_turn_metadata_at("parent", 1).timestamp_ms + assert parent_spawn_ts == 1500 # 1000 + 1 * 500 + + for j in range(3): + child_meta = src.get_turn_metadata_at("parent_s0_c0", j) + assert child_meta.timestamp_ms is not None + # Children start 100ms after spawn turn + expected = 1500 + 100 + j * 500 + assert child_meta.timestamp_ms == expected, ( + f"Child turn {j}: expected ts={expected}, got {child_meta.timestamp_ms}" + ) + + def test_delay_ms_on_subsequent_turns(self) -> None: + """delay_ms is set on subsequent turns (turn_index > 0) when configured.""" + ds = _make_conflux_dataset( + parent_turns=4, + spawn_at=1, + num_children=1, + child_turns=2, + delay_ms=150.0, + ) + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids) + src = ConversationSource(ds, sampler) + + # First turn has no delay + assert src.get_turn_metadata_at("parent", 0).delay_ms is None + # Subsequent turns have delay + for i in range(1, 4): + meta = src.get_turn_metadata_at("parent", i) + assert meta.delay_ms == 150.0, ( + f"Turn {i}: expected delay=150.0, got {meta.delay_ms}" + ) + + def test_prerequisites_on_gated_turns(self) -> None: + """TurnPrerequisite is correctly set on the gated turn.""" + ds = _make_conflux_dataset( + parent_turns=6, + spawn_at=2, + join_at=4, + num_children=2, + child_turns=3, + ) + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids) + src = ConversationSource(ds, sampler) + + # Turn 4 should have spawn_join prerequisite + gated_meta = src.get_turn_metadata_at("parent", 4) + assert len(gated_meta.prerequisites) == 1 + prereq = gated_meta.prerequisites[0] + assert prereq.kind == PrerequisiteKind.SPAWN_JOIN + assert prereq.spawn_id == "s0" + + # Other turns should not have prerequisites + for i in [0, 1, 2, 3, 5]: + meta = src.get_turn_metadata_at("parent", i) + assert len(meta.prerequisites) == 0, ( + f"Turn {i} should have no prerequisites" + ) + + def test_spawn_ids_on_spawn_turn(self) -> None: + """subagent_spawn_ids is set only on the spawn turn.""" + ds = _make_conflux_dataset(parent_turns=6, spawn_at=2) + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids) + src = ConversationSource(ds, sampler) + + for i in range(6): + meta = src.get_turn_metadata_at("parent", i) + if i == 2: + assert meta.subagent_spawn_ids == ["s0"] + else: + assert meta.subagent_spawn_ids == [], ( + f"Turn {i} should have no spawn IDs" + ) + + def test_multi_spawn_timestamps_independent(self) -> None: + """Multi-spawn dataset has independent timestamps for each spawn's children.""" + ds = _make_multi_spawn_dataset(timestamps=True) + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids) + src = ConversationSource(ds, sampler) + + # Verify parent timestamps are sequential + parent_timestamps = [ + src.get_turn_metadata_at("parent", i).timestamp_ms for i in range(6) + ] + assert parent_timestamps == [1000, 1500, 2000, 2500, 3000, 3500] + + # s0 children share same timestamps + s0_c0_ts = [ + src.get_turn_metadata_at("parent_s0_c0", j).timestamp_ms for j in range(2) + ] + s0_c1_ts = [ + src.get_turn_metadata_at("parent_s0_c1", j).timestamp_ms for j in range(2) + ] + assert s0_c0_ts == s0_c1_ts + + # s1 child has same base timestamps (same formula) + s1_c0_ts = [ + src.get_turn_metadata_at("parent_s1_c0", j).timestamp_ms for j in range(2) + ] + assert s1_c0_ts == s0_c0_ts # same formula in helper + + +# ============================================================================= +# Tests: Gated turn metadata propagation +# ============================================================================= + + +@pytest.mark.component_integration +class TestGatedTurnMetadataPropagation: + """Verify dispatched gated turns have correct metadata.""" + + def test_gated_turn_preserves_parent_correlation_id(self) -> None: + """The gated turn's x_correlation_id matches the parent's.""" + ds = _make_conflux_dataset( + parent_turns=6, spawn_at=2, num_children=1, child_turns=2 + ) + orch, dispatched, _ = _build_orchestrator(ds) + + # Advance parent + for t in range(3): + orch.intercept( + _make_credit( + conv_id="parent", corr_id="p-42", turn_index=t, num_turns=6 + ) + ) + + # Complete child + child_corr_ids = list(orch._child_to_gate.keys()) + for t in range(2): + orch.intercept( + _make_credit( + conv_id="parent_s0_c0", + corr_id=child_corr_ids[0], + turn_index=t, + num_turns=2, + agent_depth=1, + ) + ) + + gated = [d for d in dispatched if d.conversation_id == "parent"] + assert len(gated) == 1 + assert gated[0].x_correlation_id == "p-42" + assert gated[0].turn_index == 3 + assert gated[0].num_turns == 6 + + def test_child_dispatched_turns_have_correct_agent_depth(self) -> None: + """Children dispatched by the orchestrator have agent_depth=1.""" + ds = _make_conflux_dataset( + parent_turns=6, spawn_at=2, num_children=2, child_turns=3 + ) + orch, dispatched, _ = _build_orchestrator(ds) + + # Advance parent to spawn + for t in range(3): + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=t, num_turns=6) + ) + + # Child next-turn dispatches from non-final child turns + child_corr_ids = list(orch._child_to_gate.keys()) + orch.intercept( + _make_credit( + conv_id="parent_s0_c0", + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + ) + + child_dispatches = [d for d in dispatched if d.agent_depth == 1] + assert len(child_dispatches) >= 1 + for d in child_dispatches: + assert d.agent_depth == 1 + + +# ============================================================================= +# Tests: Cleanup after partial lifecycle +# ============================================================================= + + +@pytest.mark.component_integration +class TestCleanupWithTimestamps: + """Verify cleanup clears all state regardless of timestamp configuration.""" + + @pytest.mark.parametrize("timestamps", [False, True], ids=["no-ts", "with-ts"]) + def test_cleanup_clears_all_state(self, timestamps: bool) -> None: + ds = _make_conflux_dataset( + parent_turns=6, + spawn_at=2, + num_children=2, + child_turns=3, + timestamps=timestamps, + ) + orch, _, _ = _build_orchestrator(ds) + + # Trigger spawn + for t in range(3): + orch.intercept( + _make_credit(conv_id="parent", corr_id="p-1", turn_index=t, num_turns=6) + ) + + assert len(orch._gated_turns) > 0 + assert len(orch._child_to_gate) > 0 + + orch.cleanup() + + assert len(orch._gated_turns) == 0 + assert len(orch._future_gates) == 0 + assert len(orch._child_to_gate) == 0 + assert len(orch._terminated_children) == 0 + assert orch._cleaning_up is True + + # Post-cleanup intercepts are no-ops + assert ( + orch.intercept( + _make_credit( + conv_id="parent", corr_id="post", turn_index=0, num_turns=6 + ) + ) + is False + ) diff --git a/tests/component_integration/dataset/test_dataset_sampling.py b/tests/component_integration/dataset/test_dataset_sampling.py index 8875631a5..0abdceed6 100644 --- a/tests/component_integration/dataset/test_dataset_sampling.py +++ b/tests/component_integration/dataset/test_dataset_sampling.py @@ -208,7 +208,7 @@ def test_sampling_with_multi_turn(self, cli: AIPerfCLI): --streaming \ --dataset-sampling-strategy shuffle \ --random-seed 42 \ - --num-sessions 10 \ + --num-sessions 6 \ --session-turns-mean 3 \ --session-turns-stddev 0 \ --workers-max {defaults.workers_max} \ @@ -217,8 +217,8 @@ def test_sampling_with_multi_turn(self, cli: AIPerfCLI): timeout=60.0, ) - # Should have 10 sessions × 3 turns = 30 requests - assert len(result.jsonl) == 30 + # Should have 6 sessions × 3 turns = 18 requests + assert len(result.jsonl) == 18 # Verify turn indices are sequential within each session analyzer = CreditFlowAnalyzer(result.runner_result) diff --git a/tests/component_integration/dataset/test_subagent_orchestrator.py b/tests/component_integration/dataset/test_subagent_orchestrator.py new file mode 100644 index 000000000..65d3a5c86 --- /dev/null +++ b/tests/component_integration/dataset/test_subagent_orchestrator.py @@ -0,0 +1,386 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Component integration tests for SubagentOrchestrator with real dataset pipeline. + +Exercises the full lifecycle: + DatasetMetadata -> ConversationSource -> SubagentOrchestrator + -> spawn resolution -> child dispatch -> child completion -> gated turn dispatch + +Uses real dataset generation (no mocks for data model) with a capturing dispatch_fn +to verify the orchestrator state machine end-to-end. +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from aiperf.common.enums import CreditPhase, PrerequisiteKind +from aiperf.common.models import ( + ConversationMetadata, + DatasetMetadata, + SubagentSpawnInfo, + TurnMetadata, + TurnPrerequisite, +) +from aiperf.credit.structs import Credit, TurnToSend +from aiperf.plugin import plugins +from aiperf.plugin.enums import DatasetSamplingStrategy, PluginType +from aiperf.timing.conversation_source import ConversationSource +from aiperf.timing.subagent_orchestrator import SubagentOrchestrator + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _make_sampler(conv_ids, strategy=DatasetSamplingStrategy.SEQUENTIAL): + SamplerClass = plugins.get_class(PluginType.DATASET_SAMPLER, strategy) + return SamplerClass(conversation_ids=conv_ids) + + +def _make_credit( + *, + conv_id: str, + corr_id: str, + turn_index: int, + num_turns: int, + agent_depth: int = 0, +) -> Credit: + return Credit( + id=1, + phase=CreditPhase.PROFILING, + conversation_id=conv_id, + x_correlation_id=corr_id, + turn_index=turn_index, + num_turns=num_turns, + issued_at_ns=0, + agent_depth=agent_depth, + ) + + +def _build_orchestrator_from_metadata( + ds: DatasetMetadata, +) -> tuple[SubagentOrchestrator, list[TurnToSend], MagicMock]: + """Build orchestrator from dataset metadata with capturing dispatch_fn.""" + root_ids = [c.conversation_id for c in ds.conversations if c.agent_depth == 0] + sampler = _make_sampler(root_ids, ds.sampling_strategy) + src = ConversationSource(ds, sampler) + + scheduler = MagicMock() + scheduler.execute_async = MagicMock() + dispatched: list[TurnToSend] = [] + + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=MagicMock(issue_credit=AsyncMock(return_value=True)), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + scheduler=scheduler, + dispatch_fn=lambda turn: dispatched.append(turn), + ) + return orch, dispatched, scheduler + + +def _make_handcrafted_dataset( + *, + parent_turns: int = 6, + spawn_at: int = 2, + join_at: int | None = None, + num_children: int = 2, + child_turns: int = 3, + is_background: bool = False, +) -> DatasetMetadata: + """Create a handcrafted dataset with one parent and N children.""" + join_at = spawn_at + 1 if join_at is None else join_at + child_conv_ids = [f"conv_0_s0_c{ci}" for ci in range(num_children)] + spawn = SubagentSpawnInfo( + spawn_id="s0", + child_conversation_ids=child_conv_ids, + is_background=is_background, + ) + + turns = [] + for i in range(parent_turns): + spawn_ids = ["s0"] if i == spawn_at else [] + prereqs = [] + if i == join_at and not is_background: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0") + ] + turns.append( + TurnMetadata( + delay_ms=200.0 if i > 0 else None, + input_tokens=500 + i * 100, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="conv_0", + turns=turns, + subagent_spawns=[spawn], + ) + ] + for child_id in child_conv_ids: + convs.append( + ConversationMetadata( + conversation_id=child_id, + turns=[ + TurnMetadata(input_tokens=300 + j * 50) for j in range(child_turns) + ], + agent_depth=1, + ) + ) + + return DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + + +# ============================================================================= +# Full lifecycle with handcrafted dataset +# ============================================================================= + + +@pytest.mark.component_integration +class TestSubagentOrchestratorFullLifecycle: + """Walk through the full orchestrator lifecycle with a handcrafted dataset.""" + + def test_full_lifecycle_blocking_spawn(self): + """Root -> spawn turn -> children dispatched -> children complete -> gated turn dispatched.""" + ds = _make_handcrafted_dataset( + parent_turns=6, spawn_at=2, num_children=2, child_turns=3 + ) + orch, dispatched, scheduler = _build_orchestrator_from_metadata(ds) + + # Step 1: Root turn 0 credit returns -- no spawn + credit_t0 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=0, num_turns=6 + ) + handled = orch.intercept(credit_t0) + assert handled is False + assert len(dispatched) == 0 + + # Step 2: Root turn 1 credit returns -- no spawn + credit_t1 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=1, num_turns=6 + ) + handled = orch.intercept(credit_t1) + assert handled is False + + # Step 3: Root turn 2 (spawn turn) credit returns -- intercept resolves spawns + credit_t2 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + handled = orch.intercept(credit_t2) + assert handled is True + + # Blocking children dispatched via scheduler.execute_async + assert scheduler.execute_async.call_count == 2 + assert orch._stats.children_spawned == 2 + assert orch._stats.parents_suspended == 1 + + # Pending gate created + assert "parent-1" in orch._gated_turns + gate = orch._gated_turns["parent-1"] + assert gate.outstanding == {"spawn_join:s0": [2, 0]} + assert gate.gated_turn_index == 3 + + # Child-to-gate mapping + child_corr_ids = list(orch._child_to_gate.keys()) + assert len(child_corr_ids) == 2 + for entry in orch._child_to_gate.values(): + assert entry.parent_corr_id == "parent-1" + + child_conv_ids = ["conv_0_s0_c0", "conv_0_s0_c1"] + + # Step 4: Child 0 non-final turns dispatch next turn via dispatch_fn + child0_t0 = _make_credit( + conv_id=child_conv_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child0_t0) + assert len(dispatched) == 1 + assert dispatched[0].conversation_id == child_conv_ids[0] + assert dispatched[0].turn_index == 1 + + # Child 0 turn 1 (non-final) + child0_t1 = _make_credit( + conv_id=child_conv_ids[0], + corr_id=child_corr_ids[0], + turn_index=1, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child0_t1) + assert len(dispatched) == 2 + assert dispatched[1].turn_index == 2 + + # Step 5: Child 0 final turn -> gate accounting + child0_t2 = _make_credit( + conv_id=child_conv_ids[0], + corr_id=child_corr_ids[0], + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child0_t2) + assert orch._stats.children_completed == 1 + assert gate.outstanding["spawn_join:s0"][1] == 1 + # No gated turn yet (1 of 2) + join_dispatches = [d for d in dispatched if d.conversation_id == "conv_0"] + assert len(join_dispatches) == 0 + + # Step 6: Child 1 all turns (fast-forward to final) + child1_final = _make_credit( + conv_id=child_conv_ids[1], + corr_id=child_corr_ids[1], + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child1_final) + + # Step 7: All children done -> gated turn dispatched + assert orch._stats.children_completed == 2 + assert orch._stats.parents_resumed == 1 + assert "parent-1" not in orch._gated_turns + + join_dispatches = [d for d in dispatched if d.conversation_id == "conv_0"] + assert len(join_dispatches) == 1 + join_turn = join_dispatches[0] + assert join_turn.turn_index == 3 # spawn_at(2) + 1 + assert join_turn.x_correlation_id == "parent-1" + assert join_turn.num_turns == 6 + + # Step 8: Verify stats + stats = orch.get_stats() + assert stats["subagent_children_spawned"] == 2 + assert stats["subagent_children_completed"] == 2 + assert stats["subagent_children_errored"] == 0 + assert stats["subagent_parents_suspended"] == 1 + assert stats["subagent_parents_resumed"] == 1 + assert stats["subagent_joins_suppressed"] == 0 + + def test_full_lifecycle_delayed_join_blocks_on_target_turn(self): + """A later spawn_join blocks only when the parent reaches that turn.""" + ds = _make_handcrafted_dataset( + parent_turns=7, + spawn_at=2, + join_at=5, + num_children=2, + child_turns=3, + ) + orch, dispatched, scheduler = _build_orchestrator_from_metadata(ds) + + credit_t2 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=7 + ) + handled = orch.intercept(credit_t2) + assert handled is False + assert scheduler.execute_async.call_count == 2 + assert orch._stats.parents_suspended == 0 + assert orch._future_gates["parent-1"][5].outstanding == { + "spawn_join:s0": [2, 0] + } + + credit_t3 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=3, num_turns=7 + ) + assert orch.intercept(credit_t3) is False + + credit_t4 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=4, num_turns=7 + ) + assert orch.intercept(credit_t4) is True + assert orch._stats.parents_suspended == 1 + assert orch._gated_turns["parent-1"].gated_turn_index == 5 + + child_corr_ids = list(orch._child_to_gate.keys()) + child_conv_ids = ["conv_0_s0_c0", "conv_0_s0_c1"] + for i, corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_conv_ids[i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + assert "parent-1" not in orch._gated_turns + assert orch._stats.parents_resumed == 1 + + join_dispatches = [d for d in dispatched if d.conversation_id == "conv_0"] + assert len(join_dispatches) == 1 + assert join_dispatches[0].turn_index == 5 + stats = orch.get_stats() + assert stats["subagent_children_spawned"] == 2 + assert stats["subagent_children_completed"] == 2 + assert stats["subagent_parents_suspended"] == 1 + assert stats["subagent_parents_resumed"] == 1 + + def test_full_lifecycle_background_spawn(self): + """Background spawn: parent not suspended, children dispatched via dispatch_fn.""" + ds = _make_handcrafted_dataset( + parent_turns=6, + spawn_at=2, + num_children=2, + child_turns=3, + is_background=True, + ) + orch, dispatched, scheduler = _build_orchestrator_from_metadata(ds) + + credit_t2 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + handled = orch.intercept(credit_t2) + + # Background: parent not suspended + assert handled is False + assert "parent-1" not in orch._gated_turns + assert len(orch._child_to_gate) == 0 + + # Background children dispatched via dispatch_fn + bg_dispatches = [d for d in dispatched if d.agent_depth == 1] + assert len(bg_dispatches) == 2 + assert orch._stats.children_spawned == 2 + + # scheduler.execute_async NOT used for background + assert scheduler.execute_async.call_count == 0 + + def test_cleanup_after_partial_lifecycle(self): + """Cleanup after partial lifecycle clears all state.""" + ds = _make_handcrafted_dataset( + parent_turns=6, spawn_at=2, num_children=2, child_turns=3 + ) + orch, _, _ = _build_orchestrator_from_metadata(ds) + + # Trigger a spawn + credit = _make_credit( + conv_id="conv_0", + corr_id="partial-corr", + turn_index=2, + num_turns=6, + ) + orch.intercept(credit) + + # Cleanup without completing children + orch.cleanup() + + assert len(orch._gated_turns) == 0 + assert len(orch._future_gates) == 0 + assert len(orch._child_to_gate) == 0 + assert len(orch._terminated_children) == 0 + assert orch._cleaning_up is True + + # Further intercepts are no-ops + credit = _make_credit( + conv_id="conv_0", corr_id="post-cleanup", turn_index=0, num_turns=5 + ) + assert orch.intercept(credit) is False diff --git a/tests/component_integration/sessions/test_sticky_routing.py b/tests/component_integration/sessions/test_sticky_routing.py index 091cfd797..1bdb8337c 100644 --- a/tests/component_integration/sessions/test_sticky_routing.py +++ b/tests/component_integration/sessions/test_sticky_routing.py @@ -64,21 +64,21 @@ def test_sticky_routing_high_concurrency_multi_turn(self, cli: AIPerfCLI): --model {defaults.model} \ --streaming \ --request-rate {TEST_QPS} \ - --num-sessions 15 \ - --session-turns-mean 3 \ + --num-sessions 10 \ + --session-turns-mean 2 \ --session-turns-stddev 0 \ - --workers-max 10 \ + --workers-max 6 \ --ui {defaults.ui} """ ) - # 15 sessions × 3 turns = 45 requests - assert result.request_count == 45 + # 10 sessions × 2 turns = 20 requests + assert result.request_count == 20 # Verify sticky routing and all sessions completed all turns sessions = assert_sticky_routing(result.jsonl) - assert len(sessions) == 15 + assert len(sessions) == 10 for session_id, records in sessions.items(): - assert len(records) == 3, f"Session {session_id} incomplete" + assert len(records) == 2, f"Session {session_id} incomplete" assert_jsonl_turns_sequential(result.jsonl) @pytest.mark.slow @@ -94,15 +94,15 @@ def test_sticky_routing_with_concurrency_limit(self, cli: AIPerfCLI): --streaming \ --request-rate {TEST_QPS} \ --concurrency 3 \ - --num-sessions 9 \ - --session-turns-mean 4 \ + --num-sessions 6 \ + --session-turns-mean 3 \ --session-turns-stddev 0 \ --workers-max {defaults.workers_max} \ --ui {defaults.ui} """ ) - # 9 sessions × 4 turns = 36 requests - assert result.request_count == 36 + # 6 sessions × 3 turns = 18 requests + assert result.request_count == 18 # Verify sticky routing assert_sticky_routing(result.jsonl) @@ -124,8 +124,8 @@ def test_sticky_routing_with_turn_delays(self, cli: AIPerfCLI): --num-sessions 6 \ --session-turns-mean 3 \ --session-turns-stddev 0 \ - --session-turn-delay-mean 100 \ - --session-turn-delay-stddev 20 \ + --session-turn-delay-mean 50 \ + --session-turn-delay-stddev 10 \ --workers-max {defaults.workers_max} \ --ui {defaults.ui} """ diff --git a/tests/component_integration/timing/conftest.py b/tests/component_integration/timing/conftest.py index b7531fac0..6a786bb4e 100644 --- a/tests/component_integration/timing/conftest.py +++ b/tests/component_integration/timing/conftest.py @@ -481,9 +481,9 @@ def test_credits_per_session(self, cli: AIPerfCLI): def test_turn_indices_sequential(self, cli: AIPerfCLI): """Verify turn indices are sequential per session.""" config = TimingTestConfig( - num_sessions=10, - qps=50.0, - turns_per_session=5, + num_sessions=8, + qps=80.0, + turns_per_session=3, ) cmd = self.build_command(config) result = cli.run_sync(cmd, timeout=config.timeout) @@ -544,7 +544,7 @@ def build_command(self, config: TimingTestConfig) -> str: def test_with_concurrency_limit(self, cli: AIPerfCLI, concurrency: int, qps: float): """Test timing mode respects and reaches concurrency limit.""" config = TimingTestConfig( - num_sessions=50, + num_sessions=30, qps=qps, concurrency=concurrency, osl=50, # Need longer OSL to hit concurrency limits diff --git a/tests/component_integration/timing/test_advanced_scenarios.py b/tests/component_integration/timing/test_advanced_scenarios.py index 3cfa49077..57c8ee58c 100644 --- a/tests/component_integration/timing/test_advanced_scenarios.py +++ b/tests/component_integration/timing/test_advanced_scenarios.py @@ -138,10 +138,10 @@ def test_cancellation_rate_multi_turn_basic(self, cli: AIPerfCLI): - Verify all credits returned (with errors) """ config = TimingTestConfig( - num_sessions=10, + num_sessions=8, qps=0, - turns_per_session=4, - concurrency=10, + turns_per_session=3, + concurrency=8, ) cmd = f""" diff --git a/tests/component_integration/timing/test_constant_rate.py b/tests/component_integration/timing/test_constant_rate.py index 3d2a4646a..1a70d9344 100644 --- a/tests/component_integration/timing/test_constant_rate.py +++ b/tests/component_integration/timing/test_constant_rate.py @@ -60,9 +60,9 @@ def test_constant_rate_completes( def test_constant_rate_multi_turn(self, cli: AIPerfCLI): """Test constant rate with multi-turn conversations.""" config = TimingTestConfig( - num_sessions=15, - qps=75.0, - turns_per_session=4, + num_sessions=8, + qps=100.0, + turns_per_session=3, ) cmd = build_timing_command(config, arrival_pattern="constant") result = cli.run_sync(cmd, timeout=config.timeout) @@ -157,9 +157,9 @@ def test_high_volume(self, cli: AIPerfCLI): def test_sustained_multi_turn(self, cli: AIPerfCLI): """Test sustained multi-turn workload.""" config = TimingTestConfig( - num_sessions=20, - qps=100.0, - turns_per_session=5, + num_sessions=10, + qps=150.0, + turns_per_session=3, timeout=90.0, ) cmd = build_timing_command(config, arrival_pattern="constant") diff --git a/tests/component_integration/timing/test_fixed_schedule.py b/tests/component_integration/timing/test_fixed_schedule.py index cfd5e4930..a58fa1e67 100644 --- a/tests/component_integration/timing/test_fixed_schedule.py +++ b/tests/component_integration/timing/test_fixed_schedule.py @@ -558,10 +558,10 @@ class TestFixedScheduleStress: def test_high_turn_count(self, cli: AIPerfCLI, tmp_path: Path): """Test with high number of turns per session.""" config = FixedScheduleTestConfig( - num_sessions=20, - turns_per_session=25, + num_sessions=5, + turns_per_session=20, delay_ms=1, - workers_max=5, + workers_max=3, timeout=120.0, ) trace_file = generate_trace_file(tmp_path, config) diff --git a/tests/component_integration/timing/test_poisson_rate.py b/tests/component_integration/timing/test_poisson_rate.py index 004646c17..7c343c9a1 100644 --- a/tests/component_integration/timing/test_poisson_rate.py +++ b/tests/component_integration/timing/test_poisson_rate.py @@ -64,9 +64,9 @@ def test_poisson_rate_completes( def test_poisson_rate_multi_turn(self, cli: AIPerfCLI): """Test Poisson rate with multi-turn conversations.""" config = TimingTestConfig( - num_sessions=15, - qps=75.0, - turns_per_session=4, + num_sessions=8, + qps=100.0, + turns_per_session=3, ) cmd = build_timing_command(config, arrival_pattern="poisson") result = cli.run_sync(cmd, timeout=config.timeout) @@ -168,9 +168,9 @@ def test_high_volume(self, cli: AIPerfCLI): def test_sustained_multi_turn(self, cli: AIPerfCLI): """Test sustained multi-turn Poisson workload.""" config = TimingTestConfig( - num_sessions=20, - qps=100.0, - turns_per_session=5, + num_sessions=10, + qps=150.0, + turns_per_session=3, timeout=90.0, ) cmd = build_timing_command(config, arrival_pattern="poisson") diff --git a/tests/component_integration/timing/test_session_turn_delays.py b/tests/component_integration/timing/test_session_turn_delays.py index b94b6c142..d7ca69b8f 100644 --- a/tests/component_integration/timing/test_session_turn_delays.py +++ b/tests/component_integration/timing/test_session_turn_delays.py @@ -137,13 +137,13 @@ def test_turn_delay_with_concurrency_limit(self, cli: AIPerfCLI): aiperf profile \ --model {defaults.model} \ --streaming \ - --num-sessions 20 \ - --session-turns-mean 4 \ + --num-sessions 10 \ + --session-turns-mean 3 \ --session-turns-stddev 0 \ - --session-turn-delay-mean 100 \ + --session-turn-delay-mean 50 \ --request-rate 200 \ --request-rate-mode constant \ - --concurrency 8 \ + --concurrency 5 \ --osl 50 \ --extra-inputs ignore_eos:true \ --ui {defaults.ui} @@ -151,7 +151,7 @@ def test_turn_delay_with_concurrency_limit(self, cli: AIPerfCLI): result = cli.run_sync(cmd, timeout=40.0) - assert result.request_count == 80 # 20 × 4 + assert result.request_count == 30 # 10 × 3 runner = result.runner_result credit_analyzer = CreditFlowAnalyzer(runner) diff --git a/tests/unit/common/config/test_input_config.py b/tests/unit/common/config/test_input_config.py index 2aa10f86f..041fd553b 100644 --- a/tests/unit/common/config/test_input_config.py +++ b/tests/unit/common/config/test_input_config.py @@ -359,3 +359,41 @@ def test_synthesis_max_osl_requires_trace_dataset(dataset_type): ) assert "require a trace dataset type" in str(exc.value) + + +def test_double_speedup_raises_error(): + """Test that combining synthesis speedup with fixed-schedule speedup is rejected.""" + with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: + with pytest.raises(ValidationError) as exc: + InputConfig( + file=temp_file.name, + custom_dataset_type="mooncake_trace", + synthesis=SynthesisConfig(speedup_ratio=2.0), + fixed_schedule_speedup=2.0, + ) + + assert "--synthesis-speedup-ratio and --fixed-schedule-speedup" in str( + exc.value + ) + + +def test_fixed_schedule_speedup_alone_succeeds(): + """Test that fixed-schedule speedup without synthesis speedup is fine.""" + with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: + config = InputConfig( + file=temp_file.name, + fixed_schedule=True, + fixed_schedule_speedup=2.0, + ) + assert config.fixed_schedule_speedup == 2.0 + + +def test_synthesis_speedup_alone_succeeds(): + """Test that synthesis speedup without fixed-schedule speedup is fine.""" + with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: + config = InputConfig( + file=temp_file.name, + custom_dataset_type="mooncake_trace", + synthesis=SynthesisConfig(speedup_ratio=2.0), + ) + assert config.synthesis.speedup_ratio == 2.0 diff --git a/tests/unit/common/config/test_user_config_mooncake_trace.py b/tests/unit/common/config/test_user_config_mooncake_trace.py index 8ae4c0e05..d5e5b7e58 100644 --- a/tests/unit/common/config/test_user_config_mooncake_trace.py +++ b/tests/unit/common/config/test_user_config_mooncake_trace.py @@ -188,7 +188,7 @@ def test_count_dataset_entries_with_edge_cases(self, mock_is_file, mock_exists): class TestTraceDatasetTimingDetection: - """Test _should_use_fixed_schedule_for_trace_dataset() for automatic timing detection.""" + """Test _should_auto_enable_fixed_schedule() for automatic timing detection.""" @patch("pathlib.Path.exists", return_value=True) @patch("pathlib.Path.is_file", return_value=True) @@ -210,7 +210,7 @@ def test_mooncake_trace_with_timestamps_enables_fixed_schedule( ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - result = config._should_use_fixed_schedule_for_trace_dataset() + result = config._should_auto_enable_fixed_schedule() assert result is True @patch("pathlib.Path.exists", return_value=True) @@ -233,7 +233,7 @@ def test_mooncake_trace_without_timestamps_no_fixed_schedule( ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - result = config._should_use_fixed_schedule_for_trace_dataset() + result = config._should_auto_enable_fixed_schedule() assert result is False @patch("pathlib.Path.exists", return_value=True) @@ -251,7 +251,7 @@ def test_non_trace_dataset_no_auto_detection(self, mock_is_file, mock_exists): ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - result = config._should_use_fixed_schedule_for_trace_dataset() + result = config._should_auto_enable_fixed_schedule() assert result is False @patch("pathlib.Path.exists", return_value=True) @@ -281,7 +281,7 @@ def test_file_parsing_with_empty_lines_and_malformed_json( ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - has_timestamps = config._should_use_fixed_schedule_for_trace_dataset() + has_timestamps = config._should_auto_enable_fixed_schedule() assert has_timestamps is True @patch("pathlib.Path.exists", return_value=True) @@ -299,7 +299,7 @@ def test_empty_file_timing_detection(self, mock_is_file, mock_exists): ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - assert config._should_use_fixed_schedule_for_trace_dataset() is False + assert config._should_auto_enable_fixed_schedule() is False @patch("pathlib.Path.exists", return_value=True) @patch("pathlib.Path.is_file", return_value=True) @@ -321,4 +321,4 @@ def test_only_malformed_json_timing_detection(self, mock_is_file, mock_exists): ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - assert config._should_use_fixed_schedule_for_trace_dataset() is False + assert config._should_auto_enable_fixed_schedule() is False diff --git a/tests/unit/credit/test_callback_handler.py b/tests/unit/credit/test_callback_handler.py index 401d324f3..1232c292b 100644 --- a/tests/unit/credit/test_callback_handler.py +++ b/tests/unit/credit/test_callback_handler.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from pytest import param from aiperf.common.enums import CreditPhase from aiperf.credit.callback_handler import CreditCallbackHandler @@ -62,6 +63,7 @@ def mock_strategy(): """Mock timing strategy.""" mock = MagicMock() mock.handle_credit_return = AsyncMock() + mock.on_failed_credit = MagicMock() return mock @@ -113,12 +115,14 @@ def make_credit_return( credit: Credit, cancelled: bool = False, first_token_sent: bool = True, + error: str | None = None, ) -> CreditReturn: """Create a CreditReturn for testing.""" return CreditReturn( credit=credit, cancelled=cancelled, first_token_sent=first_token_sent, + error=error, ) @@ -169,6 +173,7 @@ async def test_on_credit_return_increments_returned_count( mock_progress.increment_returned.assert_called_once_with( credit.is_final_turn, False, # cancelled=False + agent_depth=0, ) async def test_on_credit_return_tracks_cancelled_status( @@ -183,6 +188,7 @@ async def test_on_credit_return_tracks_cancelled_status( mock_progress.increment_returned.assert_called_once_with( credit.is_final_turn, True, # cancelled=True + agent_depth=0, ) async def test_on_credit_return_releases_session_slot_on_final_turn( @@ -391,9 +397,108 @@ async def test_return_state_combinations( await registered_handler.on_credit_return("worker-1", credit_return) mock_progress.increment_returned.assert_called_once_with( - credit.is_final_turn, cancelled + credit.is_final_turn, cancelled, agent_depth=0 ) if not first_token_sent: mock_concurrency.release_prefill_slot.assert_called_once() else: mock_concurrency.release_prefill_slot.assert_not_called() + + +# ============================================================================= +# Test: on_failed_credit Notification +# ============================================================================= + + +class TestOnFailedCreditNotification: + """Tests for strategy.on_failed_credit notification on errored/cancelled returns.""" + + async def test_on_credit_return_error_calls_on_failed_credit( + self, registered_handler, mock_strategy + ) -> None: + """on_failed_credit is called when credit_return has an error.""" + credit = make_credit() + credit_return = make_credit_return(credit, error="connection refused") + + await registered_handler.on_credit_return("worker-1", credit_return) + + mock_strategy.on_failed_credit.assert_called_once_with(credit_return) + + async def test_on_credit_return_cancelled_calls_on_failed_credit( + self, registered_handler, mock_strategy + ) -> None: + """on_failed_credit is called when credit_return is cancelled.""" + credit = make_credit() + credit_return = make_credit_return(credit, cancelled=True) + + await registered_handler.on_credit_return("worker-1", credit_return) + + mock_strategy.on_failed_credit.assert_called_once_with(credit_return) + + async def test_on_credit_return_success_does_not_call_on_failed_credit( + self, registered_handler, mock_strategy + ) -> None: + """on_failed_credit is NOT called for successful returns.""" + credit = make_credit() + credit_return = make_credit_return(credit) # no error, not cancelled + + await registered_handler.on_credit_return("worker-1", credit_return) + + mock_strategy.on_failed_credit.assert_not_called() + + async def test_on_failed_credit_called_before_handle_credit_return( + self, registered_handler, mock_strategy + ) -> None: + """on_failed_credit must be called BEFORE handle_credit_return (ordering).""" + call_order: list[str] = [] + mock_strategy.on_failed_credit.side_effect = lambda cr: call_order.append( + "on_failed_credit" + ) + mock_strategy.handle_credit_return.side_effect = lambda c: call_order.append( + "handle_credit_return" + ) + + credit = make_credit() + credit_return = make_credit_return(credit, error="timeout") + + await registered_handler.on_credit_return("worker-1", credit_return) + + assert call_order == ["on_failed_credit", "handle_credit_return"] + + async def test_on_failed_credit_called_regardless_of_can_send_any_turn( + self, registered_handler, mock_strategy, mock_stop_checker + ) -> None: + """on_failed_credit fires even when can_send_any_turn is False.""" + mock_stop_checker.can_send_any_turn.return_value = False + + credit = make_credit() + credit_return = make_credit_return(credit, cancelled=True) + + await registered_handler.on_credit_return("worker-1", credit_return) + + mock_strategy.on_failed_credit.assert_called_once_with(credit_return) + # handle_credit_return should NOT be called (gated by can_send_any_turn) + mock_strategy.handle_credit_return.assert_not_called() + + @pytest.mark.parametrize( + "error,cancelled", + [ + param("connection refused", False, id="error-only"), + param(None, True, id="cancelled-only"), + param("timeout", True, id="error-and-cancelled"), + ], + ) # fmt: skip + async def test_on_failed_credit_called_for_all_failure_combinations( + self, + registered_handler, + mock_strategy, + error: str | None, + cancelled: bool, + ) -> None: + """on_failed_credit fires for any combination of error/cancelled.""" + credit = make_credit() + credit_return = make_credit_return(credit, error=error, cancelled=cancelled) + + await registered_handler.on_credit_return("worker-1", credit_return) + + mock_strategy.on_failed_credit.assert_called_once_with(credit_return) diff --git a/tests/unit/credit/test_child_session_bypass.py b/tests/unit/credit/test_child_session_bypass.py new file mode 100644 index 000000000..e610a413b --- /dev/null +++ b/tests/unit/credit/test_child_session_bypass.py @@ -0,0 +1,729 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Adversarial tests for child session (agent_depth > 0) bypass behavior. + +These tests verify that child sessions spawned by SubagentOrchestrator +correctly skip session slot acquisition, session quota counting, and +session slot release. Without these fixes, children would: + +1. Steal session slots from root conversations → deadlock at concurrency limit +2. Inflate sent_sessions count → premature quota exhaustion → gate hangs +3. Release session slots they never acquired → negative slot counts +""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from aiperf.common.enums import CreditPhase +from aiperf.credit.callback_handler import CreditCallbackHandler +from aiperf.credit.issuer import CreditIssuer +from aiperf.credit.messages import CreditReturn +from aiperf.credit.structs import Credit, TurnToSend +from aiperf.timing.config import CreditPhaseConfig +from aiperf.timing.phase.credit_counter import CreditCounter + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _turn( + conv: str = "conv1", + idx: int = 0, + num: int = 1, + agent_depth: int = 0, + parent_correlation_id: str | None = None, +) -> TurnToSend: + return TurnToSend( + conversation_id=conv, + x_correlation_id=f"corr-{conv}", + turn_index=idx, + num_turns=num, + agent_depth=agent_depth, + parent_correlation_id=parent_correlation_id, + ) + + +def _credit( + credit_id: int = 1, + conv: str = "conv1", + turn_index: int = 0, + num_turns: int = 1, + agent_depth: int = 0, + phase: CreditPhase = CreditPhase.PROFILING, +) -> Credit: + return Credit( + id=credit_id, + phase=phase, + conversation_id=conv, + x_correlation_id=f"corr-{conv}", + turn_index=turn_index, + num_turns=num_turns, + issued_at_ns=time.time_ns(), + agent_depth=agent_depth, + ) + + +def _credit_return( + credit: Credit, + cancelled: bool = False, + first_token_sent: bool = True, + error: str | None = None, +) -> CreditReturn: + return CreditReturn( + credit=credit, + cancelled=cancelled, + first_token_sent=first_token_sent, + error=error, + ) + + +def _cfg( + reqs: int | None = None, + sessions: int | None = None, +) -> CreditPhaseConfig: + from aiperf.plugin.enums import TimingMode + + return CreditPhaseConfig( + phase=CreditPhase.PROFILING, + timing_mode=TimingMode.REQUEST_RATE, + total_expected_requests=reqs, + expected_num_sessions=sessions, + ) + + +def _make_issuer(**overrides) -> tuple: + """Create a CreditIssuer with mocked deps, returning (issuer, mocks_dict).""" + mocks = { + "stop_checker": MagicMock( + can_send_any_turn=MagicMock(return_value=True), + can_start_new_session=MagicMock(return_value=True), + ), + "progress": MagicMock( + increment_sent=MagicMock(return_value=(0, False)), + freeze_sent_counts=MagicMock(), + all_credits_sent_event=asyncio.Event(), + ), + "concurrency": MagicMock( + acquire_session_slot=AsyncMock(return_value=True), + acquire_prefill_slot=AsyncMock(return_value=True), + try_acquire_session_slot=MagicMock(return_value=True), + try_acquire_prefill_slot=MagicMock(return_value=True), + release_session_slot=MagicMock(), + ), + "router": MagicMock(send_credit=AsyncMock()), + "cancellation": MagicMock( + next_cancellation_delay_ns=MagicMock(return_value=None) + ), + "lifecycle": MagicMock( + started_at_ns=time.time_ns(), + started_at_perf_ns=time.perf_counter_ns(), + ), + } + mocks.update(overrides) + issuer = CreditIssuer( + phase=CreditPhase.PROFILING, + stop_checker=mocks["stop_checker"], + progress=mocks["progress"], + concurrency_manager=mocks["concurrency"], + credit_router=mocks["router"], + cancellation_policy=mocks["cancellation"], + lifecycle=mocks["lifecycle"], + ) + return issuer, mocks + + +# ============================================================================= +# Bug 1: Children must NOT acquire session slots +# ============================================================================= + + +class TestChildSkipsSessionSlotAcquisition: + """Without this fix, children steal session slots from root conversations. + + Scenario: concurrency=2, 2 roots running, a child spawns → old code tries + to acquire a 3rd session slot → blocks forever because semaphore is full + → parent gate never fires → deadlock. + """ + + async def test_child_first_turn_skips_session_slot_acquire(self) -> None: + """Child first turn (turn_index=0, agent_depth=1) must NOT acquire session slot.""" + issuer, mocks = _make_issuer() + child_turn = _turn(conv="child-1", idx=0, num=2, agent_depth=1) + + await issuer.issue_credit(child_turn) + + mocks["concurrency"].acquire_session_slot.assert_not_called() + mocks["concurrency"].acquire_prefill_slot.assert_called_once() + mocks["router"].send_credit.assert_called_once() + + async def test_child_first_turn_skips_session_slot_try_acquire(self) -> None: + """try_issue_credit: child first turn must NOT try-acquire session slot.""" + issuer, mocks = _make_issuer() + child_turn = _turn(conv="child-1", idx=0, num=2, agent_depth=1) + + result = await issuer.try_issue_credit(child_turn) + + assert result is True + mocks["concurrency"].try_acquire_session_slot.assert_not_called() + mocks["concurrency"].try_acquire_prefill_slot.assert_called_once() + + async def test_root_first_turn_still_acquires_session_slot(self) -> None: + """Sanity: root first turn (agent_depth=0) must still acquire session slot.""" + issuer, mocks = _make_issuer() + root_turn = _turn(conv="root-1", idx=0, num=3, agent_depth=0) + + await issuer.issue_credit(root_turn) + + mocks["concurrency"].acquire_session_slot.assert_called_once() + + async def test_child_does_not_release_session_on_prefill_failure(self) -> None: + """If child's prefill slot fails, no session slot release (none was acquired).""" + issuer, mocks = _make_issuer() + mocks["concurrency"].acquire_prefill_slot.return_value = False + child_turn = _turn(conv="child-1", idx=0, num=2, agent_depth=1) + + result = await issuer.issue_credit(child_turn) + + assert result is False + mocks["concurrency"].release_session_slot.assert_not_called() + + async def test_root_releases_session_on_prefill_failure(self) -> None: + """Sanity: root releases session slot if prefill fails.""" + issuer, mocks = _make_issuer() + mocks["concurrency"].acquire_prefill_slot.return_value = False + root_turn = _turn(conv="root-1", idx=0, num=2, agent_depth=0) + + result = await issuer.issue_credit(root_turn) + + assert result is False + mocks["concurrency"].release_session_slot.assert_called_once() + + @pytest.mark.parametrize("depth", [1, 2, 5]) + async def test_deeply_nested_children_skip_session_slot(self, depth: int) -> None: + """Children at any nesting depth must skip session slot.""" + issuer, mocks = _make_issuer() + child_turn = _turn(conv=f"child-d{depth}", idx=0, num=1, agent_depth=depth) + + await issuer.issue_credit(child_turn) + + mocks["concurrency"].acquire_session_slot.assert_not_called() + + async def test_mixed_root_and_child_interleaved(self) -> None: + """Interleaved root/child issuance: only roots acquire session slots.""" + issuer, mocks = _make_issuer() + credit_idx = [0] + + def inc_sent(turn): + credit_idx[0] += 1 + return (credit_idx[0], False) + + mocks["progress"].increment_sent = inc_sent + + root_t0 = _turn(conv="root", idx=0, num=3, agent_depth=0) + child_t0 = _turn(conv="child-a", idx=0, num=2, agent_depth=1) + root_t1 = _turn(conv="root", idx=1, num=3, agent_depth=0) + child_t1 = _turn(conv="child-a", idx=1, num=2, agent_depth=1) + root_t2 = _turn(conv="root", idx=2, num=3, agent_depth=0) + + for turn in [root_t0, child_t0, root_t1, child_t1, root_t2]: + await issuer.issue_credit(turn) + + # Only the root first turn (root_t0) should have acquired session slot + assert mocks["concurrency"].acquire_session_slot.call_count == 1 + + +# ============================================================================= +# Bug 2: Children must NOT use can_start_new_session stop check +# ============================================================================= + + +class TestChildUsesCanSendAnyTurn: + """Without this fix, child first turns use can_start_new_session which + checks session quota. When quota is exhausted by root sessions, children + are blocked → parent gates hang waiting for children that can never start. + """ + + async def test_child_first_turn_uses_can_send_any_turn_not_can_start_new_session( + self, + ) -> None: + """Child first turn must use can_send_any_turn for prefill slot check.""" + issuer, mocks = _make_issuer() + # Make session quota exhausted: can_start_new_session=False but can_send_any_turn=True + mocks["stop_checker"].can_start_new_session.return_value = False + mocks["stop_checker"].can_send_any_turn.return_value = True + + child_turn = _turn(conv="child-1", idx=0, num=1, agent_depth=1) + result = await issuer.issue_credit(child_turn) + + # Child must succeed despite session quota being exhausted + assert result is True + mocks["router"].send_credit.assert_called_once() + + async def test_child_first_turn_blocked_by_old_code_would_deadlock(self) -> None: + """Demonstrates the deadlock: session quota full, child can't start. + + Old code used can_start_new_session for child first turns. With + expected_num_sessions reached, can_start_new_session returns False. + The child issuance fails, parent gate never fires → deadlock. + """ + issuer, mocks = _make_issuer() + mocks["stop_checker"].can_start_new_session.return_value = False + mocks["stop_checker"].can_send_any_turn.return_value = True + + child_turn = _turn(conv="child-1", idx=0, num=1, agent_depth=1) + + # try_issue_credit: old code would return False (stop condition), + # new code returns True (child uses can_send_any_turn) + result = await issuer.try_issue_credit(child_turn) + assert result is True + + async def test_root_first_turn_still_uses_can_start_new_session(self) -> None: + """Root first turn must still use can_start_new_session.""" + issuer, mocks = _make_issuer() + root_turn = _turn(conv="root-1", idx=0, num=1, agent_depth=0) + + await issuer.issue_credit(root_turn) + + # Verify can_start_new_session was passed to acquire_session_slot + call_args = mocks["concurrency"].acquire_session_slot.call_args + assert call_args[0][1] == mocks["stop_checker"].can_start_new_session + + async def test_try_issue_child_returns_false_only_when_can_send_any_turn_false( + self, + ) -> None: + """Child try_issue: returns False only when can_send_any_turn is False.""" + issuer, mocks = _make_issuer() + mocks["stop_checker"].can_send_any_turn.return_value = False + mocks["stop_checker"].can_start_new_session.return_value = False + + child_turn = _turn(conv="child-1", idx=0, num=1, agent_depth=1) + result = await issuer.try_issue_credit(child_turn) + + assert result is False # Correctly stopped by can_send_any_turn + + +# ============================================================================= +# Bug 3: Children must NOT inflate session counts in CreditCounter +# ============================================================================= + + +class TestChildSessionCountExclusion: + """Without this fix, children increment sent_sessions which is compared + against expected_num_sessions. In non-FIXED_SCHEDULE modes, expected_num_sessions + is user-specified and doesn't account for children → premature quota + exhaustion → is_final_credit fires too early or child issuance fails. + """ + + def test_child_first_turn_does_not_increment_sent_sessions(self) -> None: + """Child first turn must NOT increment sent_sessions.""" + c = CreditCounter(_cfg(sessions=5)) + + c.increment_sent(_turn(conv="root-1", idx=0, num=2, agent_depth=0)) + assert c.sent_sessions == 1 + + c.increment_sent(_turn(conv="child-1", idx=0, num=3, agent_depth=1)) + assert c.sent_sessions == 1 # Still 1, child didn't count + + c.increment_sent(_turn(conv="root-2", idx=0, num=1, agent_depth=0)) + assert c.sent_sessions == 2 # Root incremented + + def test_child_first_turn_does_not_increment_total_session_turns(self) -> None: + """Child first turn must NOT add to total_session_turns.""" + c = CreditCounter(_cfg(sessions=5)) + + c.increment_sent(_turn(conv="root-1", idx=0, num=3, agent_depth=0)) + assert c.total_session_turns == 3 + + c.increment_sent(_turn(conv="child-1", idx=0, num=5, agent_depth=1)) + assert c.total_session_turns == 3 # Child's turns not counted + + def test_premature_quota_exhaustion_prevented(self) -> None: + """With expected_num_sessions=2, 2 roots + 3 children must not hit quota early. + + Old code: sent_sessions would reach 2 after just root-1 + child-1, + triggering is_final_credit prematurely before root-2 is ever sent. + """ + c = CreditCounter(_cfg(sessions=2)) + + # Root 1: 2-turn conversation + c.increment_sent(_turn(conv="root-1", idx=0, num=2, agent_depth=0)) + assert c.sent_sessions == 1 + + # Child spawned by root-1 + _, is_final = c.increment_sent( + _turn(conv="child-1a", idx=0, num=1, agent_depth=1) + ) + assert not is_final # Must NOT be final + assert c.sent_sessions == 1 # Children don't count + + # Another child + _, is_final = c.increment_sent( + _turn(conv="child-1b", idx=0, num=2, agent_depth=1) + ) + assert not is_final + assert c.sent_sessions == 1 + + # Root 1 turn 1 + c.increment_sent(_turn(conv="root-1", idx=1, num=2, agent_depth=0)) + + # Child turns + c.increment_sent(_turn(conv="child-1b", idx=1, num=2, agent_depth=1)) + c.increment_sent(_turn(conv="child-1a", idx=0, num=1, agent_depth=1)) + + # Root 2: 1-turn conversation (this is the 2nd and final session) + _, is_final = c.increment_sent( + _turn(conv="root-2", idx=0, num=1, agent_depth=0) + ) + assert c.sent_sessions == 2 + # Now is_final should be True (2 sessions sent, all turns sent) + assert is_final + + def test_child_completion_does_not_increment_completed_sessions(self) -> None: + """Child final turn return must NOT increment completed_sessions.""" + c = CreditCounter(_cfg()) + + # Send root + child + c.increment_sent(_turn(conv="root-1", idx=0, num=2, agent_depth=0)) + c.increment_sent(_turn(conv="child-1", idx=0, num=1, agent_depth=1)) + + # Child completes (final turn, depth=1) + c.increment_returned(is_final_turn=True, cancelled=False, agent_depth=1) + assert c.completed_sessions == 0 # Child doesn't count + + # Root non-final turn completes + c.increment_returned(is_final_turn=False, cancelled=False, agent_depth=0) + assert c.completed_sessions == 0 + + # Root final turn completes + c.increment_returned(is_final_turn=True, cancelled=False, agent_depth=0) + assert c.completed_sessions == 1 # Only root counts + + def test_child_cancellation_does_not_increment_cancelled_sessions(self) -> None: + """Child final turn cancellation must NOT increment cancelled_sessions.""" + c = CreditCounter(_cfg()) + + c.increment_sent(_turn(conv="root-1", idx=0, num=1, agent_depth=0)) + c.increment_sent(_turn(conv="child-1", idx=0, num=1, agent_depth=1)) + + # Child cancelled + c.increment_returned(is_final_turn=True, cancelled=True, agent_depth=1) + assert c.cancelled_sessions == 0 + + # Root cancelled + c.increment_returned(is_final_turn=True, cancelled=True, agent_depth=0) + assert c.cancelled_sessions == 1 + + def test_requests_sent_still_counts_all_turns(self) -> None: + """requests_sent must count ALL turns (root + child) for progress tracking.""" + c = CreditCounter(_cfg()) + + c.increment_sent(_turn(conv="root-1", idx=0, num=2, agent_depth=0)) + c.increment_sent(_turn(conv="child-1", idx=0, num=1, agent_depth=1)) + c.increment_sent(_turn(conv="root-1", idx=1, num=2, agent_depth=0)) + + assert c.requests_sent == 3 # All turns counted + + def test_requests_completed_still_counts_all_returns(self) -> None: + """requests_completed must count ALL returns (root + child).""" + c = CreditCounter(_cfg()) + + c.increment_sent(_turn(conv="root-1", idx=0, num=1, agent_depth=0)) + c.increment_sent(_turn(conv="child-1", idx=0, num=1, agent_depth=1)) + + c.increment_returned(is_final_turn=True, cancelled=False, agent_depth=1) + c.increment_returned(is_final_turn=True, cancelled=False, agent_depth=0) + + assert c.requests_completed == 2 # Both counted + + def test_in_flight_sessions_excludes_children(self) -> None: + """in_flight_sessions = sent_sessions - completed_sessions - cancelled_sessions. + + Since children are excluded from all three, in_flight_sessions should + reflect only root sessions. + """ + c = CreditCounter(_cfg()) + + # Start 2 roots + 1 child + c.increment_sent(_turn(conv="root-1", idx=0, num=2, agent_depth=0)) + c.increment_sent(_turn(conv="root-2", idx=0, num=1, agent_depth=0)) + c.increment_sent(_turn(conv="child-1", idx=0, num=1, agent_depth=1)) + + assert c.in_flight_sessions == 2 # Only roots + + # Child completes + c.increment_returned(is_final_turn=True, cancelled=False, agent_depth=1) + assert c.in_flight_sessions == 2 # Unchanged + + # Root-2 completes + c.increment_returned(is_final_turn=True, cancelled=False, agent_depth=0) + assert c.in_flight_sessions == 1 + + def test_freeze_sent_counts_only_reflects_root_sessions(self) -> None: + """Frozen sent_sessions should only count root sessions.""" + c = CreditCounter(_cfg()) + + c.increment_sent(_turn(conv="root-1", idx=0, num=1, agent_depth=0)) + c.increment_sent(_turn(conv="child-1", idx=0, num=1, agent_depth=1)) + c.increment_sent(_turn(conv="child-2", idx=0, num=1, agent_depth=2)) + c.increment_sent(_turn(conv="root-2", idx=0, num=1, agent_depth=0)) + + c.freeze_sent_counts() + + assert c.final_sent_sessions == 2 # Only roots + assert c.final_requests_sent == 4 # All requests + + +# ============================================================================= +# Bug 4: Children must NOT release session slots in callback handler +# ============================================================================= + + +class TestChildSkipsSessionSlotRelease: + """Without this fix, child final turns call release_session_slot even + though they never acquired one. This causes negative semaphore counts + or double-releases that corrupt the session slot pool. + """ + + async def test_child_final_turn_does_not_release_session_slot(self) -> None: + """Child final turn return must NOT release session slot.""" + mock_concurrency = MagicMock( + release_session_slot=MagicMock(), + release_prefill_slot=MagicMock(), + ) + handler = CreditCallbackHandler(mock_concurrency) + handler.register_phase( + phase=CreditPhase.PROFILING, + progress=MagicMock( + increment_returned=MagicMock(return_value=False), + increment_prefill_released=MagicMock(), + in_flight_sessions=0, + all_credits_returned_event=asyncio.Event(), + ), + lifecycle=MagicMock(is_complete=False), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + strategy=MagicMock(handle_credit_return=AsyncMock()), + ) + + # Child final turn (depth=1) + child_credit = _credit(conv="child-1", turn_index=0, num_turns=1, agent_depth=1) + child_return = _credit_return(child_credit) + + await handler.on_credit_return("worker-1", child_return) + + mock_concurrency.release_session_slot.assert_not_called() + + async def test_root_final_turn_still_releases_session_slot(self) -> None: + """Sanity: root final turn must still release session slot.""" + mock_concurrency = MagicMock( + release_session_slot=MagicMock(), + release_prefill_slot=MagicMock(), + ) + handler = CreditCallbackHandler(mock_concurrency) + handler.register_phase( + phase=CreditPhase.PROFILING, + progress=MagicMock( + increment_returned=MagicMock(return_value=False), + increment_prefill_released=MagicMock(), + in_flight_sessions=0, + all_credits_returned_event=asyncio.Event(), + ), + lifecycle=MagicMock(is_complete=False), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + strategy=MagicMock(handle_credit_return=AsyncMock()), + ) + + # Root final turn (depth=0) + root_credit = _credit(conv="root-1", turn_index=2, num_turns=3, agent_depth=0) + root_return = _credit_return(root_credit) + + await handler.on_credit_return("worker-1", root_return) + + mock_concurrency.release_session_slot.assert_called_once_with( + CreditPhase.PROFILING + ) + + async def test_mixed_root_child_returns_only_root_releases_session(self) -> None: + """Interleaved root/child returns: only root final turns release session slots.""" + mock_concurrency = MagicMock( + release_session_slot=MagicMock(), + release_prefill_slot=MagicMock(), + ) + handler = CreditCallbackHandler(mock_concurrency) + handler.register_phase( + phase=CreditPhase.PROFILING, + progress=MagicMock( + increment_returned=MagicMock(return_value=False), + increment_prefill_released=MagicMock(), + in_flight_sessions=0, + all_credits_returned_event=asyncio.Event(), + ), + lifecycle=MagicMock(is_complete=False), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + strategy=MagicMock(handle_credit_return=AsyncMock()), + ) + + returns = [ + # (conv, turn_index, num_turns, agent_depth, should_release_session) + ("root-1", 0, 3, 0, False), # root non-final + ("child-a", 0, 1, 1, False), # child final (depth=1) + ("child-b", 0, 2, 2, False), # child non-final (depth=2) + ("child-b", 1, 2, 2, False), # child final (depth=2) + ("root-1", 1, 3, 0, False), # root non-final + ("root-1", 2, 3, 0, True), # root final → release + ("root-2", 0, 1, 0, True), # root final → release + ] + + for i, (conv, tidx, nturns, depth, _) in enumerate(returns): + mock_concurrency.reset_mock() + credit = _credit( + credit_id=i, + conv=conv, + turn_index=tidx, + num_turns=nturns, + agent_depth=depth, + ) + await handler.on_credit_return("worker-1", _credit_return(credit)) + + # Count total session releases across all calls + # We reset mocks between calls, so check the last call only confirms pattern. + # Instead, let's do it without reset: + mock_concurrency2 = MagicMock( + release_session_slot=MagicMock(), + release_prefill_slot=MagicMock(), + ) + handler2 = CreditCallbackHandler(mock_concurrency2) + handler2.register_phase( + phase=CreditPhase.PROFILING, + progress=MagicMock( + increment_returned=MagicMock(return_value=False), + increment_prefill_released=MagicMock(), + in_flight_sessions=0, + all_credits_returned_event=asyncio.Event(), + ), + lifecycle=MagicMock(is_complete=False), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + strategy=MagicMock(handle_credit_return=AsyncMock()), + ) + + for i, (conv, tidx, nturns, depth, _) in enumerate(returns): + credit = _credit( + credit_id=i, + conv=conv, + turn_index=tidx, + num_turns=nturns, + agent_depth=depth, + ) + await handler2.on_credit_return("worker-1", _credit_return(credit)) + + # Exactly 2 session slot releases (root-1 final + root-2 final) + assert mock_concurrency2.release_session_slot.call_count == 2 + + async def test_child_return_passes_agent_depth_to_increment_returned(self) -> None: + """Callback handler must pass agent_depth to progress.increment_returned.""" + mock_progress = MagicMock( + increment_returned=MagicMock(return_value=False), + increment_prefill_released=MagicMock(), + in_flight_sessions=0, + all_credits_returned_event=asyncio.Event(), + ) + handler = CreditCallbackHandler(MagicMock()) + handler.register_phase( + phase=CreditPhase.PROFILING, + progress=mock_progress, + lifecycle=MagicMock(is_complete=False), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + strategy=MagicMock(handle_credit_return=AsyncMock()), + ) + + for depth in [0, 1, 2, 5]: + mock_progress.reset_mock() + credit = _credit(conv=f"d{depth}", agent_depth=depth) + await handler.on_credit_return("worker-1", _credit_return(credit)) + + mock_progress.increment_returned.assert_called_once_with( + credit.is_final_turn, + False, + agent_depth=depth, + ) + + +# ============================================================================= +# Integration: Full scenario that would deadlock old code +# ============================================================================= + + +class TestDeadlockScenario: + """End-to-end scenario using real CreditCounter that demonstrates the + deadlock the old code would produce. + + Setup: expected_num_sessions=2, concurrency=2 + Trace: root-1 (3 turns) spawns child-a (2 turns) at turn 1, root-2 (1 turn) + + Old behavior: + - root-1 turn 0: sent_sessions=1, acquires session slot + - child-a turn 0: sent_sessions=2 (BUG), acquires session slot (BUG) + - root-2 turn 0: can_start_new_session=False (sessions=2 >= limit=2) + AND no session slots available (2/2 used) → DEADLOCK + """ + + def test_counter_does_not_exhaust_session_quota_with_children(self) -> None: + """CreditCounter: children don't count toward session quota.""" + c = CreditCounter(_cfg(sessions=2)) + + # Root 1 starts + _, is_final = c.increment_sent( + _turn(conv="root-1", idx=0, num=3, agent_depth=0) + ) + assert not is_final + assert c.sent_sessions == 1 + + # Child spawned by root-1 + _, is_final = c.increment_sent( + _turn(conv="child-a", idx=0, num=2, agent_depth=1) + ) + assert not is_final + assert c.sent_sessions == 1 # CRITICAL: still 1 + + # Root 1 turn 1 + c.increment_sent(_turn(conv="root-1", idx=1, num=3, agent_depth=0)) + + # Child turn 1 + c.increment_sent(_turn(conv="child-a", idx=1, num=2, agent_depth=1)) + + # Root 2 starts -- this is the 2nd and final session + _, is_final = c.increment_sent( + _turn(conv="root-2", idx=0, num=1, agent_depth=0) + ) + assert c.sent_sessions == 2 + + # Root 1 turn 2 (final turn of root-1) + _, is_final = c.increment_sent( + _turn(conv="root-1", idx=2, num=3, agent_depth=0) + ) + # Now: 2 sessions started, total_session_turns=3+1=4, requests_sent=6 + # is_final should be True (sessions >= 2 AND requests >= total_turns) + assert is_final + + async def test_issuer_allows_child_when_session_quota_full(self) -> None: + """CreditIssuer: child issuance succeeds even when session quota is full.""" + issuer, mocks = _make_issuer() + + # Simulate: session quota is full + mocks["stop_checker"].can_start_new_session.return_value = False + mocks["stop_checker"].can_send_any_turn.return_value = True + + # Child must still be issuable + child_turn = _turn(conv="child-a", idx=0, num=2, agent_depth=1) + result = await issuer.issue_credit(child_turn) + + assert result is True + mocks["concurrency"].acquire_session_slot.assert_not_called() + mocks["router"].send_credit.assert_called_once() + + # Verify the credit has correct agent_depth + sent_credit = mocks["router"].send_credit.call_args.kwargs["credit"] + assert sent_credit.agent_depth == 1 diff --git a/tests/unit/dataset/loader/test_conflux.py b/tests/unit/dataset/loader/test_conflux.py new file mode 100644 index 000000000..cc446644f --- /dev/null +++ b/tests/unit/dataset/loader/test_conflux.py @@ -0,0 +1,2489 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for ConfluxLoader.""" + +import base64 + +import orjson +import pytest + +from aiperf.common.config import ( + EndpointConfig, + InputConfig, + InputTokensConfig, + PromptConfig, + UserConfig, +) +from aiperf.common.enums import ConversationContextMode, PrerequisiteKind +from aiperf.dataset.loader.conflux import ( + ConfluxLoader, + _build_spawn_tuid_to_agent_id, + _detect_join_turn_from_content, + _extract_notification_joins, + _find_join_turn_index, + _iter_message_blocks, + _new_messages, + _parse_timestamp_ms, + _record_end_ms, + _stringify_block_content, +) +from aiperf.dataset.loader.models import ConfluxRecord +from aiperf.plugin.enums import DatasetSamplingStrategy + +# ========================================================================= +# Test data builders +# ========================================================================= + +BASE_TS = "2026-02-25T02:02:00.000Z" +PARENT_TOOLS = [{"name": f"tool_{i}", "description": f"Tool {i}"} for i in range(10)] +CHILD_TOOLS = [ + {"name": f"sub_tool_{i}", "description": f"Sub Tool {i}"} for i in range(5) +] + + +def _ts(offset_s: float) -> str: + """Generate ISO timestamp with offset from base.""" + from datetime import datetime, timedelta, timezone + + base = datetime(2026, 2, 25, 2, 2, 0, tzinfo=timezone.utc) + dt = base + timedelta(seconds=offset_s) + return dt.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + + +_UNSET = object() + + +def _make_record( + *, + record_id: str = "req_001", + agent_id: str | None = "claude", + is_subagent: bool = False, + timestamp: str = BASE_TS, + model: str = "claude-sonnet-4-6", + messages: list | object = _UNSET, + tools: list | object = _UNSET, + tokens: dict | object = _UNSET, + hyperparameters: dict | None | object = _UNSET, + is_streaming: bool | None = True, + duration_ms: int = 1000, +) -> dict: + """Build a raw Conflux record dict.""" + record: dict = { + "id": record_id, + "session_id": "sess-001", + "agent_id": agent_id, + "is_subagent": is_subagent, + "timestamp": timestamp, + "duration_ms": duration_ms, + "model": model, + "tokens": ( + { + "input": 1000, + "input_cached": 800, + "input_cache_write": 100, + "output": 200, + } + if tokens is _UNSET + else tokens + ), + "tools": [] if tools is _UNSET else tools, + "messages": ( + [{"role": "user", "content": "Hello"}] if messages is _UNSET else messages + ), + "output": [{"type": "text", "text": "Hi there"}], + "is_streaming": is_streaming, + "ttft_ms": 150, + } + if hyperparameters is not _UNSET: + record["hyperparameters"] = hyperparameters + else: + record["hyperparameters"] = {"max_tokens": 4096} + return record + + +def _build_session_file(tmp_path, records: list[dict]) -> str: + """Write records to a JSON file and return the path.""" + path = tmp_path / "session.json" + path.write_bytes(orjson.dumps(records)) + return str(path) + + +def _build_team_session(tmp_path) -> str: + """Build a session with parent (5 turns) + 2 subagents (3 turns each).""" + records = [] + + # Parent: 5 turns + for i in range(5): + records.append( + _make_record( + record_id=f"req_parent_{i:03d}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 5), + model="claude-opus-4-6", + messages=[{"role": "user", "content": f"parent turn {i}"}], + tools=PARENT_TOOLS, + tokens={ + "input": 1000 + i * 500, + "input_cached": 800, + "input_cache_write": 100, + "output": 200 + i * 50, + }, + hyperparameters={"max_tokens": 32000}, + ) + ) + + # Subagent A: 3 turns, spawned around parent turn 2 + for i in range(3): + records.append( + _make_record( + record_id=f"req_sub_a_{i:03d}", + agent_id="sub_a", + is_subagent=True, + timestamp=_ts(10 + i * 3), + model="claude-opus-4-6", + messages=[{"role": "user", "content": f"sub_a turn {i}"}], + tools=CHILD_TOOLS, + tokens={ + "input": 500 + i * 200, + "input_cached": 400, + "input_cache_write": 50, + "output": 100, + }, + hyperparameters={"max_tokens": 16000}, + ) + ) + + # Subagent B: 3 turns, spawned around parent turn 3 + for i in range(3): + records.append( + _make_record( + record_id=f"req_sub_b_{i:03d}", + agent_id="sub_b", + is_subagent=True, + timestamp=_ts(15 + i * 2), + model="claude-haiku-4-5-20251001", + messages=[{"role": "user", "content": f"sub_b turn {i}"}], + tools=[], + tokens={ + "input": 200, + "input_cached": 150, + "input_cache_write": 20, + "output": 50, + }, + ) + ) + + # Add a record with agent_id=None (should be filtered out) + records.append( + _make_record( + record_id="req_haiku_tool", + agent_id=None, + timestamp=_ts(8), + model="claude-haiku-4-5-20251001", + messages=[{"role": "user", "content": "tool check"}], + ) + ) + + return _build_session_file(tmp_path, records) + + +def _build_delayed_join_session(tmp_path, *, with_content_signals: bool) -> str: + """Build a parent/child session where the join happens at parent turn 3.""" + if with_content_signals: + turn0_messages = [{"role": "user", "content": "parent turn 0"}] + agent_tool_use = { + "type": "tool_use", + "id": "toolu_spawn_1", + "name": "Agent", + "input": {"task": "inspect auth flow"}, + } + queued_result = { + "type": "tool_result", + "tool_use_id": "toolu_spawn_1", + "content": "queued for running", + } + full_result = { + "type": "tool_result", + "tool_use_id": "toolu_spawn_1", + "content": "child finished and found the root cause", + } + turn1_messages = turn0_messages + [ + {"role": "assistant", "content": [agent_tool_use]}, + {"role": "user", "content": [queued_result]}, + {"role": "user", "content": "parent turn 1"}, + ] + turn2_messages = turn1_messages + [ + {"role": "assistant", "content": "doing unrelated work"}, + {"role": "user", "content": "parent turn 2"}, + ] + turn3_messages = turn2_messages + [ + {"role": "assistant", "content": "ready to use child output"}, + {"role": "user", "content": [full_result]}, + {"role": "user", "content": "parent turn 3"}, + ] + parent_messages = [ + turn0_messages, + turn1_messages, + turn2_messages, + turn3_messages, + ] + else: + parent_messages = [ + [{"role": "user", "content": f"parent turn {i}"}] for i in range(4) + ] + + records = [] + parent_offsets = [0, 2, 4, 8] + for i, offset in enumerate(parent_offsets): + records.append( + _make_record( + record_id=f"req_parent_delayed_{i:03d}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(offset), + messages=parent_messages[i], + tools=PARENT_TOOLS, + duration_ms=1000, + ) + ) + + child_offsets = [0.5, 3, 6] + for i, offset in enumerate(child_offsets): + records.append( + _make_record( + record_id=f"req_child_delayed_{i:03d}", + agent_id="sub_delayed", + is_subagent=True, + timestamp=_ts(offset), + messages=[{"role": "user", "content": f"child turn {i}"}], + tools=CHILD_TOOLS, + duration_ms=1000, + ) + ) + + return _build_session_file(tmp_path, records) + + +# ========================================================================= +# Model tests +# ========================================================================= + + +class TestConfluxRecord: + def test_create(self): + record = ConfluxRecord.model_validate(_make_record()) + assert record.id == "req_001" + assert record.agent_id == "claude" + assert not record.is_subagent + assert record.tokens is not None + assert record.tokens.input == 1000 + + def test_defaults(self): + raw = { + "id": "req_min", + "session_id": "sess", + "timestamp": BASE_TS, + "messages": [{"role": "user", "content": "hi"}], + } + record = ConfluxRecord.model_validate(raw) + assert record.agent_id is None + assert record.is_subagent is None + assert record.source is None + assert record.client is None + assert record.provider is None + assert record.completed_at is None + assert record.client_version is None + assert record.request_id is None + assert record.tokens is None + assert record.tools == [] + assert record.output == [] + assert record.hyperparameters is None + assert record.is_streaming is None + + def test_unified_fields_roundtrip(self): + """Fields from the Conflux unified schema are preserved.""" + raw = _make_record() + raw["source"] = "proxy" + raw["client"] = "claude" + raw["provider"] = "anthropic" + raw["completed_at"] = "2026-02-25T02:02:01.000Z" + raw["client_version"] = "1.2.3" + raw["request_id"] = "req_abc123" + record = ConfluxRecord.model_validate(raw) + assert record.source == "proxy" + assert record.client == "claude" + assert record.provider == "anthropic" + assert record.completed_at == "2026-02-25T02:02:01.000Z" + assert record.client_version == "1.2.3" + assert record.request_id == "req_abc123" + + +# ========================================================================= +# can_load tests +# ========================================================================= + + +class TestCanLoad: + def test_valid_json_file(self, tmp_path): + path = _build_session_file(tmp_path, [_make_record()]) + assert ConfluxLoader.can_load(filename=path) + + def test_not_json_extension(self, tmp_path): + path = tmp_path / "data.txt" + path.write_bytes(orjson.dumps([_make_record()])) + assert not ConfluxLoader.can_load(filename=str(path)) + + def test_empty_array(self, tmp_path): + path = tmp_path / "empty.json" + path.write_bytes(orjson.dumps([])) + assert not ConfluxLoader.can_load(filename=str(path)) + + def test_missing_agent_id_key(self, tmp_path): + path = tmp_path / "no_agent.json" + path.write_bytes(orjson.dumps([{"is_subagent": False, "messages": []}])) + assert not ConfluxLoader.can_load(filename=str(path)) + + def test_missing_is_subagent_key(self, tmp_path): + path = tmp_path / "no_sub.json" + path.write_bytes(orjson.dumps([{"agent_id": "x", "messages": []}])) + assert not ConfluxLoader.can_load(filename=str(path)) + + def test_missing_messages_key(self, tmp_path): + path = tmp_path / "no_msg.json" + path.write_bytes(orjson.dumps([{"agent_id": "x", "is_subagent": False}])) + assert not ConfluxLoader.can_load(filename=str(path)) + + def test_not_a_list(self, tmp_path): + path = tmp_path / "obj.json" + path.write_bytes(orjson.dumps({"agent_id": "x"})) + assert not ConfluxLoader.can_load(filename=str(path)) + + def test_none_filename(self): + assert not ConfluxLoader.can_load(filename=None) + + def test_directory(self, tmp_path): + assert not ConfluxLoader.can_load(filename=str(tmp_path)) + + def test_nonexistent_file(self): + assert not ConfluxLoader.can_load(filename="/nonexistent/file.json") + + def test_preferred_sampling_strategy(self): + assert ( + ConfluxLoader.get_preferred_sampling_strategy() + == DatasetSamplingStrategy.SEQUENTIAL + ) + + def test_large_file_byte_probe_fallback(self, tmp_path, monkeypatch): + """When file exceeds probe limit, byte-level detection is used.""" + from aiperf.dataset.loader import conflux as conflux_mod + + # 300 bytes is enough to contain signature fields in the first record + # but truncates the 2KB+ file, forcing the byte-level fallback path + monkeypatch.setattr(conflux_mod, "_CAN_LOAD_PROBE_BYTES", 300) + + records = [_make_record() for _ in range(5)] + path = _build_session_file(tmp_path, records) + assert ConfluxLoader.can_load(filename=path) + + def test_large_file_byte_probe_rejects_non_array(self, tmp_path, monkeypatch): + """Byte-level fallback rejects files that don't start with '['.""" + from aiperf.dataset.loader import conflux as conflux_mod + + monkeypatch.setattr(conflux_mod, "_CAN_LOAD_PROBE_BYTES", 8) + + # Content is longer than 8 bytes but not a JSON array + path = tmp_path / "obj.json" + path.write_bytes( + b'{"agent_id": "x", "is_subagent": false, "messages": [],' + b' "padding": "' + b"x" * 100 + b'"}' + ) + assert not ConfluxLoader.can_load(filename=str(path)) + + def test_large_file_byte_probe_rejects_missing_fields(self, tmp_path, monkeypatch): + """Byte-level fallback rejects files without Conflux signature fields.""" + from aiperf.dataset.loader import conflux as conflux_mod + + monkeypatch.setattr(conflux_mod, "_CAN_LOAD_PROBE_BYTES", 8) + + # Array that's long enough to truncate but has no Conflux fields + path = tmp_path / "other.json" + path.write_bytes(b'[{"type": "single_turn", "text": "' + b"x" * 100 + b'"}]') + assert not ConfluxLoader.can_load(filename=str(path)) + + +# ========================================================================= +# load_dataset tests +# ========================================================================= + + +class TestLoadDataset: + @pytest.fixture + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + conflux_include_utility_calls=True, + ), + ) + + @pytest.fixture + def team_session(self, tmp_path): + return _build_team_session(tmp_path) + + def test_group_count(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + assert len(data) == 4 # parent + 2 subagents + 1 orphan + + def test_orphan_as_separate_group(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + total = sum(len(recs) for recs in data.values()) + assert total == 12 # 5 + 3 + 3 + 1 orphan + + def test_parent_has_5_records(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + assert len(data["claude"]) == 5 + + def test_subagent_a_has_3_records(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + assert len(data["sub_a"]) == 3 + + def test_sorted_by_timestamp(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + for records in data.values(): + timestamps = [_parse_timestamp_ms(r.timestamp) for r in records] + assert timestamps == sorted(timestamps) + + def test_single_record(self, tmp_path, default_user_config): + path = _build_session_file(tmp_path, [_make_record()]) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + assert len(data) == 1 + assert len(data["claude"]) == 1 + + +# ========================================================================= +# convert_to_conversations tests +# ========================================================================= + + +class TestConvertToConversations: + @pytest.fixture + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + conflux_include_utility_calls=True, + ), + ) + + @pytest.fixture + def team_session(self, tmp_path): + return _build_team_session(tmp_path) + + def test_conversation_count(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + assert len(conversations) == 4 # parent + 2 children + 1 orphan child + + def test_parent_turn_count(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + assert parent.agent_depth == 0 + assert len(parent.turns) == 5 + + def test_children_marked(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + children = [c for c in conversations if c.agent_depth > 0] + assert len(children) == 3 # 2 explicit subagents + 1 orphan + + def test_every_turn_has_raw_tools(self, team_session, default_user_config): + """Every turn should have raw_tools set from its record's tools.""" + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + for turn in parent.turns: + assert turn.raw_tools is not None + + def test_fallback_uses_raw_messages(self, team_session, default_user_config): + """Without base64, turns use raw_messages for provider-agnostic replay.""" + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + for turn in parent.turns: + assert turn.raw_messages is not None + + def test_fallback_max_tokens_on_turn(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + assert parent.turns[0].max_tokens == 32000 + + def test_fallback_raw_tools_present(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + assert parent.turns[0].raw_tools is not None + assert len(parent.turns[0].raw_tools) == 10 + + def test_turn_model(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + assert parent.turns[0].model == "claude-opus-4-6" + + def test_turn_input_tokens(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + assert parent.turns[0].input_tokens == 1000 + + def test_turn_has_absolute_timestamps(self, team_session, default_user_config): + """Turns use absolute timestamps, not relative delays.""" + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + assert parent.turns[0].timestamp is not None + assert parent.turns[1].timestamp is not None + assert parent.turns[1].timestamp > parent.turns[0].timestamp + assert parent.turns[0].delay is None + + def test_timestamp_spacing(self, team_session, default_user_config): + """Parent records are 5s apart, timestamps should reflect that.""" + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + delta = parent.turns[1].timestamp - parent.turns[0].timestamp + assert delta == pytest.approx(5000, abs=1) + + def test_raw_messages_present(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + for turn in parent.turns: + assert turn.raw_messages is not None + assert len(turn.raw_messages) >= 1 + + def test_subagent_spawns(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + assert len(parent.subagent_spawns) == 3 # 2 explicit + 1 orphan + + def test_subagent_spawn_blocking_detection(self, team_session, default_user_config): + """Explicit subagent children are detected as blocking or background + based on timestamp analysis; orphans are always background.""" + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + # Explicit children have gap_ms <= 2000 -> blocking + # Orphan is always background + explicit_spawns = [ + s for s in parent.subagent_spawns if len(s.child_conversation_ids) >= 1 + ] + orphan_spawns = [ + s + for s in parent.subagent_spawns + if any("orphan" in cid for cid in s.child_conversation_ids) + ] + non_orphan_spawns = [s for s in explicit_spawns if s not in orphan_spawns] + for spawn in non_orphan_spawns: + assert spawn.is_background is False + for spawn in orphan_spawns: + assert spawn.is_background is True + + def test_children_have_parent_conversation_id( + self, team_session, default_user_config + ): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + children = [c for c in conversations if c.agent_depth > 0] + assert len(children) >= 1 + for child in children: + assert child.parent_conversation_id == parent.session_id + + def test_subagent_spawn_ids_on_turns(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + all_spawn_ids = [sid for t in parent.turns for sid in t.subagent_spawn_ids] + assert len(all_spawn_ids) == 3 # 2 explicit + 1 orphan + for sid in all_spawn_ids: + assert sid.startswith("s") + + def test_delayed_join_prerequisite_inferred_from_content( + self, tmp_path, default_user_config + ): + path = _build_delayed_join_session(tmp_path, with_content_signals=True) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + + blocking_spawns = [ + spawn for spawn in parent.subagent_spawns if not spawn.is_background + ] + assert len(blocking_spawns) == 1 + spawn_id = blocking_spawns[0].spawn_id + + assert parent.turns[0].subagent_spawn_ids == [spawn_id] + assert parent.turns[1].prerequisites == [] + assert parent.turns[2].prerequisites == [] + assert len(parent.turns[3].prerequisites) == 1 + assert parent.turns[3].prerequisites[0].kind == PrerequisiteKind.SPAWN_JOIN + assert parent.turns[3].prerequisites[0].spawn_id == spawn_id + + def test_delayed_join_prerequisite_inferred_from_timing_fallback( + self, tmp_path, default_user_config + ): + path = _build_delayed_join_session(tmp_path, with_content_signals=False) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + parent = conversations[0] + + blocking_spawns = [ + spawn for spawn in parent.subagent_spawns if not spawn.is_background + ] + assert len(blocking_spawns) == 1 + spawn_id = blocking_spawns[0].spawn_id + + assert parent.turns[0].subagent_spawn_ids == [spawn_id] + assert parent.turns[1].prerequisites == [] + assert parent.turns[2].prerequisites == [] + assert len(parent.turns[3].prerequisites) == 1 + assert parent.turns[3].prerequisites[0].kind == PrerequisiteKind.SPAWN_JOIN + assert parent.turns[3].prerequisites[0].spawn_id == spawn_id + + def test_empty_data(self, tmp_path, default_user_config): + """Empty dict produces no conversations.""" + path = _build_session_file(tmp_path, [_make_record()]) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + conversations = loader.convert_to_conversations({}) + assert conversations == [] + + def test_session_id_prefix(self, team_session, default_user_config): + loader = ConfluxLoader(filename=team_session, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + for conv in conversations: + assert conv.session_id.startswith("conflux_") + + def test_no_hyperparameters_leaves_max_tokens_none( + self, tmp_path, default_user_config + ): + """When hyperparameters is None, max_tokens is None (server default).""" + records = [_make_record(hyperparameters=None)] + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + assert conversations[0].turns[0].max_tokens is None + + def test_fallback_uses_raw_messages_single_record( + self, tmp_path, default_user_config + ): + """Without base64, fallback path uses raw_messages for a single record.""" + records = [_make_record(is_streaming=None)] + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + assert conversations[0].turns[0].raw_messages is not None + + +# ========================================================================= +# _parse_timestamp_ms tests +# ========================================================================= + + +class TestParseTimestampMs: + def test_utc_timestamp(self): + ms = _parse_timestamp_ms("2026-02-25T02:02:00.000Z") + assert ms > 0 + + def test_millisecond_precision(self): + ms1 = _parse_timestamp_ms("2026-02-25T02:02:00.000Z") + ms2 = _parse_timestamp_ms("2026-02-25T02:02:01.000Z") + assert ms2 - ms1 == pytest.approx(1000, abs=1) + + def test_fractional_seconds(self): + ms1 = _parse_timestamp_ms("2026-02-25T02:02:00.000Z") + ms2 = _parse_timestamp_ms("2026-02-25T02:02:00.500Z") + assert ms2 - ms1 == pytest.approx(500, abs=1) + + +# ========================================================================= +# _build_raw_payload tests (base64 vs fallback) +# ========================================================================= + + +def _b64_encode(obj: dict) -> str: + """Base64-encode a dict as JSON.""" + return base64.b64encode(orjson.dumps(obj)).decode() + + +class TestExtractRecordFields: + """Tests for ConfluxLoader._extract_record_fields.""" + + def test_base64_extracts_messages_with_system_inline(self): + """Base64 path inlines system into messages array.""" + payload = { + "messages": [{"role": "user", "content": "hello"}], + "system": [{"type": "text", "text": "Be helpful"}], + "tools": [{"name": "Bash"}], + "max_tokens": 32000, + } + raw = _make_record() + raw["base64"] = {"request_body": _b64_encode(payload)} + record = ConfluxRecord.model_validate(raw) + messages, tools, max_tokens = ConfluxLoader._extract_record_fields(record) + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "user" + assert tools == [{"name": "Bash"}] + assert max_tokens == 32000 + + def test_base64_strips_metadata(self): + payload = { + "messages": [{"role": "user", "content": "hi"}], + "metadata": {"user_id": "secret"}, + } + raw = _make_record() + raw["base64"] = {"request_body": _b64_encode(payload)} + record = ConfluxRecord.model_validate(raw) + messages, _, _ = ConfluxLoader._extract_record_fields(record) + assert len(messages) == 1 + + def test_fallback_uses_top_level_fields(self): + raw = _make_record( + messages=[ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ], + tools=[{"name": "read_file", "input_schema": {}}], + hyperparameters={"max_tokens": 8192}, + ) + record = ConfluxRecord.model_validate(raw) + messages, tools, max_tokens = ConfluxLoader._extract_record_fields(record) + assert len(messages) == 2 + assert tools == [{"name": "read_file", "input_schema": {}}] + assert max_tokens == 8192 + + def test_fallback_no_tools(self): + raw = _make_record(tools=[]) + record = ConfluxRecord.model_validate(raw) + _, tools, _ = ConfluxLoader._extract_record_fields(record) + assert tools is None + + +class TestBase64Integration: + """Integration tests: base64 payloads flow through to conversation turns.""" + + @pytest.fixture + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + ), + ) + + def test_base64_normalized_to_raw_messages(self, tmp_path, default_user_config): + """Base64 records are normalized to raw_messages (not raw_payload).""" + ground_truth = { + "model": "claude-sonnet-4-6", + "messages": [{"role": "user", "content": "hello"}], + "system": [{"type": "text", "text": "System"}], + "tools": [{"name": "Bash", "description": "run", "input_schema": {}}], + "max_tokens": 32000, + "stream": True, + "thinking": {"type": "adaptive"}, + } + raw = _make_record( + messages=[ + {"role": "system", "content": [{"type": "text", "text": "System"}]}, + {"role": "user", "content": "hello"}, + ], + tools=[{"name": "Bash", "description": "run", "input_schema": {}}], + ) + raw["base64"] = {"request_body": _b64_encode(ground_truth)} + path = _build_session_file(tmp_path, [raw]) + + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + turn = conversations[0].turns[0] + assert turn.raw_messages is not None + # System message should be normalized to inline + assert any(m["role"] == "system" for m in turn.raw_messages) + # Tools should be normalized to OpenAI format + assert turn.raw_tools is not None + + def test_base64_max_tokens_propagates(self, tmp_path, default_user_config): + """max_tokens from base64 payload sets Turn.max_tokens.""" + ground_truth = { + "model": "claude-sonnet-4-6", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 16000, + "stream": True, + } + raw = _make_record() + raw["base64"] = {"request_body": _b64_encode(ground_truth)} + path = _build_session_file(tmp_path, [raw]) + + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + assert conversations[0].turns[0].max_tokens == 16000 + + +# ========================================================================= +# Provider detection tests +# ========================================================================= + + +class TestDetectConfluxProvider: + """Tests for ConfluxLoader._detect_conflux_provider.""" + + def test_client_claude(self): + raw = _make_record() + raw["client"] = "claude" + record = ConfluxRecord.model_validate(raw) + assert ConfluxLoader._detect_conflux_provider(record) == "anthropic" + + def test_client_codex(self): + raw = _make_record() + raw["client"] = "codex" + record = ConfluxRecord.model_validate(raw) + assert ConfluxLoader._detect_conflux_provider(record) == "openai" + + def test_provider_field_takes_precedence(self): + raw = _make_record() + raw["client"] = "codex" + raw["provider"] = "anthropic" + record = ConfluxRecord.model_validate(raw) + assert ConfluxLoader._detect_conflux_provider(record) == "anthropic" + + def test_provider_field_openai(self): + raw = _make_record() + raw["provider"] = "OpenAI" + record = ConfluxRecord.model_validate(raw) + assert ConfluxLoader._detect_conflux_provider(record) == "openai" + + def test_no_hints(self): + raw = _make_record() + record = ConfluxRecord.model_validate(raw) + assert ConfluxLoader._detect_conflux_provider(record) is None + + +# ========================================================================= +# Speedup ratio tests +# ========================================================================= + + +# ========================================================================= +# Spawn point overlap tests +# ========================================================================= + + +class TestFindSpawnPoint: + """Tests for ConfluxLoader._find_spawn_point.""" + + def test_overlap_via_completed_at(self): + """Child spawned during parent turn 1's in-flight period.""" + parent = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + is_subagent=False, + duration_ms=3000, + ) + | {"completed_at": _ts(3)} + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + is_subagent=False, + duration_ms=8000, + ) + | {"completed_at": _ts(13)} + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p2", + timestamp=_ts(15), + is_subagent=False, + duration_ms=2000, + ) + | {"completed_at": _ts(17)} + ), + ] + # Child spawned at t=7, which is during parent turn 1 (t=5 to t=13) + child = [ + ConfluxRecord.model_validate( + _make_record( + record_id="c0", + timestamp=_ts(7), + is_subagent=True, + ) + ), + ] + assert ConfluxLoader._find_spawn_point(parent, child) == 1 + + def test_overlap_via_duration_ms(self): + """Uses duration_ms when completed_at is absent.""" + parent = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + is_subagent=False, + duration_ms=3000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + is_subagent=False, + duration_ms=10000, + ) + ), + ] + child = [ + ConfluxRecord.model_validate( + _make_record( + record_id="c0", + timestamp=_ts(8), + is_subagent=True, + ) + ), + ] + assert ConfluxLoader._find_spawn_point(parent, child) == 1 + + def test_falls_back_to_closest_timestamp(self): + """No overlap data, falls back to closest timestamp.""" + parent = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + is_subagent=False, + duration_ms=0, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(10), + is_subagent=False, + duration_ms=0, + ) + ), + ] + child = [ + ConfluxRecord.model_validate( + _make_record( + record_id="c0", + timestamp=_ts(8), + is_subagent=True, + ) + ), + ] + assert ConfluxLoader._find_spawn_point(parent, child) == 1 + + +# ========================================================================= +# Un-enriched data (is_subagent=None) tests +# ========================================================================= + + +class TestUnEnrichedData: + """Tests for handling un-enriched proxy data where is_subagent is None.""" + + @pytest.fixture + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + ), + ) + + def test_unenriched_elects_largest_group_as_parent( + self, tmp_path, default_user_config + ): + """When all groups have is_subagent=None, largest becomes parent.""" + records = [] + # Agent A: 5 turns (should become parent) + for i in range(5): + raw = _make_record( + record_id=f"a_{i}", + agent_id="agent_a", + timestamp=_ts(i * 5), + ) + del raw["is_subagent"] + records.append(raw) + # Agent B: 2 turns (should become child) + for i in range(2): + raw = _make_record( + record_id=f"b_{i}", + agent_id="agent_b", + timestamp=_ts(3 + i * 3), + ) + del raw["is_subagent"] + records.append(raw) + + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + assert len(conversations) == 2 + parent = conversations[0] + assert parent.agent_depth == 0 + assert len(parent.turns) == 5 + child = conversations[1] + assert child.agent_depth == 1 + assert len(child.turns) == 2 + + def test_proxy_source_can_load(self, tmp_path): + """Records with source=proxy but no is_subagent key are loadable.""" + raw = _make_record() + del raw["is_subagent"] + del raw["agent_id"] + raw["source"] = "proxy" + path = _build_session_file(tmp_path, [raw]) + assert ConfluxLoader.can_load(filename=path) + + +# ========================================================================= +# Propagated fields tests (extra_params, ground_truth, origin) +# ========================================================================= + + +class TestPropagatedFields: + """Tests for extra_params, ground_truth, and origin population.""" + + @pytest.fixture + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + ), + ) + + def _load_conversations(self, tmp_path, records, user_config): + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=user_config) + data = loader.load_dataset() + return loader.convert_to_conversations(data) + + def test_extra_params_from_hyperparameters(self, tmp_path, default_user_config): + """Hyperparameters beyond max_tokens populate extra_params.""" + records = [ + _make_record( + hyperparameters={ + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 0.9, + "seed": 42, + } + ) + ] + convs = self._load_conversations(tmp_path, records, default_user_config) + turn = convs[0].turns[0] + assert turn.extra_params is not None + assert turn.extra_params["temperature"] == 0.7 + assert turn.extra_params["top_p"] == 0.9 + assert turn.extra_params["seed"] == 42 + assert "max_tokens" not in turn.extra_params + + def test_extra_params_none_when_only_max_tokens( + self, tmp_path, default_user_config + ): + """extra_params is None when hyperparameters only has max_tokens.""" + records = [_make_record(hyperparameters={"max_tokens": 4096})] + convs = self._load_conversations(tmp_path, records, default_user_config) + assert convs[0].turns[0].extra_params is None + + def test_extra_params_none_when_no_hyperparameters( + self, tmp_path, default_user_config + ): + """extra_params is None when hyperparameters is absent.""" + records = [_make_record(hyperparameters=None)] + convs = self._load_conversations(tmp_path, records, default_user_config) + assert convs[0].turns[0].extra_params is None + + def test_ground_truth_from_tokens_and_timing(self, tmp_path, default_user_config): + """Token breakdown, timing, and output populate ground_truth.""" + records = [ + _make_record( + tokens={ + "input": 1000, + "input_cached": 800, + "input_cache_write": 100, + "output": 200, + "output_reasoning": 50, + }, + is_streaming=True, + duration_ms=1500, + ) + ] + convs = self._load_conversations(tmp_path, records, default_user_config) + gt = convs[0].turns[0].ground_truth + assert gt is not None + assert gt.input_cached_tokens == 800 + assert gt.input_cache_write_tokens == 100 + assert gt.output_tokens == 200 + assert gt.output_reasoning_tokens == 50 + assert gt.ttft_ms == 150 # from _make_record default + assert gt.duration_ms == 1500 + assert gt.is_streaming is True + + def test_ground_truth_none_when_no_detail(self, tmp_path, default_user_config): + """ground_truth is None when no token detail, timing, or output.""" + raw = _make_record( + tokens={ + "input": 100, + "input_cached": 0, + "input_cache_write": 0, + "output": 0, + }, + is_streaming=None, + duration_ms=0, + ) + raw["ttft_ms"] = None + raw["output"] = [] + records = [raw] + convs = self._load_conversations(tmp_path, records, default_user_config) + assert convs[0].turns[0].ground_truth is None + + def test_origin_from_record_metadata(self, tmp_path, default_user_config): + """Provenance populated from first record's source/client/session fields.""" + raw = _make_record() + raw["source"] = "proxy" + raw["client"] = "claude" + raw["client_version"] = "1.2.3" + raw["request_id"] = "req_abc" + records = [raw] + convs = self._load_conversations(tmp_path, records, default_user_config) + prov = convs[0].origin + assert prov is not None + assert prov.source == "proxy" + assert prov.client == "claude" + assert prov.client_version == "1.2.3" + assert prov.original_session_id == "sess-001" + assert prov.original_request_ids == ["req_abc"] + + def test_origin_collects_all_request_ids(self, tmp_path, default_user_config): + """All request_ids across turns are collected in origin.""" + records = [ + _make_record(record_id="r0", timestamp=_ts(0)) | {"request_id": "req_0"}, + _make_record(record_id="r1", timestamp=_ts(5)) | {"request_id": "req_1"}, + _make_record(record_id="r2", timestamp=_ts(10)) | {"request_id": "req_2"}, + ] + convs = self._load_conversations(tmp_path, records, default_user_config) + assert convs[0].origin.original_request_ids == [ + "req_0", + "req_1", + "req_2", + ] + + def test_origin_skips_null_request_ids(self, tmp_path, default_user_config): + """request_ids that are None are excluded from origin.""" + records = [ + _make_record(record_id="r0", timestamp=_ts(0)), + _make_record(record_id="r1", timestamp=_ts(5)) | {"request_id": "req_1"}, + ] + convs = self._load_conversations(tmp_path, records, default_user_config) + assert convs[0].origin.original_request_ids == ["req_1"] + + +# ========================================================================= +# get_default_context_mode tests +# ========================================================================= + + +class TestDefaultContextMode: + def test_returns_message_array_with_responses(self) -> None: + assert ( + ConfluxLoader.get_default_context_mode() + == ConversationContextMode.MESSAGE_ARRAY_WITH_RESPONSES + ) + + +# ========================================================================= +# _new_messages helper tests +# ========================================================================= + + +class TestNewMessages: + def test_identical_prefix_returns_appended(self) -> None: + prev = [{"role": "user", "content": "a"}] + curr = [{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}] + result = _new_messages(prev, curr) + assert result == [{"role": "assistant", "content": "b"}] + + def test_no_common_prefix(self) -> None: + prev = [{"role": "user", "content": "x"}] + curr = [{"role": "user", "content": "y"}] + result = _new_messages(prev, curr) + assert result == [{"role": "user", "content": "y"}] + + def test_empty_previous(self) -> None: + curr = [{"role": "user", "content": "a"}] + result = _new_messages([], curr) + assert result == curr + + def test_empty_current(self) -> None: + prev = [{"role": "user", "content": "a"}] + result = _new_messages(prev, []) + assert result == [] + + def test_both_empty(self) -> None: + assert _new_messages([], []) == [] + + def test_full_prefix_match(self) -> None: + msgs = [{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}] + result = _new_messages(msgs, list(msgs)) + assert result == [] + + +# ========================================================================= +# _iter_message_blocks helper tests +# ========================================================================= + + +class TestIterMessageBlocks: + def test_extracts_dict_blocks_from_list_content(self) -> None: + messages = [ + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "t1", "name": "Agent"}, + {"type": "text", "text": "hello"}, + ], + }, + ] + blocks = _iter_message_blocks(messages) + assert len(blocks) == 2 + assert blocks[0]["type"] == "tool_use" + + def test_skips_string_content(self) -> None: + messages = [{"role": "user", "content": "just a string"}] + blocks = _iter_message_blocks(messages) + assert blocks == [] + + def test_skips_non_dict_items_in_list(self) -> None: + messages = [{"role": "user", "content": ["string_item", 42, {"type": "text"}]}] + blocks = _iter_message_blocks(messages) + assert len(blocks) == 1 + assert blocks[0]["type"] == "text" + + def test_empty_messages(self) -> None: + assert _iter_message_blocks([]) == [] + + +# ========================================================================= +# _stringify_block_content helper tests +# ========================================================================= + + +class TestStringifyBlockContent: + def test_string_passthrough(self) -> None: + assert _stringify_block_content("hello") == "hello" + + def test_list_joins(self) -> None: + result = _stringify_block_content(["a", "b"]) + assert result == "a b" + + def test_dict_with_text_key(self) -> None: + assert _stringify_block_content({"text": "found it"}) == "found it" + + def test_dict_with_content_key_recurses(self) -> None: + nested = {"content": [{"text": "inner"}]} + result = _stringify_block_content(nested) + assert "inner" in result + + def test_dict_fallback_to_json(self) -> None: + result = _stringify_block_content({"key": "value"}) + assert "key" in result + assert "value" in result + + def test_non_standard_type(self) -> None: + assert _stringify_block_content(42) == "42" + + def test_nested_list_with_dicts(self) -> None: + data = [{"text": "a"}, "b"] + result = _stringify_block_content(data) + assert "a" in result + assert "b" in result + + +# ========================================================================= +# _record_end_ms helper tests +# ========================================================================= + + +class TestRecordEndMs: + def test_uses_completed_at_when_available(self) -> None: + record = ConfluxRecord.model_validate( + _make_record(timestamp=_ts(0), duration_ms=5000) | {"completed_at": _ts(10)} + ) + end = _record_end_ms(record) + start = _parse_timestamp_ms(_ts(0)) + expected = _parse_timestamp_ms(_ts(10)) + assert end == expected + assert end != start + 5000 + + def test_uses_duration_ms_when_no_completed_at(self) -> None: + record = ConfluxRecord.model_validate( + _make_record(timestamp=_ts(0), duration_ms=3000) + ) + end = _record_end_ms(record) + start = _parse_timestamp_ms(_ts(0)) + assert end == pytest.approx(start + 3000, abs=1) + + def test_returns_start_when_no_timing_data(self) -> None: + record = ConfluxRecord.model_validate( + _make_record(timestamp=_ts(5), duration_ms=0) + ) + end = _record_end_ms(record) + start = _parse_timestamp_ms(_ts(5)) + assert end == start + + +# ========================================================================= +# _detect_join_turn_from_content tests +# ========================================================================= + + +def _make_content_records( + *, + spawn_tool_use_id: str = "toolu_spawn_x", + queued_text: str = "queued for running", + result_text: str = "child finished", + result_at_turn: int = 3, + num_turns: int = 4, + background_only: bool = False, +) -> list[ConfluxRecord]: + """Build parent records with Agent tool_use / tool_result signals.""" + turn0_msgs = [{"role": "user", "content": "turn 0"}] + agent_block = { + "type": "tool_use", + "id": spawn_tool_use_id, + "name": "Agent", + "input": {"task": "do something"}, + } + queued_block = { + "type": "tool_result", + "tool_use_id": spawn_tool_use_id, + "content": queued_text, + } + result_block = { + "type": "tool_result", + "tool_use_id": spawn_tool_use_id, + "content": result_text, + } + + all_messages: list[list[dict]] = [turn0_msgs] + for i in range(1, num_turns): + prev = all_messages[i - 1] + if i == 1: + new = prev + [ + {"role": "assistant", "content": [agent_block]}, + {"role": "user", "content": [queued_block]}, + {"role": "user", "content": f"turn {i}"}, + ] + elif not background_only and i == result_at_turn: + new = prev + [ + {"role": "assistant", "content": f"work at turn {i}"}, + {"role": "user", "content": [result_block]}, + {"role": "user", "content": f"turn {i}"}, + ] + else: + new = prev + [ + {"role": "assistant", "content": f"work at turn {i}"}, + {"role": "user", "content": f"turn {i}"}, + ] + all_messages.append(new) + + records = [] + for i, msgs in enumerate(all_messages): + records.append( + ConfluxRecord.model_validate( + _make_record( + record_id=f"p{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 5), + messages=msgs, + duration_ms=2000, + ) + ) + ) + return records + + +class TestDetectJoinTurnFromContent: + def test_finds_join_at_result_turn(self) -> None: + records = _make_content_records(result_at_turn=3) + join_idx, saw_bg = _detect_join_turn_from_content(records, spawn_turn_index=0) + assert join_idx == 3 + assert saw_bg is False + + def test_background_signal_when_only_queued(self) -> None: + records = _make_content_records(background_only=True) + join_idx, saw_bg = _detect_join_turn_from_content(records, spawn_turn_index=0) + assert join_idx is None + assert saw_bg is True + + def test_spawn_at_last_turn_returns_none(self) -> None: + records = _make_content_records(num_turns=2) + join_idx, saw_bg = _detect_join_turn_from_content(records, spawn_turn_index=1) + assert join_idx is None + assert saw_bg is False + + def test_ignores_existing_agent_ids(self) -> None: + """Pre-existing Agent tool_use IDs in the spawn turn's history are not treated as new spawns.""" + old_agent_block = { + "type": "tool_use", + "id": "toolu_old", + "name": "Agent", + "input": {"task": "old task"}, + } + turn0_msgs = [ + {"role": "user", "content": "start"}, + {"role": "assistant", "content": [old_agent_block]}, + ] + # Turn 1 re-surfaces the old tool_use (common-prefix diff edge case) + turn1_msgs = turn0_msgs + [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_old", + "content": "old result", + }, + ], + }, + {"role": "user", "content": "turn 1"}, + ] + records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + messages=turn0_msgs, + duration_ms=1000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + messages=turn1_msgs, + duration_ms=1000, + ) + ), + ] + join_idx, saw_bg = _detect_join_turn_from_content(records, spawn_turn_index=0) + assert join_idx is None + assert saw_bg is False + + +# ========================================================================= +# _find_join_turn_index tests +# ========================================================================= + + +class TestFindJoinTurnIndex: + def test_content_based_join_preferred(self) -> None: + """Content-based detection takes priority over timing.""" + records = _make_content_records(result_at_turn=3, num_turns=5) + child_records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="c0", + agent_id="sub", + is_subagent=True, + timestamp=_ts(1), + duration_ms=1000, + ) + ) + ] + from aiperf.common.models import Conversation + + child_conv = Conversation(session_id="conflux_sub") + children = [("sub", child_records, child_conv)] + result = _find_join_turn_index(records, 0, children) + assert result == 3 + + def test_timing_fallback_when_no_content(self) -> None: + """Falls back to timing when no content signals exist.""" + parent_records = [ + ConfluxRecord.model_validate( + _make_record( + record_id=f"p{i}", + timestamp=_ts(i * 10), + duration_ms=2000, + ) + | {"completed_at": _ts(i * 10 + 2)} + ) + for i in range(4) + ] + # Child completes at t=15 + child_records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="c0", + agent_id="sub", + is_subagent=True, + timestamp=_ts(1), + duration_ms=14000, + ) + | {"completed_at": _ts(15)} + ) + ] + from aiperf.common.models import Conversation + + child_conv = Conversation(session_id="conflux_sub") + children = [("sub", child_records, child_conv)] + # Parent turn 2 starts at t=20, which is after child ends at t=15 + result = _find_join_turn_index(parent_records, 0, children) + assert result == 2 + + def test_returns_none_when_background_signal(self) -> None: + """Returns None when content signals indicate background-only spawn.""" + records = _make_content_records(background_only=True, num_turns=4) + child_records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="c0", + agent_id="sub", + is_subagent=True, + timestamp=_ts(1), + duration_ms=500, + ) + ) + ] + from aiperf.common.models import Conversation + + child_conv = Conversation(session_id="conflux_sub") + children = [("sub", child_records, child_conv)] + result = _find_join_turn_index(records, 0, children) + assert result is None + + +# ========================================================================= +# _extract_notification_joins tests +# ========================================================================= + + +class TestExtractNotificationJoins: + def test_finds_task_notification(self) -> None: + turn0_msgs = [{"role": "user", "content": "start"}] + notification_text = ( + "Some preamble " + "toolu_abc123" + " more text" + ) + turn1_msgs = turn0_msgs + [ + {"role": "assistant", "content": "working"}, + {"role": "user", "content": notification_text}, + ] + records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + messages=turn0_msgs, + duration_ms=1000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + messages=turn1_msgs, + duration_ms=1000, + ) + ), + ] + joins = _extract_notification_joins(records) + assert joins == {"toolu_abc123": 1} + + def test_no_notifications(self) -> None: + records = [ + ConfluxRecord.model_validate( + _make_record(record_id="p0", timestamp=_ts(0), duration_ms=1000) + ), + ] + joins = _extract_notification_joins(records) + assert joins == {} + + def test_multiple_notifications_first_wins(self) -> None: + """First occurrence of a tool_use_id is kept.""" + turn0_msgs = [{"role": "user", "content": "start"}] + notif = ( + "toolu_x" + ) + turn1_msgs = turn0_msgs + [{"role": "user", "content": notif}] + turn2_msgs = turn1_msgs + [{"role": "user", "content": notif}] + records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + messages=turn0_msgs, + duration_ms=1000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + messages=turn1_msgs, + duration_ms=1000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p2", + timestamp=_ts(10), + messages=turn2_msgs, + duration_ms=1000, + ) + ), + ] + joins = _extract_notification_joins(records) + assert joins["toolu_x"] == 1 + + def test_notification_in_list_content_block(self) -> None: + """Handles notifications inside structured content blocks.""" + turn0_msgs = [{"role": "user", "content": "start"}] + notif_block = { + "type": "text", + "text": "toolu_y", + } + turn1_msgs = turn0_msgs + [{"role": "user", "content": [notif_block]}] + records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + messages=turn0_msgs, + duration_ms=1000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + messages=turn1_msgs, + duration_ms=1000, + ) + ), + ] + joins = _extract_notification_joins(records) + assert joins == {"toolu_y": 1} + + +# ========================================================================= +# _build_spawn_tuid_to_agent_id tests +# ========================================================================= + + +class TestBuildSpawnTuidToAgentId: + def test_maps_async_agent_launched(self) -> None: + turn0_msgs = [{"role": "user", "content": "start"}] + agent_block = { + "type": "tool_use", + "id": "toolu_new_1", + "name": "Agent", + "input": {"task": "do work"}, + } + result_block = { + "type": "tool_result", + "tool_use_id": "toolu_new_1", + "content": "Async agent launched, agentId: agent_abc", + } + turn1_msgs = turn0_msgs + [ + {"role": "assistant", "content": [agent_block]}, + {"role": "user", "content": [result_block]}, + ] + records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + messages=turn0_msgs, + duration_ms=1000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + messages=turn1_msgs, + duration_ms=1000, + ) + ), + ] + mapping = _build_spawn_tuid_to_agent_id(records, spawn_turn_index=0) + assert mapping == {"toolu_new_1": "agent_abc"} + + def test_ignores_pre_existing_agent_ids(self) -> None: + old_agent = { + "type": "tool_use", + "id": "toolu_old", + "name": "Agent", + "input": {"task": "old"}, + } + turn0_msgs = [ + {"role": "user", "content": "start"}, + {"role": "assistant", "content": [old_agent]}, + ] + # Turn 1: old agent completes (not a new spawn) + result_block = { + "type": "tool_result", + "tool_use_id": "toolu_old", + "content": "Async agent launched, agentId: old_agent", + } + turn1_msgs = turn0_msgs + [ + {"role": "user", "content": [result_block]}, + ] + records = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + messages=turn0_msgs, + duration_ms=1000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(5), + messages=turn1_msgs, + duration_ms=1000, + ) + ), + ] + mapping = _build_spawn_tuid_to_agent_id(records, spawn_turn_index=0) + assert mapping == {} + + def test_spawn_at_last_turn_returns_empty(self) -> None: + records = [ + ConfluxRecord.model_validate( + _make_record(record_id="p0", timestamp=_ts(0), duration_ms=1000) + ), + ] + mapping = _build_spawn_tuid_to_agent_id(records, spawn_turn_index=0) + assert mapping == {} + + +# ========================================================================= +# Orphan filtering tests +# ========================================================================= + + +class TestOrphanFiltering: + @pytest.fixture + def no_orphan_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + conflux_include_utility_calls=False, + ), + ) + + def test_orphans_excluded_when_disabled(self, tmp_path, no_orphan_config) -> None: + """Orphan records are filtered out when conflux_include_utility_calls=False.""" + records = [ + _make_record( + record_id="p0", agent_id="claude", is_subagent=False, timestamp=_ts(0) + ), + _make_record(record_id="orphan0", agent_id=None, timestamp=_ts(5)), + ] + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=no_orphan_config) + data = loader.load_dataset() + assert len(data) == 1 + assert "claude" in data + + def test_orphans_excluded_produce_no_child_conversations( + self, tmp_path, no_orphan_config + ) -> None: + records = [ + _make_record( + record_id="p0", agent_id="claude", is_subagent=False, timestamp=_ts(0) + ), + _make_record(record_id="orphan0", agent_id=None, timestamp=_ts(5)), + ] + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=no_orphan_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + assert len(conversations) == 1 + assert conversations[0].agent_depth == 0 + assert conversations[0].subagent_spawns == [] + + +# ========================================================================= +# _find_spawn_point tier 2 (post-completion gap) tests +# ========================================================================= + + +class TestFindSpawnPointPostCompletionGap: + def test_gap_via_completed_at(self) -> None: + """Child spawned in gap between turn 0 completing and turn 1 starting.""" + parent = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + is_subagent=False, + duration_ms=2000, + ) + | {"completed_at": _ts(2)} + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(10), + is_subagent=False, + duration_ms=2000, + ) + | {"completed_at": _ts(12)} + ), + ] + # Child at t=3 is in gap (t=2 to t=10) + child = [ + ConfluxRecord.model_validate( + _make_record(record_id="c0", timestamp=_ts(3), is_subagent=True) + ) + ] + assert ConfluxLoader._find_spawn_point(parent, child) == 0 + + def test_gap_via_duration_ms(self) -> None: + """Gap detection using duration_ms when completed_at is absent.""" + parent = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + is_subagent=False, + duration_ms=2000, + ) + ), + ConfluxRecord.model_validate( + _make_record( + record_id="p1", + timestamp=_ts(10), + is_subagent=False, + duration_ms=2000, + ) + ), + ] + child = [ + ConfluxRecord.model_validate( + _make_record(record_id="c0", timestamp=_ts(5), is_subagent=True) + ) + ] + assert ConfluxLoader._find_spawn_point(parent, child) == 0 + + def test_gap_after_last_parent_turn(self) -> None: + """Child spawned after the last parent turn completes (open-ended gap).""" + parent = [ + ConfluxRecord.model_validate( + _make_record( + record_id="p0", + timestamp=_ts(0), + is_subagent=False, + duration_ms=2000, + ) + | {"completed_at": _ts(2)} + ), + ] + child = [ + ConfluxRecord.model_validate( + _make_record(record_id="c0", timestamp=_ts(5), is_subagent=True) + ) + ] + assert ConfluxLoader._find_spawn_point(parent, child) == 0 + + +# ========================================================================= +# Notification-based join splitting integration test +# ========================================================================= + + +def _build_notification_join_session(tmp_path) -> str: + """Build a session where async spawns complete via .""" + agent_tool_use_1 = { + "type": "tool_use", + "id": "toolu_async_1", + "name": "Agent", + "input": {"task": "research auth"}, + } + agent_tool_use_2 = { + "type": "tool_use", + "id": "toolu_async_2", + "name": "Agent", + "input": {"task": "research db"}, + } + async_result_1 = { + "type": "tool_result", + "tool_use_id": "toolu_async_1", + "content": "Async agent launched, agentId: sub_a", + } + async_result_2 = { + "type": "tool_result", + "tool_use_id": "toolu_async_2", + "content": "Async agent launched, agentId: sub_b", + } + + turn0_msgs = [{"role": "user", "content": "turn 0"}] + turn1_msgs = turn0_msgs + [ + {"role": "assistant", "content": [agent_tool_use_1, agent_tool_use_2]}, + {"role": "user", "content": [async_result_1, async_result_2]}, + {"role": "user", "content": "turn 1"}, + ] + notif_a = ( + "" + "toolu_async_1" + "" + ) + turn2_msgs = turn1_msgs + [ + {"role": "assistant", "content": "working"}, + {"role": "user", "content": notif_a}, + {"role": "user", "content": "turn 2"}, + ] + notif_b = ( + "" + "toolu_async_2" + "" + ) + turn3_msgs = turn2_msgs + [ + {"role": "assistant", "content": "more work"}, + {"role": "user", "content": notif_b}, + {"role": "user", "content": "turn 3"}, + ] + + records = [] + parent_msgs = [turn0_msgs, turn1_msgs, turn2_msgs, turn3_msgs] + for i, msgs in enumerate(parent_msgs): + records.append( + _make_record( + record_id=f"p{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 5), + messages=msgs, + duration_ms=2000, + ) + ) + + # Child A: spawned at t=1 + records.append( + _make_record( + record_id="ca0", + agent_id="sub_a", + is_subagent=True, + timestamp=_ts(1), + duration_ms=8000, + ) + ) + + # Child B: spawned at t=2 + records.append( + _make_record( + record_id="cb0", + agent_id="sub_b", + is_subagent=True, + timestamp=_ts(2), + duration_ms=12000, + ) + ) + + return _build_session_file(tmp_path, records) + + +class TestNotificationJoinIntegration: + @pytest.fixture + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + ), + ) + + def test_notification_splits_into_per_child_blocking_spawns( + self, tmp_path, default_user_config + ) -> None: + """Async spawns with produce per-child blocking spawns.""" + path = _build_notification_join_session(tmp_path) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + parent = conversations[0] + # Two children => two separate blocking spawns (not one grouped background) + blocking = [s for s in parent.subagent_spawns if not s.is_background] + assert len(blocking) == 2 + for spawn in blocking: + assert len(spawn.child_conversation_ids) == 1 + + def test_notification_join_prerequisites_on_correct_turns( + self, tmp_path, default_user_config + ) -> None: + """Each notification-based join creates a prerequisite on the notification turn.""" + path = _build_notification_join_session(tmp_path) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + parent = conversations[0] + # Turn 0: spawn_ids should be set + assert len(parent.turns[0].subagent_spawn_ids) == 2 + # Turn 2: sub_a notification -> prerequisite + prereqs_2 = parent.turns[2].prerequisites + assert len(prereqs_2) == 1 + assert prereqs_2[0].kind == PrerequisiteKind.SPAWN_JOIN + # Turn 3: sub_b notification -> prerequisite + prereqs_3 = parent.turns[3].prerequisites + assert len(prereqs_3) == 1 + assert prereqs_3[0].kind == PrerequisiteKind.SPAWN_JOIN + # The two spawns should reference different spawn_ids + assert prereqs_2[0].spawn_id != prereqs_3[0].spawn_id + + +class TestZeroAlignTimestamps: + """Tests for zero-aligning timestamps in convert_to_conversations.""" + + @pytest.fixture() + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + ), + ) + + def test_single_conversation_timestamps_start_at_zero( + self, tmp_path, default_user_config + ) -> None: + """All timestamps shift so the earliest becomes 0.""" + records = [ + _make_record( + record_id=f"req_{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 10), + ) + for i in range(4) + ] + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + timestamps = [t.timestamp for t in conversations[0].turns] + assert timestamps[0] == 0.0 + assert timestamps[1] == 10_000.0 + assert timestamps[2] == 20_000.0 + assert timestamps[3] == 30_000.0 + + def test_parent_and_children_all_zero_aligned( + self, tmp_path, default_user_config + ) -> None: + """Parent and subagent timestamps are shifted by the same global minimum.""" + path = _build_team_session(tmp_path) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + all_timestamps = [ + t.timestamp + for c in conversations + for t in c.turns + if t.timestamp is not None + ] + assert min(all_timestamps) == 0.0 + # Parent turn 0 is the global minimum (_ts(0)), so it should be 0 + parent = conversations[0] + assert parent.turns[0].timestamp == 0.0 + + def test_relative_spacing_preserved(self, tmp_path, default_user_config) -> None: + """Inter-turn gaps are identical before and after alignment.""" + records = [ + _make_record( + record_id=f"req_{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 7), + ) + for i in range(3) + ] + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + ts = [t.timestamp for t in conversations[0].turns] + gaps = [ts[i + 1] - ts[i] for i in range(len(ts) - 1)] + assert gaps == [7_000.0, 7_000.0] + + def test_child_earlier_than_parent_becomes_zero( + self, tmp_path, default_user_config + ) -> None: + """When a child timestamp precedes the parent, the child becomes 0.""" + records = [ + _make_record( + record_id="req_parent_0", + agent_id="claude", + is_subagent=False, + timestamp=_ts(10), + duration_ms=2000, + ), + _make_record( + record_id="req_child_0", + agent_id="sub_a", + is_subagent=True, + timestamp=_ts(5), + ), + ] + path = _build_session_file(tmp_path, records) + loader = ConfluxLoader(filename=path, user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + + all_timestamps = [ + t.timestamp + for c in conversations + for t in c.turns + if t.timestamp is not None + ] + assert min(all_timestamps) == 0.0 + # Child at _ts(5) is earliest -> becomes 0; parent at _ts(10) -> 5000 + child = next(c for c in conversations if c.agent_depth == 1) + parent = next(c for c in conversations if c.agent_depth == 0) + assert child.turns[0].timestamp == 0.0 + assert parent.turns[0].timestamp == 5_000.0 + + def test_already_zero_aligned_is_noop(self) -> None: + """If earliest timestamp is already 0, nothing changes.""" + from aiperf.common.models import Conversation, Turn + + conv = Conversation(session_id="test") + conv.turns = [ + Turn(role="user", timestamp=0.0, max_tokens=100), + Turn(role="user", timestamp=5_000.0, max_tokens=100), + ] + ConfluxLoader._zero_align_timestamps([conv]) + assert conv.turns[0].timestamp == 0.0 + assert conv.turns[1].timestamp == 5_000.0 + + def test_empty_conversations_no_error(self) -> None: + """Empty conversation list does not raise.""" + ConfluxLoader._zero_align_timestamps([]) + + def test_single_turn_becomes_zero(self) -> None: + """A single-turn session normalizes to timestamp 0.""" + from aiperf.common.models import Conversation, Turn + + conv = Conversation(session_id="test") + conv.turns = [Turn(role="user", timestamp=999_999.0, max_tokens=100)] + ConfluxLoader._zero_align_timestamps([conv]) + assert conv.turns[0].timestamp == 0.0 + + +# ========================================================================= +# Directory loading tests +# ========================================================================= + + +class TestDirectoryCanLoad: + """Test can_load with directories.""" + + def test_directory_with_conflux_json(self, tmp_path): + """Directory containing a valid Conflux JSON file is accepted.""" + f = tmp_path / "session1.json" + f.write_bytes(orjson.dumps([_make_record()])) + assert ConfluxLoader.can_load(filename=str(tmp_path)) + + def test_empty_directory(self, tmp_path): + """Directory with no JSON files is rejected.""" + assert not ConfluxLoader.can_load(filename=str(tmp_path)) + + def test_directory_with_non_conflux_json(self, tmp_path): + """Directory with non-Conflux JSON files is rejected.""" + f = tmp_path / "other.json" + f.write_bytes(orjson.dumps({"key": "value"})) + assert not ConfluxLoader.can_load(filename=str(tmp_path)) + + def test_directory_with_mixed_files(self, tmp_path): + """Directory with one valid and one non-JSON file is accepted.""" + (tmp_path / "session.json").write_bytes(orjson.dumps([_make_record()])) + (tmp_path / "readme.txt").write_text("not json") + assert ConfluxLoader.can_load(filename=str(tmp_path)) + + +class TestDirectoryLoadDataset: + """Test load_dataset and convert_to_conversations with directory input. + + Each file in a directory is an independent session (separate capture). + Agent IDs are prefixed with ``f_`` to avoid cross-file collisions, + and each file is zero-aligned independently. + """ + + @pytest.fixture + def default_user_config(self): + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + prompt=PromptConfig(input_tokens=InputTokensConfig(mean=100)), + conflux_include_utility_calls=True, + ), + ) + + def _write_session(self, tmp_path, filename, records): + (tmp_path / filename).write_bytes(orjson.dumps(records)) + + def test_loads_records_from_multiple_files(self, tmp_path, default_user_config): + """Each file is an independent session with prefixed agent_ids.""" + # File 0: session with 3 turns + self._write_session( + tmp_path, + "session_a.json", + [ + _make_record( + record_id=f"req_a{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 5), + ) + for i in range(3) + ], + ) + # File 1: separate session with 2 turns (same agent_id, different file) + self._write_session( + tmp_path, + "session_b.json", + [ + _make_record( + record_id=f"req_b{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(100 + i * 5), + ) + for i in range(2) + ], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=default_user_config) + data = loader.load_dataset() + assert "f0_claude" in data + assert "f1_claude" in data + assert len(data["f0_claude"]) == 3 + assert len(data["f1_claude"]) == 2 + + def test_converts_directory_to_independent_conversations( + self, tmp_path, default_user_config + ): + """Each file produces its own parent+children conversations.""" + # File 0: parent + child session + self._write_session( + tmp_path, + "session_a.json", + [ + _make_record( + record_id=f"req_p{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 5), + ) + for i in range(3) + ] + + [ + _make_record( + record_id=f"req_c{i}", + agent_id="sub_a", + is_subagent=True, + timestamp=_ts(5 + i * 3), + ) + for i in range(2) + ], + ) + # File 1: standalone session + self._write_session( + tmp_path, + "session_b.json", + [ + _make_record( + record_id=f"req_x{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(200 + i * 5), + ) + for i in range(2) + ], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + # File 0: parent + child = 2 conversations + # File 1: standalone parent = 1 conversation + assert len(conversations) == 3 + + def test_per_file_zero_alignment(self, tmp_path, default_user_config): + """Each file's timestamps are zero-aligned independently.""" + # File 0: starts at t=1000s + self._write_session( + tmp_path, + "early.json", + [ + _make_record( + record_id=f"req_e{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(1000 + i * 10), + ) + for i in range(3) + ], + ) + # File 1: starts at t=5000s (completely different time origin) + self._write_session( + tmp_path, + "late.json", + [ + _make_record( + record_id=f"req_l{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(5000 + i * 10), + ) + for i in range(3) + ], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + assert len(conversations) == 2 + + # Both files should start at timestamp 0 independently + for conv in conversations: + assert conv.turns[0].timestamp == 0.0 + + def test_empty_directory_raises(self, tmp_path, default_user_config): + """Loading from an empty directory raises FileNotFoundError.""" + loader = ConfluxLoader(filename=str(tmp_path), user_config=default_user_config) + with pytest.raises(FileNotFoundError, match="No .json files found"): + loader.load_dataset() + + def test_skips_non_json_files(self, tmp_path, default_user_config): + """Non-JSON files in the directory are ignored.""" + self._write_session( + tmp_path, + "session.json", + [ + _make_record( + record_id=f"req_{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 5), + ) + for i in range(2) + ], + ) + (tmp_path / "notes.txt").write_text("not json") + (tmp_path / "data.csv").write_text("a,b,c") + + loader = ConfluxLoader(filename=str(tmp_path), user_config=default_user_config) + data = loader.load_dataset() + assert len(data["f0_claude"]) == 2 + + def test_single_file_directory(self, tmp_path, default_user_config): + """Directory with one file behaves like a single session.""" + self._write_session( + tmp_path, + "all.json", + [ + _make_record( + record_id=f"req_{i}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 5), + ) + for i in range(4) + ], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=default_user_config) + data = loader.load_dataset() + assert len(data["f0_claude"]) == 4 + + def test_same_agent_ids_across_files_no_collision( + self, tmp_path, default_user_config + ): + """Two files with identical agent_ids produce separate conversations.""" + for i, name in enumerate(["file_a.json", "file_b.json"]): + self._write_session( + tmp_path, + name, + [ + _make_record( + record_id=f"req_{name}_{j}", + agent_id="claude", + is_subagent=False, + timestamp=_ts(i * 1000 + j * 5), + ) + for j in range(3) + ], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=default_user_config) + data = loader.load_dataset() + conversations = loader.convert_to_conversations(data) + assert len(conversations) == 2 + session_ids = {c.session_id for c in conversations} + assert len(session_ids) == 2 diff --git a/tests/unit/dataset/test_message_normalizer.py b/tests/unit/dataset/test_message_normalizer.py new file mode 100644 index 000000000..d83ad74ad --- /dev/null +++ b/tests/unit/dataset/test_message_normalizer.py @@ -0,0 +1,2186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for dataset.message_normalizer. + +Focuses on: +- Provider auto-detection from message/tool shapes +- Anthropic -> OpenAI canonical conversion (messages, tools) +- OpenAI canonical -> Anthropic wire conversion (emitters) +- Round-trip integrity (Anthropic -> canonical -> Anthropic) +- Content flattening, billing header stripping, tool_result extraction +- Image block bidirectional conversion +- Redacted thinking, server tools, citations, cache_control, is_error, caller +- Developer role, refusal, passthrough blocks +- OpenAI passthrough (messages unchanged) +""" + +from typing import Any + +import orjson +import pytest +from pytest import param + +from aiperf.dataset.message_normalizer import ( + _anthropic_image_to_openai, + _anthropic_tool_use_to_call, + _detect_provider, + _emit_anthropic_assistant, + _emit_anthropic_tool_result, + _emit_anthropic_user, + _flatten_text_content, + _merge_consecutive_roles, + _normalize_anthropic_assistant, + _normalize_anthropic_messages, + _normalize_anthropic_tools, + _normalize_anthropic_user, + _openai_image_to_anthropic, + _sanitize_tool_id, + _tool_call_to_anthropic_block, + normalize_messages, + to_anthropic_messages, + to_anthropic_tools, +) + +# ============================================================ +# Helpers -- reusable message builders +# ============================================================ + + +def _tool_use_block( + tool_id: str = "toolu_01", + name: str = "Read", + input_dict: dict[str, Any] | None = None, + **extra: Any, +) -> dict[str, Any]: + block: dict[str, Any] = { + "type": "tool_use", + "id": tool_id, + "name": name, + "input": input_dict or {"file_path": "/tmp/test.py"}, + } + block.update(extra) + return block + + +def _server_tool_use_block( + tool_id: str = "srvtoolu_01", + name: str = "web_search", + input_dict: dict[str, Any] | None = None, +) -> dict[str, Any]: + return { + "type": "server_tool_use", + "id": tool_id, + "name": name, + "input": input_dict or {"query": "test"}, + } + + +def _tool_result_block( + tool_use_id: str = "toolu_01", + content: str | list[dict[str, Any]] = "file contents here", + **extra: Any, +) -> dict[str, Any]: + block: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content, + } + block.update(extra) + return block + + +def _text_block(text: str, **extra: Any) -> dict[str, Any]: + block: dict[str, Any] = {"type": "text", "text": text} + block.update(extra) + return block + + +def _thinking_block(text: str = "Let me think...") -> dict[str, Any]: + return {"type": "thinking", "thinking": text} + + +def _redacted_thinking_block(data: str = "EqQBCgIYAh==") -> dict[str, Any]: + return {"type": "redacted_thinking", "data": data} + + +def _anthropic_image_block( + source_type: str = "base64", + media_type: str = "image/jpeg", + data: str = "abc123", + url: str = "", +) -> dict[str, Any]: + if source_type == "base64": + return { + "type": "image", + "source": {"type": "base64", "media_type": media_type, "data": data}, + } + return {"type": "image", "source": {"type": "url", "url": url}} + + +def _openai_image_url_part(url: str = "https://example.com/img.png") -> dict[str, Any]: + return {"type": "image_url", "image_url": {"url": url}} + + +def _document_block(title: str = "doc") -> dict[str, Any]: + return { + "type": "document", + "source": {"type": "text", "content": [_text_block("doc content")]}, + "title": title, + } + + +def _search_result_block(title: str = "result") -> dict[str, Any]: + return { + "type": "search_result", + "title": title, + "source": "https://example.com", + "content": [_text_block("search content")], + } + + +def _web_search_tool_result_block() -> dict[str, Any]: + return { + "type": "web_search_tool_result", + "tool_use_id": "srvtoolu_01", + "content": [_text_block("search results here")], + } + + +def _anthropic_tool_def( + name: str = "Read", + description: str = "Read a file", + properties: dict[str, Any] | None = None, +) -> dict[str, Any]: + return { + "name": name, + "description": description, + "input_schema": { + "type": "object", + "properties": properties or {"file_path": {"type": "string"}}, + }, + } + + +def _openai_tool_def( + name: str = "Read", + description: str = "Read a file", +) -> dict[str, Any]: + return { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": {"type": "object", "properties": {}}, + }, + } + + +def _versioned_tool_def(tool_type: str = "web_search_20260209") -> dict[str, Any]: + return {"type": tool_type, "name": "web_search"} + + +# ============================================================ +# _detect_provider +# ============================================================ + + +class TestDetectProvider: + def test_detect_provider_anthropic_tool_use_in_messages(self) -> None: + msgs = [{"role": "assistant", "content": [_tool_use_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_tool_result_in_messages(self) -> None: + msgs = [{"role": "user", "content": [_tool_result_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_thinking_in_messages(self) -> None: + msgs = [{"role": "assistant", "content": [_thinking_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_redacted_thinking(self) -> None: + msgs = [{"role": "assistant", "content": [_redacted_thinking_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_server_tool_use(self) -> None: + msgs = [{"role": "assistant", "content": [_server_tool_use_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_image_with_source(self) -> None: + msgs = [{"role": "user", "content": [_anthropic_image_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_document_block(self) -> None: + msgs = [{"role": "user", "content": [_document_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_server_tool_result(self) -> None: + msgs = [{"role": "assistant", "content": [_web_search_tool_result_block()]}] + assert _detect_provider(msgs) == "anthropic" + + def test_detect_provider_anthropic_versioned_tool(self) -> None: + tools = [_versioned_tool_def("computer_20251124")] + assert _detect_provider([], tools) == "anthropic" + + def test_detect_provider_anthropic_tools(self) -> None: + assert _detect_provider([], [_anthropic_tool_def()]) == "anthropic" + + def test_detect_provider_openai_tools(self) -> None: + assert _detect_provider([], [_openai_tool_def()]) == "openai" + + def test_detect_provider_openai_tool_calls_on_assistant(self) -> None: + msgs = [{"role": "assistant", "tool_calls": [{"id": "call_1"}]}] + assert _detect_provider(msgs) == "openai" + + def test_detect_provider_openai_tool_role_message(self) -> None: + msgs = [{"role": "tool", "tool_call_id": "call_1", "content": "ok"}] + assert _detect_provider(msgs) == "openai" + + def test_detect_provider_plain_text_defaults_to_openai(self) -> None: + msgs = [{"role": "user", "content": "hello"}] + assert _detect_provider(msgs) == "openai" + + def test_detect_provider_empty_messages_defaults_to_openai(self) -> None: + assert _detect_provider([]) == "openai" + + def test_detect_provider_tools_checked_before_messages(self) -> None: + """Tools are a more reliable signal and should be checked first.""" + msgs = [{"role": "assistant", "content": [_tool_use_block()]}] + tools = [_openai_tool_def()] + assert _detect_provider(msgs, tools) == "openai" + + def test_detect_provider_non_dict_content_blocks_skipped(self) -> None: + msgs = [{"role": "user", "content": [42, "a string", None]}] + assert _detect_provider(msgs) == "openai" + + def test_detect_provider_openai_tool_with_parameters_key(self) -> None: + tools = [{"name": "foo", "parameters": {"type": "object"}}] + assert _detect_provider([], tools) == "openai" + + +# ============================================================ +# _flatten_text_content +# ============================================================ + + +class TestFlattenTextContent: + def test_flatten_plain_string(self) -> None: + assert _flatten_text_content("hello") == "hello" + + def test_flatten_list_of_text_blocks(self) -> None: + blocks = [_text_block("a"), _text_block("b")] + assert _flatten_text_content(blocks) == "a\n\nb" + + def test_flatten_list_of_strings(self) -> None: + assert _flatten_text_content(["x", "y"]) == "x\n\ny" + + def test_flatten_mixed_strings_and_text_blocks(self) -> None: + content = ["raw", _text_block("block")] + assert _flatten_text_content(content) == "raw\n\nblock" + + def test_flatten_none_returns_empty(self) -> None: + assert _flatten_text_content(None) == "" + + def test_flatten_non_string_non_list_returns_str(self) -> None: + assert _flatten_text_content(42) == "42" + + def test_flatten_empty_text_blocks_skipped(self) -> None: + content = [_text_block(""), _text_block("ok")] + assert _flatten_text_content(content) == "ok" + + def test_flatten_strip_billing_header_string(self) -> None: + billing = "x-anthropic-billing-header: some data" + assert _flatten_text_content(billing, strip_billing_headers=True) == "" + + def test_flatten_strip_billing_header_in_list(self) -> None: + content = [ + _text_block("x-anthropic-billing-header: metadata"), + _text_block("real system prompt"), + ] + result = _flatten_text_content(content, strip_billing_headers=True) + assert result == "real system prompt" + + def test_flatten_strip_billing_header_raw_string_in_list(self) -> None: + content = ["x-anthropic-billing-header: foo", "keep me"] + result = _flatten_text_content(content, strip_billing_headers=True) + assert result == "keep me" + + def test_flatten_billing_header_preserved_when_not_stripping(self) -> None: + billing = "x-anthropic-billing-header: data" + assert _flatten_text_content(billing) == billing + + +# ============================================================ +# _flatten_text_content (also covers former _extract_tool_result_content) +# ============================================================ + + +# ============================================================ +# Image conversion helpers +# ============================================================ + + +class TestAnthropicImageToOpenAI: + def test_base64_image(self) -> None: + block = _anthropic_image_block("base64", "image/jpeg", "abc123") + result = _anthropic_image_to_openai(block) + assert result["type"] == "image_url" + assert result["image_url"]["url"] == "data:image/jpeg;base64,abc123" + + def test_url_image(self) -> None: + block = _anthropic_image_block("url", url="https://example.com/img.png") + result = _anthropic_image_to_openai(block) + assert result["type"] == "image_url" + assert result["image_url"]["url"] == "https://example.com/img.png" + + def test_unknown_source_type_passthrough(self) -> None: + block = {"type": "image", "source": {"type": "file", "file_id": "f1"}} + result = _anthropic_image_to_openai(block) + assert result == block + + def test_cache_control_preserved(self) -> None: + block = _anthropic_image_block("base64") + block["cache_control"] = {"type": "ephemeral"} + result = _anthropic_image_to_openai(block) + assert result["_meta"]["cache_control"] == {"type": "ephemeral"} + + def test_default_media_type(self) -> None: + block = {"type": "image", "source": {"type": "base64", "data": "x"}} + result = _anthropic_image_to_openai(block) + assert "data:image/png;base64," in result["image_url"]["url"] + + +class TestOpenAIImageToAnthropic: + def test_url_image(self) -> None: + part = _openai_image_url_part("https://example.com/img.png") + result = _openai_image_to_anthropic(part) + assert result["type"] == "image" + assert result["source"]["type"] == "url" + assert result["source"]["url"] == "https://example.com/img.png" + + def test_data_uri_base64_image(self) -> None: + url = "data:image/jpeg;base64,abc123" + part = _openai_image_url_part(url) + result = _openai_image_to_anthropic(part) + assert result["type"] == "image" + assert result["source"]["type"] == "base64" + assert result["source"]["media_type"] == "image/jpeg" + assert result["source"]["data"] == "abc123" + + def test_data_uri_png(self) -> None: + url = "data:image/png;base64,xyz" + result = _openai_image_to_anthropic( + {"type": "image_url", "image_url": {"url": url}} + ) + assert result["source"]["media_type"] == "image/png" + + def test_cache_control_restored(self) -> None: + part = { + "type": "image_url", + "image_url": {"url": "https://x.com/i.png"}, + "_meta": {"cache_control": {"type": "ephemeral"}}, + } + result = _openai_image_to_anthropic(part) + assert result["cache_control"] == {"type": "ephemeral"} + + +# ============================================================ +# _anthropic_tool_use_to_call +# ============================================================ + + +class TestAnthropicToolUseToCall: + def test_basic_conversion(self) -> None: + block = _tool_use_block("t1", "Read", {"file_path": "f.py"}) + result = _anthropic_tool_use_to_call(block) + assert result["id"] == "t1" + assert result["type"] == "function" + assert result["function"]["name"] == "Read" + assert orjson.loads(result["function"]["arguments"]) == {"file_path": "f.py"} + + def test_caller_preserved(self) -> None: + block = _tool_use_block() + block["caller"] = {"type": "code_execution_20260120", "tool_id": "srv1"} + result = _anthropic_tool_use_to_call(block) + assert result["_meta"]["caller"] == { + "type": "code_execution_20260120", + "tool_id": "srv1", + } + + def test_cache_control_preserved(self) -> None: + block = _tool_use_block() + block["cache_control"] = {"type": "ephemeral"} + result = _anthropic_tool_use_to_call(block) + assert result["_meta"]["cache_control"] == {"type": "ephemeral"} + + def test_non_dict_input(self) -> None: + block = {"type": "tool_use", "id": "t1", "name": "X", "input": "raw string"} + result = _anthropic_tool_use_to_call(block) + assert result["function"]["arguments"] == "raw string" + + +# ============================================================ +# _normalize_anthropic_assistant +# ============================================================ + + +class TestNormalizeAnthropicAssistant: + def test_text_only_flattened_to_string(self) -> None: + msg = {"role": "assistant", "content": [_text_block("hello")]} + result = _normalize_anthropic_assistant(msg) + assert len(result) == 1 + assert result[0]["content"] == "hello" + assert "tool_calls" not in result[0] + + def test_multiple_text_blocks_joined(self) -> None: + msg = {"role": "assistant", "content": [_text_block("a"), _text_block("b")]} + result = _normalize_anthropic_assistant(msg) + assert result[0]["content"] == "a\n\nb" + + def test_tool_use_converted_to_tool_calls(self) -> None: + tool = _tool_use_block("toolu_abc", "Bash", {"command": "ls"}) + msg = {"role": "assistant", "content": [tool]} + result = _normalize_anthropic_assistant(msg) + assert len(result) == 1 + out = result[0] + assert out["content"] == "" + assert len(out["tool_calls"]) == 1 + tc = out["tool_calls"][0] + assert tc["id"] == "toolu_abc" + assert tc["type"] == "function" + assert tc["function"]["name"] == "Bash" + parsed = orjson.loads(tc["function"]["arguments"]) + assert parsed == {"command": "ls"} + + def test_text_plus_tool_use(self) -> None: + msg = { + "role": "assistant", + "content": [ + _text_block("Let me check that file."), + _tool_use_block("toolu_01", "Read", {"file_path": "/etc/hosts"}), + ], + } + result = _normalize_anthropic_assistant(msg) + out = result[0] + assert out["content"] == "Let me check that file." + assert len(out["tool_calls"]) == 1 + + def test_thinking_blocks_preserved(self) -> None: + msg = { + "role": "assistant", + "content": [_thinking_block("hmm"), _text_block("answer")], + } + result = _normalize_anthropic_assistant(msg) + out = result[0] + assert out["content"] == "answer" + assert len(out["thinking_blocks"]) == 1 + assert out["thinking_blocks"][0]["type"] == "thinking" + + def test_redacted_thinking_preserved(self) -> None: + msg = { + "role": "assistant", + "content": [_redacted_thinking_block("encrypted"), _text_block("answer")], + } + result = _normalize_anthropic_assistant(msg) + out = result[0] + assert len(out["thinking_blocks"]) == 1 + assert out["thinking_blocks"][0]["type"] == "redacted_thinking" + assert out["thinking_blocks"][0]["data"] == "encrypted" + + def test_mixed_thinking_and_redacted_thinking(self) -> None: + msg = { + "role": "assistant", + "content": [ + _thinking_block("visible"), + _redacted_thinking_block("hidden"), + _text_block("answer"), + ], + } + result = _normalize_anthropic_assistant(msg) + blocks = result[0]["thinking_blocks"] + assert len(blocks) == 2 + assert blocks[0]["type"] == "thinking" + assert blocks[1]["type"] == "redacted_thinking" + + def test_server_tool_use_converted_with_marker(self) -> None: + msg = { + "role": "assistant", + "content": [_server_tool_use_block("srv1", "web_search")], + } + result = _normalize_anthropic_assistant(msg) + tc = result[0]["tool_calls"][0] + assert tc["function"]["name"] == "web_search" + assert tc["_meta"]["server_tool"] is True + + def test_server_tool_result_preserved_as_passthrough(self) -> None: + msg = { + "role": "assistant", + "content": [ + _server_tool_use_block(), + _web_search_tool_result_block(), + _text_block("Based on my search..."), + ], + } + result = _normalize_anthropic_assistant(msg) + out = result[0] + assert len(out["tool_calls"]) == 1 + assert len(out["_meta"]["passthrough_blocks"]) == 1 + assert out["_meta"]["passthrough_blocks"][0]["type"] == "web_search_tool_result" + assert out["content"] == "Based on my search..." + + def test_caller_preserved_on_tool_call(self) -> None: + block = _tool_use_block() + block["caller"] = {"type": "code_execution_20260120", "tool_id": "srv1"} + msg = {"role": "assistant", "content": [block]} + result = _normalize_anthropic_assistant(msg) + assert ( + "_meta" in result[0]["tool_calls"][0] + and "caller" in result[0]["tool_calls"][0]["_meta"] + ) + + def test_citations_preserved(self) -> None: + citations = [ + {"type": "char_location", "cited_text": "hello", "document_index": 0} + ] + text = _text_block("hello") + text["citations"] = citations + msg = {"role": "assistant", "content": [text]} + result = _normalize_anthropic_assistant(msg) + assert result[0]["_meta"]["citations"] == citations + + def test_empty_content_list_returns_empty_string(self) -> None: + msg = {"role": "assistant", "content": []} + result = _normalize_anthropic_assistant(msg) + assert result[0]["content"] == "" + + def test_non_list_content_passthrough(self) -> None: + msg = {"role": "assistant", "content": "already a string"} + result = _normalize_anthropic_assistant(msg) + assert result == [msg] + + def test_tool_use_with_non_dict_input(self) -> None: + block = {"type": "tool_use", "id": "t1", "name": "X", "input": "raw string"} + msg = {"role": "assistant", "content": [block]} + result = _normalize_anthropic_assistant(msg) + assert result[0]["tool_calls"][0]["function"]["arguments"] == "raw string" + + def test_non_dict_blocks_in_content_skipped(self) -> None: + msg = {"role": "assistant", "content": [42, _text_block("ok")]} + result = _normalize_anthropic_assistant(msg) + assert result[0]["content"] == "ok" + + def test_multiple_tool_use_blocks(self) -> None: + msg = { + "role": "assistant", + "content": [ + _tool_use_block("t1", "Read", {"path": "a"}), + _tool_use_block("t2", "Bash", {"cmd": "b"}), + ], + } + result = _normalize_anthropic_assistant(msg) + assert len(result[0]["tool_calls"]) == 2 + names = [tc["function"]["name"] for tc in result[0]["tool_calls"]] + assert names == ["Read", "Bash"] + + +# ============================================================ +# _normalize_anthropic_user +# ============================================================ + + +class TestNormalizeAnthropicUser: + def test_plain_string_passthrough(self) -> None: + msg = {"role": "user", "content": "hello"} + result = _normalize_anthropic_user(msg) + assert result == [msg] + + def test_tool_result_becomes_tool_role(self) -> None: + msg = { + "role": "user", + "content": [_tool_result_block("toolu_01", "file data")], + } + result = _normalize_anthropic_user(msg) + assert len(result) == 1 + assert result[0]["role"] == "tool" + assert result[0]["tool_call_id"] == "toolu_01" + assert result[0]["content"] == "file data" + + def test_tool_result_is_error_preserved(self) -> None: + block = _tool_result_block("t1", "Error!", is_error=True) + msg = {"role": "user", "content": [block]} + result = _normalize_anthropic_user(msg) + assert result[0]["_meta"]["is_error"] is True + + def test_tool_result_cache_control_preserved(self) -> None: + block = _tool_result_block("t1", "data", cache_control={"type": "ephemeral"}) + msg = {"role": "user", "content": [block]} + result = _normalize_anthropic_user(msg) + assert result[0]["_meta"]["cache_control"] == {"type": "ephemeral"} + + def test_text_plus_tool_result_split(self) -> None: + """Text should appear as user message before tool result messages.""" + msg = { + "role": "user", + "content": [ + _text_block("Here is context"), + _tool_result_block("toolu_01", "result data"), + ], + } + result = _normalize_anthropic_user(msg) + assert len(result) == 2 + assert result[0] == {"role": "user", "content": "Here is context"} + assert result[1]["role"] == "tool" + assert result[1]["tool_call_id"] == "toolu_01" + + def test_tool_result_then_text_preserves_order(self) -> None: + """tool_result before text should keep tool entries first.""" + msg = { + "role": "user", + "content": [ + _tool_result_block("toolu_01", "result data"), + _text_block("Follow-up context"), + ], + } + result = _normalize_anthropic_user(msg) + assert len(result) == 2 + assert result[0]["role"] == "tool" + assert result[0]["tool_call_id"] == "toolu_01" + assert result[1] == {"role": "user", "content": "Follow-up context"} + + # Round-trip through to_anthropic_messages + assistant = { + "role": "assistant", + "content": [ + _tool_use_block("toolu_01", "Read", {"path": "/tmp/f"}), + ], + } + canonical = [assistant] + result + roundtrip_msgs, _ = to_anthropic_messages(canonical) + # tool_result should precede user text in the reconstructed Anthropic message + user_msgs = [m for m in roundtrip_msgs if m["role"] == "user"] + assert len(user_msgs) == 1 + content = user_msgs[0]["content"] + assert isinstance(content, list) + tool_result_indices = [ + i + for i, b in enumerate(content) + if isinstance(b, dict) and b.get("type") == "tool_result" + ] + text_indices = [ + i + for i, b in enumerate(content) + if isinstance(b, dict) and b.get("type") == "text" + ] + assert tool_result_indices[0] < text_indices[0] + + def test_multiple_tool_results(self) -> None: + msg = { + "role": "user", + "content": [ + _tool_result_block("t1", "r1"), + _tool_result_block("t2", "r2"), + ], + } + result = _normalize_anthropic_user(msg) + assert len(result) == 2 + assert result[0]["tool_call_id"] == "t1" + assert result[1]["tool_call_id"] == "t2" + + def test_tool_result_with_nested_text_blocks(self) -> None: + nested = [_text_block("line 1"), _text_block("line 2")] + msg = {"role": "user", "content": [_tool_result_block("t1", nested)]} + result = _normalize_anthropic_user(msg) + assert result[0]["content"] == "line 1\n\nline 2" + + def test_raw_string_in_content_list(self) -> None: + msg = {"role": "user", "content": ["a plain string"]} + result = _normalize_anthropic_user(msg) + assert result[0]["role"] == "user" + assert result[0]["content"] == "a plain string" + + def test_empty_content_list_returns_original(self) -> None: + msg = {"role": "user", "content": []} + result = _normalize_anthropic_user(msg) + assert result == [msg] + + def test_image_block_converted_to_image_url(self) -> None: + msg = { + "role": "user", + "content": [ + _text_block("Look at this"), + _anthropic_image_block("base64", "image/png", "imgdata"), + ], + } + result = _normalize_anthropic_user(msg) + assert len(result) == 1 + parts = result[0]["content"] + assert parts[0] == {"type": "text", "text": "Look at this"} + assert parts[1]["type"] == "image_url" + assert "data:image/png;base64,imgdata" in parts[1]["image_url"]["url"] + + def test_image_url_block_converted(self) -> None: + msg = { + "role": "user", + "content": [_anthropic_image_block("url", url="https://example.com/i.jpg")], + } + result = _normalize_anthropic_user(msg) + assert result[0]["content"][0]["type"] == "image_url" + assert ( + result[0]["content"][0]["image_url"]["url"] == "https://example.com/i.jpg" + ) + + def test_document_block_preserved(self) -> None: + msg = {"role": "user", "content": [_text_block("context"), _document_block()]} + result = _normalize_anthropic_user(msg) + parts = result[0]["content"] + assert any(p.get("type") == "document" for p in parts) + + def test_search_result_block_preserved(self) -> None: + msg = {"role": "user", "content": [_search_result_block()]} + result = _normalize_anthropic_user(msg) + parts = result[0]["content"] + assert parts[0]["type"] == "search_result" + + def test_other_block_types_kept_as_content_parts(self) -> None: + """Non-text, non-tool_result blocks are preserved as content parts.""" + unknown = {"type": "custom_type", "data": "something"} + msg = {"role": "user", "content": [_text_block("caption"), unknown]} + result = _normalize_anthropic_user(msg) + assert len(result) == 1 + parts = result[0]["content"] + assert len(parts) == 2 + assert parts[1] == unknown + + +# ============================================================ +# _normalize_anthropic_tools +# ============================================================ + + +class TestNormalizeAnthropicTools: + def test_anthropic_tool_converted(self) -> None: + tools = [_anthropic_tool_def("Bash", "Run a command")] + result = _normalize_anthropic_tools(tools) + assert len(result) == 1 + t = result[0] + assert t["type"] == "function" + assert t["function"]["name"] == "Bash" + assert t["function"]["description"] == "Run a command" + assert t["function"]["parameters"]["type"] == "object" + + def test_openai_tool_passthrough(self) -> None: + tools = [_openai_tool_def("Read")] + result = _normalize_anthropic_tools(tools) + assert result == tools + + def test_parameters_key_tool_wrapped(self) -> None: + """Tools with top-level name + parameters get wrapped in function envelope.""" + tool = {"name": "foo", "description": "bar", "parameters": {"type": "object"}} + result = _normalize_anthropic_tools([tool]) + assert result[0]["type"] == "function" + assert result[0]["function"]["name"] == "foo" + + def test_unknown_format_passthrough(self) -> None: + tool = {"something": "else"} + result = _normalize_anthropic_tools([tool]) + assert result == [tool] + + def test_missing_description_gets_default(self) -> None: + tool = {"name": "X", "input_schema": {"type": "object"}} + result = _normalize_anthropic_tools([tool]) + assert result[0]["function"]["description"] == "(no description)" + + def test_mixed_tool_formats(self) -> None: + tools = [ + _anthropic_tool_def("A"), + _openai_tool_def("B"), + {"name": "C", "parameters": {"type": "object"}}, + ] + result = _normalize_anthropic_tools(tools) + names = [t["function"]["name"] for t in result] + assert names == ["A", "B", "C"] + + def test_versioned_tool_passthrough(self) -> None: + tools = [_versioned_tool_def("computer_20251124")] + result = _normalize_anthropic_tools(tools) + assert result == tools + + def test_versioned_tool_among_mixed(self) -> None: + tools = [ + _anthropic_tool_def("A"), + _versioned_tool_def("bash_20250124"), + _versioned_tool_def("text_editor_20250429"), + ] + result = _normalize_anthropic_tools(tools) + assert result[0]["type"] == "function" + assert result[1]["type"] == "bash_20250124" + assert result[2]["type"] == "text_editor_20250429" + + def test_cache_control_preserved(self) -> None: + tool = _anthropic_tool_def("Read") + tool["cache_control"] = {"type": "ephemeral"} + result = _normalize_anthropic_tools([tool]) + assert result[0]["_meta"]["cache_control"] == {"type": "ephemeral"} + + +# ============================================================ +# _normalize_anthropic_messages (integration of sub-functions) +# ============================================================ + + +class TestNormalizeAnthropicMessages: + def test_system_message_flattened(self) -> None: + msgs = [ + { + "role": "system", + "content": [ + _text_block("x-anthropic-billing-header: proj-123"), + _text_block("You are a helpful assistant."), + ], + }, + ] + result = _normalize_anthropic_messages(msgs) + assert len(result) == 1 + assert result[0]["role"] == "system" + assert result[0]["content"] == "You are a helpful assistant." + + def test_system_message_all_billing_headers_produces_no_output(self) -> None: + msgs = [ + { + "role": "system", + "content": [_text_block("x-anthropic-billing-header: only")], + } + ] + result = _normalize_anthropic_messages(msgs) + assert len(result) == 0 + + def test_unknown_role_passthrough(self) -> None: + msg = {"role": "function", "content": "data"} + result = _normalize_anthropic_messages([msg]) + assert result == [msg] + + def test_full_conversation_round_trip(self) -> None: + """Realistic Claude Code multi-turn with system, assistant, user/tool_result.""" + msgs = [ + { + "role": "system", + "content": [ + _text_block("x-anthropic-billing-header: abc"), + _text_block("System prompt here."), + ], + }, + {"role": "user", "content": "Fix the bug in main.py"}, + { + "role": "assistant", + "content": [ + _text_block("I will read the file first."), + _tool_use_block("t1", "Read", {"file_path": "main.py"}), + ], + }, + { + "role": "user", + "content": [_tool_result_block("t1", "def main(): pass")], + }, + { + "role": "assistant", + "content": [_text_block("The file is empty. Done.")], + }, + ] + result = _normalize_anthropic_messages(msgs) + + assert result[0] == {"role": "system", "content": "System prompt here."} + assert result[1] == {"role": "user", "content": "Fix the bug in main.py"} + + # Assistant with tool call + assert result[2]["role"] == "assistant" + assert result[2]["content"] == "I will read the file first." + assert len(result[2]["tool_calls"]) == 1 + + # Tool result + assert result[3]["role"] == "tool" + assert result[3]["tool_call_id"] == "t1" + + # Final assistant + assert result[4] == {"role": "assistant", "content": "The file is empty. Done."} + + +# ============================================================ +# normalize_messages (top-level entry point) +# ============================================================ + + +class TestNormalizeMessages: + def test_anthropic_auto_detected_and_normalized(self) -> None: + msgs = [ + {"role": "assistant", "content": [_tool_use_block("t1", "Read", {"p": 1})]}, + {"role": "user", "content": [_tool_result_block("t1", "data")]}, + ] + tools = [_anthropic_tool_def()] + result_msgs, result_tools = normalize_messages(msgs, tools) + + assert result_msgs[0]["tool_calls"][0]["type"] == "function" + assert result_msgs[1]["role"] == "tool" + assert result_tools is not None + assert result_tools[0]["type"] == "function" + + def test_openai_passthrough(self) -> None: + msgs = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + result_msgs, result_tools = normalize_messages(msgs) + assert result_msgs == msgs + assert result_tools is None + + def test_explicit_provider_skips_detection(self) -> None: + """When provider is given explicitly, detection is bypassed.""" + msgs = [ + {"role": "assistant", "content": [_tool_use_block()]}, + ] + # Force openai -- should NOT normalize Anthropic content + result_msgs, _ = normalize_messages(msgs, provider="openai") + assert isinstance(result_msgs[0]["content"], list) + + def test_tokens_metadata_stripped(self) -> None: + msgs = [{"role": "user", "content": "hi", "tokens": 42}] + result_msgs, _ = normalize_messages(msgs) + assert "tokens" not in result_msgs[0] + + def test_none_tools_stay_none(self) -> None: + _, result_tools = normalize_messages([{"role": "user", "content": "x"}]) + assert result_tools is None + + def test_empty_tools_stay_empty(self) -> None: + _, result_tools = normalize_messages( + [{"role": "user", "content": "x"}], tools=[] + ) + assert result_tools == [] + + @pytest.mark.parametrize( + "provider,msgs,tools", + [ + param( + None, + [{"role": "user", "content": "hi"}], + None, + id="plain-text-no-tools", + ), + param( + "openai", + [{"role": "user", "content": "hi"}], + [_openai_tool_def()], + id="explicit-openai-with-tools", + ), + ], + ) # fmt: skip + def test_openai_messages_unchanged( + self, + provider: str | None, + msgs: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> None: + result_msgs, _ = normalize_messages(msgs, tools, provider=provider) + # Content should remain identical (only tokens key stripped) + for orig, res in zip(msgs, result_msgs, strict=False): + assert res["content"] == orig["content"] + + +# ============================================================ +# _emit_anthropic_assistant (canonical -> Anthropic) +# ============================================================ + + +class TestEmitAnthropicAssistant: + def test_text_only_becomes_text_block(self) -> None: + msg = {"role": "assistant", "content": "hello world"} + result = _emit_anthropic_assistant(msg) + assert result["role"] == "assistant" + assert result["content"] == [{"type": "text", "text": "hello world"}] + + def test_empty_string_content_becomes_empty_text_block(self) -> None: + msg = {"role": "assistant", "content": ""} + result = _emit_anthropic_assistant(msg) + assert result["content"] == [{"type": "text", "text": ""}] + + def test_none_content_produces_empty_text_block(self) -> None: + msg = {"role": "assistant", "content": None} + result = _emit_anthropic_assistant(msg) + assert result["content"] == [{"type": "text", "text": ""}] + + def test_tool_calls_become_tool_use_blocks(self) -> None: + msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "Bash", + "arguments": orjson.dumps({"command": "ls"}).decode(), + }, + } + ], + } + result = _emit_anthropic_assistant(msg) + blocks = result["content"] + assert len(blocks) == 1 + assert blocks[0]["type"] == "tool_use" + assert blocks[0]["id"] == "call_1" + assert blocks[0]["name"] == "Bash" + assert blocks[0]["input"] == {"command": "ls"} + + def test_text_plus_tool_calls(self) -> None: + msg = { + "role": "assistant", + "content": "Let me check.", + "tool_calls": [ + { + "id": "call_2", + "type": "function", + "function": {"name": "Read", "arguments": '{"path": "f.py"}'}, + } + ], + } + result = _emit_anthropic_assistant(msg) + blocks = result["content"] + assert len(blocks) == 2 + assert blocks[0] == {"type": "text", "text": "Let me check."} + assert blocks[1]["type"] == "tool_use" + + def test_thinking_blocks_restored_first(self) -> None: + msg = { + "role": "assistant", + "content": "answer", + "thinking_blocks": [{"type": "thinking", "thinking": "hmm"}], + } + result = _emit_anthropic_assistant(msg) + blocks = result["content"] + assert blocks[0] == {"type": "thinking", "thinking": "hmm"} + assert blocks[1] == {"type": "text", "text": "answer"} + + def test_redacted_thinking_restored(self) -> None: + msg = { + "role": "assistant", + "content": "answer", + "thinking_blocks": [{"type": "redacted_thinking", "data": "enc"}], + } + result = _emit_anthropic_assistant(msg) + blocks = result["content"] + assert blocks[0] == {"type": "redacted_thinking", "data": "enc"} + + def test_server_tool_use_restored(self) -> None: + msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "srv1", + "type": "function", + "function": {"name": "web_search", "arguments": '{"q": "test"}'}, + "_meta": {"server_tool": True}, + } + ], + } + result = _emit_anthropic_assistant(msg) + block = result["content"][0] + assert block["type"] == "server_tool_use" + assert block["name"] == "web_search" + + def test_caller_restored_on_tool_use(self) -> None: + caller = {"type": "code_execution_20260120", "tool_id": "srv1"} + msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "t1", + "type": "function", + "function": {"name": "Read", "arguments": "{}"}, + "_meta": {"caller": caller}, + } + ], + } + result = _emit_anthropic_assistant(msg) + assert result["content"][0]["caller"] == caller + + def test_cache_control_restored_on_tool_use(self) -> None: + msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "t1", + "type": "function", + "function": {"name": "X", "arguments": "{}"}, + "_meta": {"cache_control": {"type": "ephemeral"}}, + } + ], + } + result = _emit_anthropic_assistant(msg) + assert result["content"][0]["cache_control"] == {"type": "ephemeral"} + + def test_passthrough_blocks_restored(self) -> None: + sr = _web_search_tool_result_block() + msg = { + "role": "assistant", + "content": "result", + "_meta": {"passthrough_blocks": [sr]}, + } + result = _emit_anthropic_assistant(msg) + assert sr in result["content"] + + def test_refusal_becomes_text_block(self) -> None: + msg = {"role": "assistant", "content": None, "refusal": "I cannot do that."} + result = _emit_anthropic_assistant(msg) + text_blocks = [b for b in result["content"] if b.get("type") == "text"] + assert any(b["text"] == "I cannot do that." for b in text_blocks) + + def test_multiple_tool_calls(self) -> None: + msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "A", "arguments": "{}"}, + }, + { + "id": "c2", + "type": "function", + "function": {"name": "B", "arguments": "{}"}, + }, + ], + } + result = _emit_anthropic_assistant(msg) + names = [b["name"] for b in result["content"] if b["type"] == "tool_use"] + assert names == ["A", "B"] + + def test_invalid_json_arguments_passthrough(self) -> None: + msg = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "X", "arguments": "not valid json"}, + } + ], + } + result = _emit_anthropic_assistant(msg) + assert result["content"][0]["input"] == "not valid json" + + def test_list_content_passed_through(self) -> None: + """Content that is already a list of blocks passes through.""" + blocks = [{"type": "text", "text": "already blocks"}] + msg = {"role": "assistant", "content": blocks} + result = _emit_anthropic_assistant(msg) + assert {"type": "text", "text": "already blocks"} in result["content"] + + +# ============================================================ +# _emit_anthropic_tool_result (canonical -> Anthropic) +# ============================================================ + + +class TestEmitAnthropicToolResult: + def test_basic_tool_result(self) -> None: + msg = {"role": "tool", "tool_call_id": "call_1", "content": "file data"} + result = _emit_anthropic_tool_result(msg) + assert result["role"] == "user" + assert len(result["content"]) == 1 + block = result["content"][0] + assert block["type"] == "tool_result" + assert block["tool_use_id"] == "call_1" + assert block["content"] == "file data" + + def test_missing_tool_call_id_defaults_empty(self) -> None: + msg = {"role": "tool", "content": "data"} + result = _emit_anthropic_tool_result(msg) + assert result["content"][0]["tool_use_id"] == "" + + def test_missing_content_defaults_empty(self) -> None: + msg = {"role": "tool", "tool_call_id": "c1"} + result = _emit_anthropic_tool_result(msg) + assert result["content"][0]["content"] == "" + + def test_is_error_restored(self) -> None: + msg = { + "role": "tool", + "tool_call_id": "c1", + "content": "Error!", + "_meta": {"is_error": True}, + } + result = _emit_anthropic_tool_result(msg) + assert result["content"][0]["is_error"] is True + + def test_is_error_not_set_when_false(self) -> None: + msg = {"role": "tool", "tool_call_id": "c1", "content": "ok"} + result = _emit_anthropic_tool_result(msg) + assert "is_error" not in result["content"][0] + + def test_cache_control_restored(self) -> None: + msg = { + "role": "tool", + "tool_call_id": "c1", + "content": "data", + "_meta": {"cache_control": {"type": "ephemeral"}}, + } + result = _emit_anthropic_tool_result(msg) + assert result["content"][0]["cache_control"] == {"type": "ephemeral"} + + +# ============================================================ +# _emit_anthropic_user (canonical -> Anthropic) +# ============================================================ + + +class TestEmitAnthropicUser: + def test_string_content_passthrough(self) -> None: + msg = {"role": "user", "content": "hello"} + result = _emit_anthropic_user(msg) + assert result == {"role": "user", "content": "hello"} + + def test_image_url_converted_to_anthropic_image(self) -> None: + part = _openai_image_url_part("https://example.com/img.png") + msg = {"role": "user", "content": [{"type": "text", "text": "look"}, part]} + result = _emit_anthropic_user(msg) + blocks = result["content"] + assert blocks[0] == {"type": "text", "text": "look"} + assert blocks[1]["type"] == "image" + assert blocks[1]["source"]["type"] == "url" + assert blocks[1]["source"]["url"] == "https://example.com/img.png" + + def test_data_uri_image_url_converted(self) -> None: + url = "data:image/jpeg;base64,abc123" + msg = {"role": "user", "content": [_openai_image_url_part(url)]} + result = _emit_anthropic_user(msg) + block = result["content"][0] + assert block["type"] == "image" + assert block["source"]["type"] == "base64" + assert block["source"]["media_type"] == "image/jpeg" + assert block["source"]["data"] == "abc123" + + def test_non_image_parts_passthrough(self) -> None: + msg = {"role": "user", "content": [{"type": "text", "text": "hi"}]} + result = _emit_anthropic_user(msg) + assert result["content"] == [{"type": "text", "text": "hi"}] + + def test_non_dict_content_passthrough(self) -> None: + msg = {"role": "user", "content": 42} + result = _emit_anthropic_user(msg) + assert result == msg + + def test_mixed_content_parts(self) -> None: + msg = { + "role": "user", + "content": [ + {"type": "text", "text": "caption"}, + _openai_image_url_part("https://x.com/i.png"), + {"type": "document", "title": "doc"}, + ], + } + result = _emit_anthropic_user(msg) + assert result["content"][0]["type"] == "text" + assert result["content"][1]["type"] == "image" + assert result["content"][2]["type"] == "document" + + +# ============================================================ +# _merge_consecutive_roles +# ============================================================ + + +class TestMergeConsecutiveRoles: + def test_empty_input(self) -> None: + assert _merge_consecutive_roles([]) == [] + + def test_single_message(self) -> None: + msgs = [{"role": "user", "content": "hi"}] + assert _merge_consecutive_roles(msgs) == msgs + + def test_alternating_roles_unchanged(self) -> None: + msgs = [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "a"}, + {"role": "user", "content": "q2"}, + ] + assert _merge_consecutive_roles(msgs) == msgs + + def test_consecutive_user_list_plus_list(self) -> None: + msgs = [ + {"role": "user", "content": [{"type": "tool_result", "content": "r1"}]}, + {"role": "user", "content": [{"type": "tool_result", "content": "r2"}]}, + ] + result = _merge_consecutive_roles(msgs) + assert len(result) == 1 + assert len(result[0]["content"]) == 2 + + def test_consecutive_user_string_plus_string(self) -> None: + msgs = [ + {"role": "user", "content": "hello"}, + {"role": "user", "content": "world"}, + ] + result = _merge_consecutive_roles(msgs) + assert len(result) == 1 + assert result[0]["content"] == [ + {"type": "text", "text": "hello"}, + {"type": "text", "text": "world"}, + ] + + def test_consecutive_user_list_plus_string(self) -> None: + msgs = [ + {"role": "user", "content": [{"type": "text", "text": "a"}]}, + {"role": "user", "content": "b"}, + ] + result = _merge_consecutive_roles(msgs) + assert len(result) == 1 + assert result[0]["content"] == [ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + + def test_consecutive_user_string_plus_list(self) -> None: + msgs = [ + {"role": "user", "content": "a"}, + {"role": "user", "content": [{"type": "text", "text": "b"}]}, + ] + result = _merge_consecutive_roles(msgs) + assert len(result) == 1 + assert result[0]["content"] == [ + {"type": "text", "text": "a"}, + {"type": "text", "text": "b"}, + ] + + def test_three_consecutive_same_role(self) -> None: + msgs = [ + {"role": "user", "content": [{"type": "tool_result", "content": "r1"}]}, + {"role": "user", "content": [{"type": "tool_result", "content": "r2"}]}, + {"role": "user", "content": [{"type": "tool_result", "content": "r3"}]}, + ] + result = _merge_consecutive_roles(msgs) + assert len(result) == 1 + assert len(result[0]["content"]) == 3 + + def test_different_roles_not_merged(self) -> None: + msgs = [ + {"role": "user", "content": "a"}, + {"role": "assistant", "content": "b"}, + {"role": "user", "content": "c"}, + ] + result = _merge_consecutive_roles(msgs) + assert len(result) == 3 + + +# ============================================================ +# to_anthropic_messages (full emitter) +# ============================================================ + + +class TestToAnthropicMessages: + def test_system_extracted(self) -> None: + msgs = [ + {"role": "system", "content": "Be helpful."}, + {"role": "user", "content": "hi"}, + ] + result, system = to_anthropic_messages(msgs) + assert system == "Be helpful." + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_developer_role_extracted_as_system(self) -> None: + msgs = [ + {"role": "developer", "content": "Instructions"}, + {"role": "user", "content": "hi"}, + ] + result, system = to_anthropic_messages(msgs) + assert system == "Instructions" + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_multiple_system_messages_merged(self) -> None: + msgs = [ + {"role": "system", "content": "first"}, + {"role": "user", "content": "q"}, + {"role": "system", "content": "second"}, + {"role": "user", "content": "q2"}, + ] + _, system = to_anthropic_messages(msgs) + assert system == "first\n\nsecond" + + def test_developer_merged_with_system(self) -> None: + msgs = [ + {"role": "system", "content": "old"}, + {"role": "developer", "content": "new"}, + {"role": "user", "content": "q"}, + ] + _, system = to_anthropic_messages(msgs) + assert system == "old\n\nnew" + + def test_no_system_returns_none(self) -> None: + msgs = [{"role": "user", "content": "hi"}] + _, system = to_anthropic_messages(msgs) + assert system is None + + def test_tool_role_becomes_user_tool_result(self) -> None: + msgs = [{"role": "tool", "tool_call_id": "c1", "content": "result data"}] + result, _ = to_anthropic_messages(msgs) + assert len(result) == 1 + assert result[0]["role"] == "user" + block = result[0]["content"][0] + assert block["type"] == "tool_result" + assert block["tool_use_id"] == "c1" + + def test_assistant_with_tool_calls_emitted(self) -> None: + msgs = [ + { + "role": "assistant", + "content": "checking", + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "Read", "arguments": '{"p": 1}'}, + } + ], + } + ] + result, _ = to_anthropic_messages(msgs) + blocks = result[0]["content"] + types = [b["type"] for b in blocks] + assert "text" in types + assert "tool_use" in types + + def test_consecutive_tool_results_merged(self) -> None: + """Multiple role:tool messages should merge into one user message.""" + msgs = [ + {"role": "tool", "tool_call_id": "c1", "content": "r1"}, + {"role": "tool", "tool_call_id": "c2", "content": "r2"}, + ] + result, _ = to_anthropic_messages(msgs) + assert len(result) == 1 + assert result[0]["role"] == "user" + assert len(result[0]["content"]) == 2 + + def test_user_string_content_passthrough(self) -> None: + msgs = [{"role": "user", "content": "hello"}] + result, _ = to_anthropic_messages(msgs) + assert result[0] == {"role": "user", "content": "hello"} + + def test_user_list_content_passthrough(self) -> None: + blocks = [{"type": "text", "text": "hi"}, {"type": "document", "title": "d"}] + msgs = [{"role": "user", "content": blocks}] + result, _ = to_anthropic_messages(msgs) + assert result[0]["content"] == blocks + + def test_user_image_url_converted(self) -> None: + msg = { + "role": "user", + "content": [ + {"type": "text", "text": "look"}, + _openai_image_url_part("https://x.com/i.png"), + ], + } + result, _ = to_anthropic_messages([msg]) + blocks = result[0]["content"] + assert blocks[1]["type"] == "image" + + def test_unknown_role_passthrough(self) -> None: + msgs = [{"role": "function", "content": "legacy"}] + result, _ = to_anthropic_messages(msgs) + assert result[0] == msgs[0] + + def test_full_conversation(self) -> None: + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Fix main.py"}, + { + "role": "assistant", + "content": "Reading file.", + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": { + "name": "Read", + "arguments": '{"file_path": "main.py"}', + }, + } + ], + }, + {"role": "tool", "tool_call_id": "c1", "content": "def main(): pass"}, + {"role": "assistant", "content": "Done."}, + ] + result, system = to_anthropic_messages(msgs) + assert system == "You are helpful." + assert result[0] == {"role": "user", "content": "Fix main.py"} + assert result[1]["role"] == "assistant" + assert any(b["type"] == "tool_use" for b in result[1]["content"]) + assert result[2]["role"] == "user" + assert result[2]["content"][0]["type"] == "tool_result" + assert result[3]["role"] == "assistant" + + +# ============================================================ +# to_anthropic_tools (canonical -> Anthropic) +# ============================================================ + + +class TestToAnthropicTools: + def test_openai_tool_converted(self) -> None: + tools = [_openai_tool_def("Bash", "Run a command")] + result = to_anthropic_tools(tools) + assert len(result) == 1 + t = result[0] + assert t["name"] == "Bash" + assert t["description"] == "Run a command" + assert "input_schema" in t + + def test_anthropic_tool_passthrough(self) -> None: + tools = [_anthropic_tool_def("Read")] + result = to_anthropic_tools(tools) + assert result == tools + + def test_missing_description_omitted(self) -> None: + tool = { + "type": "function", + "function": { + "name": "X", + "parameters": {"type": "object"}, + }, + } + result = to_anthropic_tools([tool]) + assert result[0]["name"] == "X" + assert "description" not in result[0] + + def test_missing_parameters_gets_empty_schema(self) -> None: + tool = { + "type": "function", + "function": {"name": "Y", "description": "desc"}, + } + result = to_anthropic_tools([tool]) + assert result[0]["input_schema"] == {} + + def test_unknown_format_passthrough(self) -> None: + tool = {"something": "unknown"} + assert to_anthropic_tools([tool]) == [tool] + + def test_mixed_formats(self) -> None: + tools = [ + _openai_tool_def("A"), + _anthropic_tool_def("B"), + ] + result = to_anthropic_tools(tools) + assert result[0]["name"] == "A" + assert "input_schema" in result[0] + assert result[1] == tools[1] + + def test_versioned_tool_passthrough(self) -> None: + tools = [_versioned_tool_def("computer_20251124")] + result = to_anthropic_tools(tools) + assert result == tools + + def test_cache_control_restored(self) -> None: + tool = _openai_tool_def("X") + tool["_meta"] = {"cache_control": {"type": "ephemeral"}} + result = to_anthropic_tools([tool]) + assert result[0]["cache_control"] == {"type": "ephemeral"} + + +# ============================================================ +# Round-trip: Anthropic -> canonical -> Anthropic +# ============================================================ + + +class TestRoundTrip: + def test_text_only_assistant_round_trip(self) -> None: + original = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": [_text_block("hello")]}, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + restored, _ = to_anthropic_messages(canonical) + + assert restored[0] == {"role": "user", "content": "hi"} + assert restored[1]["role"] == "assistant" + assert any( + b.get("type") == "text" and b.get("text") == "hello" + for b in restored[1]["content"] + ) + + def test_tool_use_round_trip(self) -> None: + original = [ + { + "role": "assistant", + "content": [ + _text_block("Let me check."), + _tool_use_block("t1", "Read", {"file_path": "/tmp/f"}), + ], + }, + { + "role": "user", + "content": [_tool_result_block("t1", "file data")], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + + assert canonical[0]["role"] == "assistant" + assert len(canonical[0]["tool_calls"]) == 1 + assert canonical[1]["role"] == "tool" + + restored, _ = to_anthropic_messages(canonical) + + assistant_blocks = restored[0]["content"] + tool_use_blocks = [b for b in assistant_blocks if b.get("type") == "tool_use"] + assert len(tool_use_blocks) == 1 + assert tool_use_blocks[0]["name"] == "Read" + assert tool_use_blocks[0]["input"] == {"file_path": "/tmp/f"} + + user_msg = restored[1] + assert user_msg["role"] == "user" + tr_blocks = [b for b in user_msg["content"] if b.get("type") == "tool_result"] + assert len(tr_blocks) == 1 + assert tr_blocks[0]["tool_use_id"] == "t1" + + def test_thinking_blocks_round_trip(self) -> None: + original = [ + { + "role": "assistant", + "content": [_thinking_block("deep thought"), _text_block("42")], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + assert canonical[0]["thinking_blocks"][0]["type"] == "thinking" + + restored, _ = to_anthropic_messages(canonical) + blocks = restored[0]["content"] + thinking = [b for b in blocks if b.get("type") == "thinking"] + text = [b for b in blocks if b.get("type") == "text"] + assert len(thinking) == 1 + assert thinking[0]["thinking"] == "deep thought" + assert len(text) == 1 + assert text[0]["text"] == "42" + + def test_redacted_thinking_round_trip(self) -> None: + original = [ + { + "role": "assistant", + "content": [ + _redacted_thinking_block("encrypted_data"), + _text_block("answer"), + ], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + restored, _ = to_anthropic_messages(canonical) + blocks = restored[0]["content"] + redacted = [b for b in blocks if b.get("type") == "redacted_thinking"] + assert len(redacted) == 1 + assert redacted[0]["data"] == "encrypted_data" + + def test_system_message_round_trip(self) -> None: + original = [ + {"role": "system", "content": [_text_block("Be helpful.")]}, + {"role": "user", "content": "hi"}, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + assert canonical[0] == {"role": "system", "content": "Be helpful."} + + restored, system = to_anthropic_messages(canonical) + assert system == "Be helpful." + assert restored[0]["role"] == "user" + + def test_tools_round_trip(self) -> None: + anthropic_tools = [ + _anthropic_tool_def("Bash", "Run command", {"cmd": {"type": "string"}}), + ] + _, canonical_tools = normalize_messages( + [{"role": "user", "content": "hi"}], + tools=anthropic_tools, + provider="anthropic", + ) + assert canonical_tools is not None + assert canonical_tools[0]["type"] == "function" + + restored_tools = to_anthropic_tools(canonical_tools) + assert restored_tools[0]["name"] == "Bash" + assert restored_tools[0]["description"] == "Run command" + assert restored_tools[0]["input_schema"]["properties"] == { + "cmd": {"type": "string"} + } + + def test_multiple_tool_results_round_trip(self) -> None: + original = [ + { + "role": "assistant", + "content": [ + _tool_use_block("t1", "A", {"x": 1}), + _tool_use_block("t2", "B", {"y": 2}), + ], + }, + { + "role": "user", + "content": [ + _tool_result_block("t1", "res1"), + _tool_result_block("t2", "res2"), + ], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + assert canonical[1]["role"] == "tool" + assert canonical[2]["role"] == "tool" + + restored, _ = to_anthropic_messages(canonical) + user_msg = restored[1] + assert user_msg["role"] == "user" + tr_ids = [ + b["tool_use_id"] + for b in user_msg["content"] + if b.get("type") == "tool_result" + ] + assert tr_ids == ["t1", "t2"] + + def test_image_round_trip_base64(self) -> None: + original = [ + { + "role": "user", + "content": [ + _text_block("Look at this"), + _anthropic_image_block("base64", "image/png", "imgdata"), + ], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + # Canonical should have image_url + parts = canonical[0]["content"] + assert parts[1]["type"] == "image_url" + + restored, _ = to_anthropic_messages(canonical) + blocks = restored[0]["content"] + img = [b for b in blocks if b.get("type") == "image"] + assert len(img) == 1 + assert img[0]["source"]["type"] == "base64" + assert img[0]["source"]["media_type"] == "image/png" + assert img[0]["source"]["data"] == "imgdata" + + def test_image_round_trip_url(self) -> None: + original = [ + { + "role": "user", + "content": [ + _anthropic_image_block("url", url="https://example.com/i.jpg") + ], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + restored, _ = to_anthropic_messages(canonical) + block = restored[0]["content"][0] + assert block["type"] == "image" + assert block["source"]["url"] == "https://example.com/i.jpg" + + def test_server_tool_use_round_trip(self) -> None: + original = [ + { + "role": "assistant", + "content": [ + _server_tool_use_block("srv1", "web_search", {"query": "test"}), + _web_search_tool_result_block(), + _text_block("Found it."), + ], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + assert canonical[0]["tool_calls"][0]["_meta"]["server_tool"] is True + assert len(canonical[0]["_meta"]["passthrough_blocks"]) == 1 + + restored, _ = to_anthropic_messages(canonical) + blocks = restored[0]["content"] + types = [b["type"] for b in blocks] + assert "server_tool_use" in types + assert "web_search_tool_result" in types + assert "text" in types + + def test_is_error_round_trip(self) -> None: + original = [ + { + "role": "user", + "content": [_tool_result_block("t1", "Error!", is_error=True)], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + assert canonical[0]["_meta"]["is_error"] is True + + restored, _ = to_anthropic_messages(canonical) + block = restored[0]["content"][0] + assert block["is_error"] is True + + def test_caller_round_trip(self) -> None: + caller = {"type": "code_execution_20260120", "tool_id": "srv1"} + block = _tool_use_block("t1", "Read", {"path": "f"}) + block["caller"] = caller + original = [{"role": "assistant", "content": [block]}] + canonical, _ = normalize_messages(original, provider="anthropic") + assert canonical[0]["tool_calls"][0]["_meta"]["caller"] == caller + + restored, _ = to_anthropic_messages(canonical) + assert restored[0]["content"][0]["caller"] == caller + + def test_cache_control_on_tool_def_round_trip(self) -> None: + tool = _anthropic_tool_def("X") + tool["cache_control"] = {"type": "ephemeral"} + _, canonical_tools = normalize_messages( + [{"role": "user", "content": "hi"}], tools=[tool], provider="anthropic" + ) + assert canonical_tools[0]["_meta"]["cache_control"] == {"type": "ephemeral"} + + restored = to_anthropic_tools(canonical_tools) + assert restored[0]["cache_control"] == {"type": "ephemeral"} + + def test_versioned_tool_round_trip(self) -> None: + tools = [_versioned_tool_def("computer_20251124"), _anthropic_tool_def("Read")] + _, canonical_tools = normalize_messages( + [{"role": "user", "content": "hi"}], tools=tools, provider="anthropic" + ) + # Versioned tool passes through, regular tool is converted + assert canonical_tools[0]["type"] == "computer_20251124" + assert canonical_tools[1]["type"] == "function" + + restored = to_anthropic_tools(canonical_tools) + assert restored[0]["type"] == "computer_20251124" + assert restored[1]["name"] == "Read" + + def test_document_block_round_trip(self) -> None: + doc = _document_block("my doc") + original = [{"role": "user", "content": [_text_block("context"), doc]}] + canonical, _ = normalize_messages(original, provider="anthropic") + # Document preserved as content part + parts = canonical[0]["content"] + assert any(p.get("type") == "document" for p in parts) + + restored, _ = to_anthropic_messages(canonical) + blocks = restored[0]["content"] + assert any(b.get("type") == "document" for b in blocks) + + def test_full_multi_turn_conversation_round_trip(self) -> None: + """Realistic multi-turn Claude Code conversation with extended features.""" + original = [ + { + "role": "system", + "content": [ + _text_block("x-anthropic-billing-header: proj-123"), + _text_block("You are a coding assistant."), + ], + }, + {"role": "user", "content": "Fix the bug"}, + { + "role": "assistant", + "content": [ + _thinking_block("Need to read the file first"), + _redacted_thinking_block("classified"), + _text_block("I will read the file."), + _tool_use_block("t1", "Read", {"file_path": "bug.py"}), + ], + }, + { + "role": "user", + "content": [_tool_result_block("t1", "def buggy(): return None")], + }, + { + "role": "assistant", + "content": [_text_block("Fixed!")], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + restored, system = to_anthropic_messages(canonical) + + assert system == "You are a coding assistant." + assert restored[0] == {"role": "user", "content": "Fix the bug"} + + asst = restored[1] + assert asst["role"] == "assistant" + types = [b["type"] for b in asst["content"]] + assert "thinking" in types + assert "redacted_thinking" in types + assert "text" in types + assert "tool_use" in types + + assert restored[2]["role"] == "user" + assert restored[2]["content"][0]["type"] == "tool_result" + + assert restored[3]["role"] == "assistant" + + +# ============================================================ +# _sanitize_tool_id +# ============================================================ + + +class TestSanitizeToolId: + def test_valid_id_unchanged(self) -> None: + assert ( + _sanitize_tool_id("toolu_01A09q90qw90lq917835lq9") + == "toolu_01A09q90qw90lq917835lq9" + ) + + def test_empty_string_unchanged(self) -> None: + assert _sanitize_tool_id("") == "" + + def test_server_tool_id_unchanged(self) -> None: + assert _sanitize_tool_id("srvtoolu_01abc") == "srvtoolu_01abc" + + def test_replaces_invalid_characters(self) -> None: + assert _sanitize_tool_id("tool.call:123!@#") == "tool_call_123___" + + def test_hyphens_preserved(self) -> None: + assert _sanitize_tool_id("tool-call-123") == "tool-call-123" + + def test_spaces_replaced(self) -> None: + assert _sanitize_tool_id("tool call 123") == "tool_call_123" + + +# ============================================================ +# Tool ID sanitization in emitters +# ============================================================ + + +class TestToolIdSanitizationInEmitters: + def test_emit_anthropic_assistant_sanitizes_tool_ids(self) -> None: + msg: dict[str, Any] = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call.123!bad", + "type": "function", + "function": {"name": "test", "arguments": "{}"}, + } + ], + } + result = _emit_anthropic_assistant(msg) + assert result["content"][0]["id"] == "call_123_bad" + + def test_emit_anthropic_tool_result_sanitizes_tool_ids(self) -> None: + msg: dict[str, Any] = { + "role": "tool", + "tool_call_id": "call.123!bad", + "content": "result", + } + result = _emit_anthropic_tool_result(msg) + assert result["content"][0]["tool_use_id"] == "call_123_bad" + + +# ============================================================ +# Thinking block interleaving with server tools +# ============================================================ + + +class TestThinkingBlockInterleaving: + def test_interleaved_thinking_server_tool_preserves_order(self) -> None: + """Anthropic requires thinking blocks interleaved with server_tool_use + to maintain signature verification. Verify round-trip preserves order.""" + original = [ + {"role": "user", "content": "Search for Python docs"}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "I'll search for this"}, + { + "type": "server_tool_use", + "id": "srvtoolu_01", + "name": "web_search", + "input": {"query": "Python docs"}, + }, + { + "type": "web_search_tool_result", + "tool_use_id": "srvtoolu_01", + "content": [{"type": "text", "text": "Results here"}], + }, + {"type": "thinking", "thinking": "Now I'll summarize"}, + {"type": "text", "text": "Here are the results"}, + ], + }, + ] + canonical, _ = normalize_messages(original, provider="anthropic") + + # Verify block_order is preserved + asst = canonical[1] + assert "_meta" in asst and "block_order" in asst["_meta"] + order = asst["_meta"]["block_order"] + kinds = [k for k, _ in order] + assert kinds == ["thinking", "tool_call", "passthrough", "thinking", "text"] + + # Round-trip back to Anthropic + restored, _ = to_anthropic_messages(canonical) + asst_restored = restored[1] + types = [b["type"] for b in asst_restored["content"]] + assert types == [ + "thinking", + "server_tool_use", + "web_search_tool_result", + "thinking", + "text", + ] + + def test_no_block_order_when_no_server_tools(self) -> None: + """Regular tool_use (not server) should not get _block_order.""" + original = [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me think"}, + _tool_use_block(), + ], + } + ] + canonical = _normalize_anthropic_messages(original) + assert "_meta" not in canonical[0] or "block_order" not in canonical[0].get( + "_meta", {} + ) + + def test_no_block_order_when_no_thinking(self) -> None: + """Server tools without thinking should not get _block_order.""" + original = [ + { + "role": "assistant", + "content": [ + _server_tool_use_block(), + ], + } + ] + canonical = _normalize_anthropic_messages(original) + assert "_meta" not in canonical[0] or "block_order" not in canonical[0].get( + "_meta", {} + ) + + +# ============================================================ +# MCP server tool support +# ============================================================ + + +class TestMcpServerTools: + def test_detect_provider_mcp_server_tool(self) -> None: + tools = [ + { + "type": "url", + "url": "https://mcp.example.com", + "name": "my_server", + "tool_configuration": {"allowed_tools": ["get_data"]}, + } + ] + assert _detect_provider([], tools) == "anthropic" + + def test_normalize_anthropic_tools_mcp_passthrough(self) -> None: + tools = [ + { + "type": "url", + "url": "https://mcp.example.com", + "name": "my_server", + } + ] + result = _normalize_anthropic_tools(tools) + assert result == tools + + def test_to_anthropic_tools_mcp_passthrough(self) -> None: + tools = [ + { + "type": "url", + "url": "https://mcp.example.com", + "name": "my_server", + } + ] + result = to_anthropic_tools(tools) + assert result == tools + + +# ============================================================ +# OpenAI passthrough content types +# ============================================================ + + +class TestOpenAIPassthroughContentTypes: + @pytest.mark.parametrize( + "content_type", + [ + param("input_audio", id="input_audio"), + param("audio_url", id="audio_url"), + param("guarded_text", id="guarded_text"), + param("video_url", id="video_url"), + param("file", id="file"), + ], + ) + def test_passthrough_content_types_in_user_message(self, content_type: str) -> None: + """Content types valid in OpenAI but not Anthropic should pass through.""" + block = {"type": content_type, "data": "test_data"} + msg: dict[str, Any] = {"role": "user", "content": [block]} + result = _normalize_anthropic_user(msg) + assert len(result) == 1 + assert result[0]["role"] == "user" + content = result[0]["content"] + assert isinstance(content, list) + assert block in content + + def test_file_content_alongside_text(self) -> None: + msg: dict[str, Any] = { + "role": "user", + "content": [ + {"type": "text", "text": "Analyze this file"}, + {"type": "file", "file": {"file_id": "abc123"}}, + ], + } + result = _normalize_anthropic_user(msg) + assert len(result) == 1 + content = result[0]["content"] + assert isinstance(content, list) + assert len(content) == 2 + assert content[0]["type"] == "text" + assert content[1]["type"] == "file" + + +# ============================================================ +# _tool_call_to_anthropic_block +# ============================================================ + + +class TestToolCallToAnthropicBlock: + def test_regular_tool_call(self) -> None: + tc: dict[str, Any] = { + "id": "toolu_01", + "type": "function", + "function": {"name": "Read", "arguments": '{"path": "/tmp"}'}, + } + block = _tool_call_to_anthropic_block(tc) + assert block["type"] == "tool_use" + assert block["id"] == "toolu_01" + assert block["name"] == "Read" + assert block["input"] == {"path": "/tmp"} + + def test_server_tool_call(self) -> None: + tc: dict[str, Any] = { + "id": "srvtoolu_01", + "type": "function", + "function": {"name": "web_search", "arguments": '{"query": "test"}'}, + "_meta": {"server_tool": True}, + } + block = _tool_call_to_anthropic_block(tc) + assert block["type"] == "server_tool_use" + + def test_caller_preserved(self) -> None: + tc: dict[str, Any] = { + "id": "toolu_01", + "type": "function", + "function": {"name": "Read", "arguments": "{}"}, + "_meta": {"caller": {"type": "direct"}}, + } + block = _tool_call_to_anthropic_block(tc) + assert block["caller"] == {"type": "direct"} + + def test_sanitizes_invalid_id(self) -> None: + tc: dict[str, Any] = { + "id": "call.123!bad", + "type": "function", + "function": {"name": "test", "arguments": "{}"}, + } + block = _tool_call_to_anthropic_block(tc) + assert block["id"] == "call_123_bad" diff --git a/tests/unit/endpoints/test_openai_chat_completions.py b/tests/unit/endpoints/test_openai_chat_completions.py index 662d419d3..49f4e3d8f 100644 --- a/tests/unit/endpoints/test_openai_chat_completions.py +++ b/tests/unit/endpoints/test_openai_chat_completions.py @@ -2,16 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from aiperf.common.enums import ModelSelectionStrategy -from aiperf.common.models.model_endpoint_info import ( - EndpointInfo, - ModelEndpointInfo, - ModelInfo, - ModelListInfo, -) +from aiperf.common.models import Text, Turn from aiperf.endpoints.openai_chat import ChatEndpoint from aiperf.plugin.enums import EndpointType -from tests.unit.endpoints.conftest import create_request_info +from tests.unit.endpoints.conftest import ( + create_endpoint_with_mock_transport, + create_model_endpoint, + create_request_info, +) class TestChatEndpoint: @@ -19,21 +17,13 @@ class TestChatEndpoint: @pytest.fixture def model_endpoint(self): - return ModelEndpointInfo( - models=ModelListInfo( - models=[ModelInfo(name="test-model")], - model_selection_strategy=ModelSelectionStrategy.RANDOM, - ), - endpoint=EndpointInfo( - type=EndpointType.CHAT, - base_url="http://localhost:8000", - custom_endpoint="/v1/chat/completions", - api_key="test-api-key", - ), - ) + return create_model_endpoint(EndpointType.CHAT) - def test_format_payload_basic(self, model_endpoint, sample_conversations): - endpoint = ChatEndpoint(model_endpoint) + @pytest.fixture + def endpoint(self, model_endpoint): + return ChatEndpoint(model_endpoint) + + def test_format_payload_basic(self, endpoint, model_endpoint, sample_conversations): turn = sample_conversations["session_1"].turns[0] turns = [turn] request_info = create_request_info(model_endpoint=model_endpoint, turns=turns) @@ -52,9 +42,8 @@ def test_format_payload_basic(self, model_endpoint, sample_conversations): assert payload == expected_payload def test_format_payload_with_max_tokens_and_streaming( - self, model_endpoint, sample_conversations + self, endpoint, model_endpoint, sample_conversations ): - endpoint = ChatEndpoint(model_endpoint) turn = sample_conversations["session_1"].turns[0] turns = [turn] turns[0].max_tokens = 42 @@ -76,9 +65,8 @@ def test_format_payload_with_max_tokens_and_streaming( assert payload == expected_payload def test_format_payload_with_extra_options( - self, model_endpoint, sample_conversations + self, endpoint, model_endpoint, sample_conversations ): - endpoint = ChatEndpoint(model_endpoint) turn = sample_conversations["session_1"].turns[0] turns = [turn] model_endpoint.endpoint.extra = {"ignore_eos": True, "temperature": 0.7} @@ -100,9 +88,8 @@ def test_format_payload_with_extra_options( assert payload == expected_payload def test_format_payload_multiple_turns_with_text_and_image( - self, model_endpoint, sample_conversations + self, endpoint, model_endpoint, sample_conversations ): - endpoint = ChatEndpoint(model_endpoint) # Create a turn with both text and image turns = sample_conversations["session_1"].turns turns[0].images = type("ImageList", (), {})() @@ -134,8 +121,9 @@ def test_format_payload_multiple_turns_with_text_and_image( } assert payload == expected_payload - def test_format_payload_with_audio(self, model_endpoint, sample_conversations): - endpoint = ChatEndpoint(model_endpoint) + def test_format_payload_with_audio( + self, endpoint, model_endpoint, sample_conversations + ): turn = sample_conversations["session_1"].turns[0] turn.audios = [type("Audio", (), {"contents": ["mp3,ZmFrZV9hdWRpbw=="]})()] turns = [turn] @@ -162,8 +150,7 @@ def test_format_payload_with_audio(self, model_endpoint, sample_conversations): } assert payload == expected_payload - def test_create_messages_hotfix(self, model_endpoint, sample_conversations): - endpoint = ChatEndpoint(model_endpoint) + def test_create_messages_hotfix(self, endpoint, sample_conversations): turn = sample_conversations["session_1"].turns[0] turns = [turn] messages = endpoint._create_messages(turns, None, None) @@ -171,10 +158,7 @@ def test_create_messages_hotfix(self, model_endpoint, sample_conversations): assert messages[0]["name"] == turn.texts[0].name assert messages[0]["content"] == turn.texts[0].contents[0] - def test_create_messages_with_empty_content( - self, model_endpoint, sample_conversations - ): - endpoint = ChatEndpoint(model_endpoint) + def test_create_messages_with_empty_content(self, endpoint, sample_conversations): turn = sample_conversations["session_1"].turns[0] turn.texts[0].contents = [""] turns = [turn] @@ -183,10 +167,7 @@ def test_create_messages_with_empty_content( assert messages[0]["name"] == turn.texts[0].name assert messages[0]["content"] == "" - def test_create_messages_audio_format_error( - self, model_endpoint, sample_conversations - ): - endpoint = ChatEndpoint(model_endpoint) + def test_create_messages_audio_format_error(self, endpoint, sample_conversations): turn = sample_conversations["session_1"].turns[0] turn.audios = [type("Audio", (), {"contents": ["not_base64_audio"]})()] turns = [turn] @@ -212,6 +193,7 @@ def test_create_messages_audio_format_error( ) # fmt: skip def test_stream_options_auto_configuration( self, + endpoint, model_endpoint, sample_conversations, streaming, @@ -220,7 +202,6 @@ def test_stream_options_auto_configuration( expected_stream_options, ): """Verify stream_options.include_usage is automatically configured based on flags and user settings.""" - endpoint = ChatEndpoint(model_endpoint) turn = sample_conversations["session_1"].turns[0] turns = [turn] model_endpoint.endpoint.streaming = streaming @@ -237,10 +218,7 @@ def test_stream_options_auto_configuration( assert payload["stream_options"] == expected_stream_options endpoint._create_messages(turns, None, None) - def test_create_messages_with_system_message( - self, model_endpoint, sample_conversations - ): - endpoint = ChatEndpoint(model_endpoint) + def test_create_messages_with_system_message(self, endpoint, sample_conversations): turn = sample_conversations["session_1"].turns[0] turns = [turn] system_message = "You are a helpful AI assistant." @@ -254,9 +232,8 @@ def test_create_messages_with_system_message( assert messages[1]["content"] == turn.texts[0].contents[0] def test_create_messages_with_user_context_message( - self, model_endpoint, sample_conversations + self, endpoint, sample_conversations ): - endpoint = ChatEndpoint(model_endpoint) turn = sample_conversations["session_1"].turns[0] turns = [turn] user_context = "The user is working on a Python project." @@ -270,9 +247,8 @@ def test_create_messages_with_user_context_message( assert messages[1]["content"] == turn.texts[0].contents[0] def test_create_messages_with_both_context_messages( - self, model_endpoint, sample_conversations + self, endpoint, sample_conversations ): - endpoint = ChatEndpoint(model_endpoint) turn = sample_conversations["session_1"].turns[0] turns = [turn] system_message = "You are a helpful AI assistant." @@ -290,9 +266,8 @@ def test_create_messages_with_both_context_messages( assert messages[2]["content"] == turn.texts[0].contents[0] def test_create_messages_with_context_and_multiple_turns( - self, model_endpoint, sample_conversations + self, endpoint, sample_conversations ): - endpoint = ChatEndpoint(model_endpoint) turns = sample_conversations["session_1"].turns system_message = "You are a helpful AI assistant." user_context = "The user is working on a Python project." @@ -311,9 +286,8 @@ def test_create_messages_with_context_and_multiple_turns( assert messages[3]["role"] == turns[1].role def test_format_payload_with_context_messages( - self, model_endpoint, sample_conversations + self, endpoint, model_endpoint, sample_conversations ): - endpoint = ChatEndpoint(model_endpoint) turn = sample_conversations["session_1"].turns[0] turns = [turn] system_message = "You are a helpful AI assistant." @@ -338,3 +312,205 @@ def test_format_payload_with_context_messages( assert payload["messages"][1]["content"] == user_context # Third message should be the turn assert payload["messages"][2]["role"] == (turn.role or "user") + + +class TestChatEndpointRawMessages: + """Tests for raw_messages verbatim replay in ChatEndpoint.""" + + @pytest.fixture + def model_endpoint(self): + return create_model_endpoint(EndpointType.CHAT) + + @pytest.fixture + def endpoint(self, model_endpoint): + return create_endpoint_with_mock_transport(ChatEndpoint, model_endpoint) + + def test_raw_messages_replaces_entire_message(self, endpoint, model_endpoint): + """Turn with raw_messages produces those exact dicts in the messages list.""" + raw = { + "role": "assistant", + "content": "hi", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "get_file", "arguments": '{"path": "a.py"}'}, + } + ], + } + turn = Turn(raw_messages=[raw], model="test-model") + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + + payload = endpoint.format_payload(request_info) + + assert payload["messages"] == [raw] + + def test_raw_messages_with_tool_role(self, endpoint, model_endpoint): + """Turn with raw_messages for a tool result message works.""" + raw = {"role": "tool", "tool_call_id": "call_1", "content": "file contents"} + turn = Turn(raw_messages=[raw], model="test-model") + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + + payload = endpoint.format_payload(request_info) + + assert payload["messages"] == [raw] + assert payload["messages"][0]["role"] == "tool" + assert payload["messages"][0]["tool_call_id"] == "call_1" + + def test_raw_messages_string_content(self, endpoint, model_endpoint): + """raw_messages with string content works.""" + turn = Turn( + raw_messages=[{"role": "user", "content": "verbatim content"}], + model="test-model", + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + + payload = endpoint.format_payload(request_info) + + assert payload["messages"][0]["role"] == "user" + assert payload["messages"][0]["content"] == "verbatim content" + + def test_raw_messages_mixed_with_normal_turns(self, endpoint, model_endpoint): + """raw_messages turns can be mixed with normal turns.""" + raw = {"role": "tool", "tool_call_id": "call_1", "content": "result"} + turns = [ + Turn(texts=[Text(contents=["Hello"])], role="user", model="test-model"), + Turn(raw_messages=[raw], model="test-model"), + Turn(texts=[Text(contents=["Thanks"])], role="user", model="test-model"), + ] + request_info = create_request_info(model_endpoint=model_endpoint, turns=turns) + + payload = endpoint.format_payload(request_info) + + assert len(payload["messages"]) == 3 + assert payload["messages"][0]["role"] == "user" + assert payload["messages"][0]["content"] == "Hello" + assert payload["messages"][1] == raw + assert payload["messages"][2]["role"] == "user" + assert payload["messages"][2]["content"] == "Thanks" + + def test_raw_tools_included_in_payload(self, endpoint, model_endpoint): + """Tool definitions from Turn.raw_tools are included in the payload.""" + tools = [ + { + "type": "function", + "function": { + "name": "get_file", + "description": "Read a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ] + turn = Turn( + texts=[Text(contents=["Read a.py"])], + role="user", + model="test-model", + raw_tools=tools, + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + + payload = endpoint.format_payload(request_info) + + assert payload["tools"] == tools + + def test_raw_tools_omitted_when_none(self, endpoint, model_endpoint): + """No tools key in payload when raw_tools is None.""" + turn = Turn(texts=[Text(contents=["Hello"])], role="user", model="test-model") + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + + payload = endpoint.format_payload(request_info) + + assert "tools" not in payload + + def test_raw_messages_multi_message_extends_into_messages_list( + self, endpoint, model_endpoint + ): + """A single turn with 3 raw_messages expands to 3 entries in messages list.""" + turn = Turn( + raw_messages=[ + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "read_file", + "arguments": '{"path": "a.py"}', + }, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "file data"}, + ], + model="test-model", + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + + payload = endpoint.format_payload(request_info) + + assert len(payload["messages"]) == 3 + assert payload["messages"][0]["role"] == "user" + assert payload["messages"][1]["role"] == "assistant" + assert payload["messages"][2]["role"] == "tool" + + def test_raw_messages_empty_list_adds_nothing(self, endpoint, model_endpoint): + """Turn with raw_messages=[] contributes zero messages.""" + turns = [ + Turn(texts=[Text(contents=["Before"])], role="user", model="test-model"), + Turn(raw_messages=[], model="test-model"), + Turn(texts=[Text(contents=["After"])], role="user", model="test-model"), + ] + request_info = create_request_info(model_endpoint=model_endpoint, turns=turns) + + payload = endpoint.format_payload(request_info) + + assert len(payload["messages"]) == 2 + assert payload["messages"][0]["content"] == "Before" + assert payload["messages"][1]["content"] == "After" + + def test_raw_messages_with_system_and_user_context(self, endpoint, model_endpoint): + """System and user_context are prepended before raw_messages turns.""" + turn = Turn( + raw_messages=[ + {"role": "user", "content": "verbatim user"}, + {"role": "assistant", "content": "verbatim assistant"}, + ], + model="test-model", + ) + request_info = create_request_info( + model_endpoint=model_endpoint, + turns=[turn], + system_message="System prompt", + user_context_message="User context", + ) + + payload = endpoint.format_payload(request_info) + + assert len(payload["messages"]) == 4 + assert payload["messages"][0] == {"role": "system", "content": "System prompt"} + assert payload["messages"][1] == {"role": "user", "content": "User context"} + assert payload["messages"][2] == {"role": "user", "content": "verbatim user"} + assert payload["messages"][3] == { + "role": "assistant", + "content": "verbatim assistant", + } + + def test_raw_messages_takes_precedence_over_texts(self, endpoint, model_endpoint): + """When raw_messages is set, texts are ignored.""" + turn = Turn( + raw_messages=[{"role": "user", "content": "raw wins"}], + texts=[Text(contents=["should be ignored"])], + model="test-model", + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + + payload = endpoint.format_payload(request_info) + + assert payload["messages"][0]["content"] == "raw wins" diff --git a/tests/unit/timing/conftest.py b/tests/unit/timing/conftest.py index b57cf18e2..38bf6618c 100644 --- a/tests/unit/timing/conftest.py +++ b/tests/unit/timing/conftest.py @@ -143,6 +143,7 @@ def create( timing_mode: TimingMode | None = None, auto_offset_timestamps: bool = False, fixed_schedule_start_offset: int | None = None, + fixed_schedule_speedup: float | None = None, ) -> OrchestratorHarness: if schedule is not None: dataset = make_dataset_with_schedule(schedule, sampling_strategy) @@ -182,6 +183,7 @@ def create( random_seed=random_seed, auto_offset_timestamps=auto_offset_timestamps, fixed_schedule_start_offset=fixed_schedule_start_offset, + fixed_schedule_speedup=fixed_schedule_speedup, ) router = MockCreditRouter() pub = MagicMock() @@ -320,6 +322,7 @@ def make_phase_config( auto_offset_timestamps: bool = False, fixed_schedule_start_offset: int | None = None, fixed_schedule_end_offset: int | None = None, + fixed_schedule_speedup: float | None = None, concurrency_ramp_duration_sec: float | None = None, prefill_concurrency_ramp_duration_sec: float | None = None, request_rate_ramp_duration_sec: float | None = None, @@ -340,6 +343,7 @@ def make_phase_config( auto_offset_timestamps=auto_offset_timestamps, fixed_schedule_start_offset=fixed_schedule_start_offset, fixed_schedule_end_offset=fixed_schedule_end_offset, + fixed_schedule_speedup=fixed_schedule_speedup, concurrency_ramp_duration_sec=concurrency_ramp_duration_sec, prefill_concurrency_ramp_duration_sec=prefill_concurrency_ramp_duration_sec, request_rate_ramp_duration_sec=request_rate_ramp_duration_sec, @@ -364,6 +368,7 @@ def make_timing_config( auto_offset_timestamps: bool = False, fixed_schedule_start_offset: int | None = None, fixed_schedule_end_offset: int | None = None, + fixed_schedule_speedup: float | None = None, phase_configs: list[CreditPhaseConfig] | None = None, concurrency_ramp_duration_sec: float | None = None, prefill_concurrency_ramp_duration_sec: float | None = None, @@ -386,6 +391,7 @@ def make_timing_config( auto_offset_timestamps=auto_offset_timestamps, fixed_schedule_start_offset=fixed_schedule_start_offset, fixed_schedule_end_offset=fixed_schedule_end_offset, + fixed_schedule_speedup=fixed_schedule_speedup, concurrency_ramp_duration_sec=concurrency_ramp_duration_sec, prefill_concurrency_ramp_duration_sec=prefill_concurrency_ramp_duration_sec, request_rate_ramp_duration_sec=request_rate_ramp_duration_sec, diff --git a/tests/unit/timing/phase/test_runner.py b/tests/unit/timing/phase/test_runner.py index aa2561a8f..53079734d 100644 --- a/tests/unit/timing/phase/test_runner.py +++ b/tests/unit/timing/phase/test_runner.py @@ -59,6 +59,12 @@ async def execute_phase(self) -> None: async def handle_credit_return(self, credit: Credit) -> None: self.handle_credit_return_calls.append(credit) + def on_failed_credit(self, credit_return: object) -> None: + pass + + def cleanup(self) -> None: + pass + def mock_conc_mgr() -> MagicMock: m = MagicMock() @@ -652,6 +658,134 @@ async def test_cleanup_runs_on_exception( mock_ramper.stop.assert_called_once() +class TestStrategyCleanup: + """Verify strategy.cleanup() is called at every exit point of run().""" + + async def test_cleanup_called_on_normal_completion( + self, + conv_src: MagicMock, + pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel: MagicMock, + cb: MagicMock, + ) -> None: + r = make_runner(cfg(), conv_src, pub, router, conc, cancel, cb) + strategy = MockStrategy() + with patch( + "aiperf.timing.phase.runner.plugins.get_class", + return_value=lambda **kw: strategy, + ): + r._progress.all_credits_sent_event.set() + r._progress.all_credits_returned_event.set() + await r.run(is_final_phase=True) + # MockStrategy.cleanup is a no-op; use a spy to verify it was called + strategy_spy = MagicMock(wraps=MockStrategy()) + with patch( + "aiperf.timing.phase.runner.plugins.get_class", + return_value=lambda **kw: strategy_spy, + ): + r2 = make_runner(cfg(), conv_src, pub, router, conc, cancel, cb) + r2._progress.all_credits_sent_event.set() + r2._progress.all_credits_returned_event.set() + await r2.run(is_final_phase=True) + strategy_spy.cleanup.assert_called_once() + + async def test_cleanup_called_on_cancellation( + self, + conv_src: MagicMock, + pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel: MagicMock, + cb: MagicMock, + ) -> None: + """When phase is cancelled, strategy.cleanup() is still called.""" + strategy = MagicMock(wraps=MockStrategy()) + r = make_runner(cfg(), conv_src, pub, router, conc, cancel, cb) + with patch( + "aiperf.timing.phase.runner.plugins.get_class", + return_value=lambda **kw: strategy, + ): + # Pre-set sent event, simulate cancellation before returns + r._progress.all_credits_sent_event.set() + r._was_cancelled = True + result = await r.run(is_final_phase=True) + strategy.cleanup.assert_called_once() + assert isinstance(result, CreditPhaseStats) + + async def test_cleanup_called_on_exception( + self, + conv_src: MagicMock, + pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel: MagicMock, + cb: MagicMock, + ) -> None: + """When an exception occurs during run(), strategy.cleanup() is called before re-raise.""" + strategy = MagicMock() + strategy.setup_phase = AsyncMock(side_effect=RuntimeError("boom")) + strategy.cleanup = MagicMock() + with patch( + "aiperf.timing.phase.runner.plugins.get_class", + return_value=lambda **kw: strategy, + ): + r = make_runner(cfg(), conv_src, pub, router, conc, cancel, cb) + with pytest.raises(RuntimeError, match="boom"): + await r.run(is_final_phase=True) + strategy.cleanup.assert_called_once() + + +@pytest.mark.asyncio +class TestAddConcurrencyRamper: + """Verify _add_concurrency_ramper creates and registers rampers correctly.""" + + async def test_adds_ramper_to_list( + self, + conv_src: MagicMock, + pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel: MagicMock, + cb: MagicMock, + ) -> None: + r = make_runner(cfg(conc=10), conv_src, pub, router, conc, cancel, cb) + assert len(r._rampers) == 0 + r._add_concurrency_ramper("session", 5.0, 10, lambda v: None) + assert len(r._rampers) == 1 + + async def test_ramper_config_start_at_one( + self, + conv_src: MagicMock, + pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel: MagicMock, + cb: MagicMock, + ) -> None: + r = make_runner(cfg(), conv_src, pub, router, conc, cancel, cb) + r._add_concurrency_ramper("prefill", 3.0, 8, lambda v: None) + ramper = r._rampers[0] + assert ramper._config.start == 1 + assert ramper._config.target == 8 + assert ramper._config.duration_sec == 3.0 + + async def test_multiple_rampers_accumulate( + self, + conv_src: MagicMock, + pub: MagicMock, + router: MagicMock, + conc: MagicMock, + cancel: MagicMock, + cb: MagicMock, + ) -> None: + r = make_runner(cfg(), conv_src, pub, router, conc, cancel, cb) + r._add_concurrency_ramper("session", 5.0, 10, lambda v: None) + r._add_concurrency_ramper("prefill", 3.0, 5, lambda v: None) + assert len(r._rampers) == 2 + + class TestFixedScheduleConfigCorrection: """Tests for FIXED_SCHEDULE mode config correction using actual dataset size.""" diff --git a/tests/unit/timing/strategies/test_fixed_schedule.py b/tests/unit/timing/strategies/test_fixed_schedule.py index 16436ff8e..e0972bcc9 100644 --- a/tests/unit/timing/strategies/test_fixed_schedule.py +++ b/tests/unit/timing/strategies/test_fixed_schedule.py @@ -8,7 +8,7 @@ from aiperf.common.constants import MILLIS_PER_SECOND from aiperf.common.enums import CreditPhase from aiperf.common.models import ConversationMetadata, DatasetMetadata, TurnMetadata -from aiperf.credit.structs import Credit +from aiperf.credit.structs import Credit, TurnToSend from aiperf.plugin.enums import DatasetSamplingStrategy, TimingMode from aiperf.timing.config import CreditPhaseConfig from aiperf.timing.conversation_source import ConversationSource @@ -42,6 +42,7 @@ def make_strategy( schedule: list[tuple[int, str]], auto_offset: bool = True, manual_offset: int | None = None, + speedup: float | None = None, ) -> tuple[FixedScheduleStrategy, MagicMock, MagicMock]: scheduler = MagicMock() scheduler.schedule_at_perf_ns = MagicMock() @@ -66,6 +67,7 @@ def make_strategy( total_expected_requests=len(schedule), auto_offset_timestamps=auto_offset, fixed_schedule_start_offset=manual_offset, + fixed_schedule_speedup=speedup, ) strategy = FixedScheduleStrategy( config=cfg, @@ -330,3 +332,126 @@ async def test_single_conversation_works(self) -> None: await strategy.setup_phase() assert len(strategy._absolute_schedule) == 1 assert strategy._absolute_schedule[0][0] == 0 + + +@pytest.mark.asyncio +class TestFixedScheduleSpeedup: + @pytest.mark.parametrize( + "speedup,expected_time_scale", + [ + (None, 1.0), + (1.0, 1.0), + (2.0, 0.5), + (0.5, 2.0), + (10.0, 0.1), + ], + ) # fmt: skip + async def test_time_scale_calculation( + self, speedup: float | None, expected_time_scale: float + ) -> None: + """Verify speedup correctly computes internal time scale.""" + strategy, _, _ = make_strategy([(0, "c1")], speedup=speedup) + assert strategy._time_scale == pytest.approx(expected_time_scale) + + async def test_speedup_scales_absolute_timestamps(self) -> None: + """Verify speedup scales the offset in _timestamp_to_perf_sec.""" + strategy, _, lifecycle = make_strategy( + [(0, "c1"), (1000, "c2")], auto_offset=True, speedup=2.0 + ) + lifecycle.started_at_perf_sec = 1.0 + await strategy.setup_phase() + # At 2x speed, 1000ms offset becomes 500ms offset + perf_sec = strategy._timestamp_to_perf_sec(1000) + assert perf_sec == pytest.approx(1.5) + + async def test_speedup_scales_delay_ms(self) -> None: + """Verify speedup scales delay_ms in _dispatch_by_timing.""" + strategy, scheduler, lifecycle = make_strategy([(0, "c1")], speedup=4.0) + lifecycle.started_at_perf_sec = 1.0 + await strategy.setup_phase() + turn = TurnToSend( + conversation_id="c1", + x_correlation_id="corr-c1", + turn_index=1, + num_turns=2, + ) + # Call _dispatch_by_timing directly with delay_ms only + strategy._dispatch_by_timing(turn, timestamp_ms=None, delay_ms=200) + # delay_ms=200, at 4x speed: 200 * 0.25 / 1000 = 0.05s + scheduler.schedule_later.assert_called_once() + actual_delay = scheduler.schedule_later.call_args[0][0] + assert actual_delay == pytest.approx(0.05) + + async def test_speedup_with_auto_offset(self) -> None: + """Verify speedup works with auto offset (scales relative to shifted zero).""" + strategy, scheduler, lifecycle = make_strategy( + [(1000, "c1"), (1200, "c2")], auto_offset=True, speedup=2.0 + ) + lifecycle.started_at_perf_sec = 1.0 + await strategy.setup_phase() + await strategy.execute_phase() + # First call: (1000 - 1000) * 0.5 / 1000 = 0s offset + # Second call: (1200 - 1000) * 0.5 / 1000 = 0.1s offset + calls = scheduler.schedule_at_perf_sec.call_args_list + assert len(calls) == 2 + base = strategy._lifecycle.started_at_perf_sec + assert calls[0][0][0] == pytest.approx(base + 0.0) + assert calls[1][0][0] == pytest.approx(base + 0.1) + + async def test_speedup_with_manual_offset(self) -> None: + """Verify speedup works with manual offset.""" + strategy, scheduler, lifecycle = make_strategy( + [(1000, "c1"), (1200, "c2")], + auto_offset=False, + manual_offset=500, + speedup=2.0, + ) + lifecycle.started_at_perf_sec = 1.0 + await strategy.setup_phase() + await strategy.execute_phase() + # First call: (1000 - 500) * 0.5 / 1000 = 0.25s offset + # Second call: (1200 - 500) * 0.5 / 1000 = 0.35s offset + calls = scheduler.schedule_at_perf_sec.call_args_list + base = strategy._lifecycle.started_at_perf_sec + assert calls[0][0][0] == pytest.approx(base + 0.25) + assert calls[1][0][0] == pytest.approx(base + 0.35) + + async def test_no_speedup_unchanged(self) -> None: + """Verify None speedup produces same results as no speedup.""" + strategy_none, sched_none, lc_none = make_strategy( + [(0, "c1"), (100, "c2")], speedup=None + ) + strategy_one, sched_one, lc_one = make_strategy( + [(0, "c1"), (100, "c2")], speedup=1.0 + ) + lc_none.started_at_perf_sec = 1.0 + lc_one.started_at_perf_sec = 1.0 + await strategy_none.setup_phase() + await strategy_one.setup_phase() + await strategy_none.execute_phase() + await strategy_one.execute_phase() + calls_none = [c[0][0] for c in sched_none.schedule_at_perf_sec.call_args_list] + calls_one = [c[0][0] for c in sched_one.schedule_at_perf_sec.call_args_list] + assert calls_none == pytest.approx(calls_one) + + @pytest.mark.parametrize( + "speedup,expected_duration_sec", + [ + (None, 0.2), # No speedup: 200ms range = 0.2s + (1.0, 0.2), # 1x: same as no speedup + (2.0, 0.1), # 2x: 200ms / 2 = 0.1s + (4.0, 0.05), # 4x: 200ms / 4 = 0.05s + (0.5, 0.4), # 0.5x: 200ms / 0.5 = 0.4s + ], + ) # fmt: skip + async def test_execution_timing_with_speedup( + self, create_orchestrator_harness, time_traveler, speedup, expected_duration_sec + ) -> None: + """Verify speedup scales virtual execution time proportionally.""" + schedule = [(0, "c1"), (100, "c2"), (200, "c3")] + h: OrchestratorHarness = create_orchestrator_harness( + schedule=schedule, fixed_schedule_speedup=speedup + ) + with time_traveler.travels_for(expected_duration_sec, tolerance=0.05): + await h.run_with_auto_return() + assert len(h.sent_credits) == 3 diff --git a/tests/unit/timing/strategies/test_subagent_mixin.py b/tests/unit/timing/strategies/test_subagent_mixin.py new file mode 100644 index 000000000..6911b905a --- /dev/null +++ b/tests/unit/timing/strategies/test_subagent_mixin.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for SubagentMixin lifecycle delegation.""" + +from unittest.mock import MagicMock + +import pytest + +from aiperf.common.enums import CreditPhase +from aiperf.credit.messages import CreditReturn +from aiperf.credit.structs import Credit +from aiperf.timing.strategies.subagent_mixin import SubagentMixin + + +def make_credit( + *, + agent_depth: int = 0, + turn_index: int = 0, + num_turns: int = 2, +) -> Credit: + return Credit( + id=1, + phase=CreditPhase.PROFILING, + conversation_id="conv-1", + x_correlation_id="corr-1", + turn_index=turn_index, + num_turns=num_turns, + issued_at_ns=1000, + agent_depth=agent_depth, + ) + + +def make_credit_return( + credit: Credit, + *, + error: str | None = None, + cancelled: bool = False, +) -> CreditReturn: + return CreditReturn(credit=credit, error=error, cancelled=cancelled) + + +def make_orchestrator() -> MagicMock: + o = MagicMock() + o.set_dispatch = MagicMock() + o.terminate_child = MagicMock() + o.cleanup = MagicMock() + return o + + +class ConcreteSubagentMixin(SubagentMixin): + """Minimal concrete class for testing the mixin.""" + + def _dispatch_turn(self, turn: object) -> None: + pass + + +class TestSubagentMixinInit: + def test_init_subagents_none_sets_none(self) -> None: + obj = ConcreteSubagentMixin() + obj._init_subagents(None) + assert obj._subagents is None + + def test_init_subagents_calls_set_dispatch(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + orch.set_dispatch.assert_called_once_with(obj._dispatch_turn) + + def test_init_subagents_stores_orchestrator(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + assert obj._subagents is orch + + +class TestSubagentMixinOnFailedCredit: + def test_no_subagents_does_nothing(self) -> None: + obj = ConcreteSubagentMixin() + obj._init_subagents(None) + credit = make_credit(agent_depth=1, turn_index=0, num_turns=2) + cr = make_credit_return(credit, error="boom") + obj.on_failed_credit(cr) # should not raise + + def test_depth_zero_does_not_terminate(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + credit = make_credit(agent_depth=0, turn_index=0, num_turns=2) + cr = make_credit_return(credit, error="boom") + obj.on_failed_credit(cr) + orch.terminate_child.assert_not_called() + + def test_final_turn_does_not_terminate(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + credit = make_credit(agent_depth=1, turn_index=1, num_turns=2) + cr = make_credit_return(credit, error="boom") + obj.on_failed_credit(cr) + orch.terminate_child.assert_not_called() + + def test_child_non_final_with_error_terminates(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + credit = make_credit(agent_depth=1, turn_index=0, num_turns=2) + cr = make_credit_return(credit, error="timeout") + obj.on_failed_credit(cr) + orch.terminate_child.assert_called_once_with(credit) + + def test_child_non_final_cancelled_terminates(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + credit = make_credit(agent_depth=2, turn_index=0, num_turns=3) + cr = make_credit_return(credit, cancelled=True) + obj.on_failed_credit(cr) + orch.terminate_child.assert_called_once_with(credit) + + def test_child_depth_greater_than_one_terminates(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + credit = make_credit(agent_depth=5, turn_index=1, num_turns=4) + cr = make_credit_return(credit, error="err") + obj.on_failed_credit(cr) + orch.terminate_child.assert_called_once_with(credit) + + @pytest.mark.parametrize("num_turns", [1, 3, 10]) + def test_only_final_turn_skipped(self, num_turns: int) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + last = num_turns - 1 + credit = make_credit(agent_depth=1, turn_index=last, num_turns=num_turns) + cr = make_credit_return(credit, error="err") + obj.on_failed_credit(cr) + orch.terminate_child.assert_not_called() + + +class TestSubagentMixinCleanup: + def test_cleanup_no_subagents_does_nothing(self) -> None: + obj = ConcreteSubagentMixin() + obj._init_subagents(None) + obj.cleanup() # should not raise + + def test_cleanup_delegates_to_orchestrator(self) -> None: + obj = ConcreteSubagentMixin() + orch = make_orchestrator() + obj._init_subagents(orch) + obj.cleanup() + orch.cleanup.assert_called_once() diff --git a/tests/unit/timing/test_subagent_manager.py b/tests/unit/timing/test_subagent_manager.py new file mode 100644 index 000000000..78a75c8b8 --- /dev/null +++ b/tests/unit/timing/test_subagent_manager.py @@ -0,0 +1,530 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for SubagentOrchestrator: spawn/join, child dispatch, metrics, cleanup. + +Tests the orchestrator component directly and via strategy integration. +Uses the composition API: intercept(), terminate_child(), cleanup(), get_stats(). +""" + +from unittest.mock import AsyncMock, MagicMock + +from aiperf.common.enums import CreditPhase, PrerequisiteKind +from aiperf.common.models import ( + ConversationMetadata, + DatasetMetadata, + SubagentSpawnInfo, + TurnMetadata, + TurnPrerequisite, +) +from aiperf.credit.structs import Credit, TurnToSend +from aiperf.plugin.enums import DatasetSamplingStrategy +from aiperf.timing.conversation_source import ConversationSource +from aiperf.timing.subagent_orchestrator import SubagentOrchestrator +from tests.unit.timing.conftest import make_sampler + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _make_credit( + *, + conv_id: str = "conv_0", + corr_id: str = "xcorr-1", + turn_index: int = 0, + num_turns: int = 5, + agent_depth: int = 0, +) -> Credit: + return Credit( + id=1, + phase=CreditPhase.PROFILING, + conversation_id=conv_id, + x_correlation_id=corr_id, + turn_index=turn_index, + num_turns=num_turns, + issued_at_ns=0, + agent_depth=agent_depth, + ) + + +def _make_dataset_and_source( + *, + spawn_at: int = 2, + join_at: int | None = None, + num_children: int = 2, + child_turns: int = 3, + is_background: bool = False, +) -> tuple[ConversationSource, list[str]]: + join_at = spawn_at + 1 if join_at is None else join_at + child_conv_ids = [f"conv_0_s0_c{ci}" for ci in range(num_children)] + spawn = SubagentSpawnInfo( + spawn_id="s0", + child_conversation_ids=child_conv_ids, + is_background=is_background, + ) + + parent_turns = [] + for i in range(6): + spawn_ids = ["s0"] if i == spawn_at else [] + prereqs = [] + if i == join_at and not is_background: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0") + ] + parent_turns.append( + TurnMetadata( + delay_ms=200.0 if i > 0 else None, + input_tokens=500 + i * 100, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="conv_0", + turns=parent_turns, + subagent_spawns=[spawn], + ) + ] + for child_id in child_conv_ids: + convs.append( + ConversationMetadata( + conversation_id=child_id, + turns=[ + TurnMetadata(input_tokens=300 + j * 50) for j in range(child_turns) + ], + agent_depth=1, + ) + ) + + ds = DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + sampler = make_sampler(["conv_0"], ds.sampling_strategy) + src = ConversationSource(ds, sampler) + return src, child_conv_ids + + +def _make_orchestrator( + *, + spawn_at: int = 2, + join_at: int | None = None, + num_children: int = 2, + is_background: bool = False, +) -> tuple[SubagentOrchestrator, MagicMock, MagicMock, MagicMock, list[str]]: + src, child_conv_ids = _make_dataset_and_source( + spawn_at=spawn_at, + join_at=join_at, + num_children=num_children, + is_background=is_background, + ) + + issuer = MagicMock() + issuer.issue_credit = AsyncMock(return_value=True) + + stop_checker = MagicMock() + stop_checker.can_send_any_turn = MagicMock(return_value=True) + + scheduler = MagicMock() + scheduler.execute_async = MagicMock() + dispatched: list[TurnToSend] = [] + + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=issuer, + stop_checker=stop_checker, + scheduler=scheduler, + dispatch_fn=lambda turn: dispatched.append(turn), + ) + orch._test_dispatched = dispatched # type: ignore[attr-defined] + return orch, issuer, stop_checker, scheduler, child_conv_ids + + +def _spawn_children(orch): + """Spawn blocking children via intercept and return child corr_ids.""" + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + orch.intercept(credit) + child_corr_ids = list(orch._child_to_gate.keys()) + return credit, child_corr_ids + + +# ============================================================================= +# Spawn/join state machine +# ============================================================================= + + +class TestSubagentSpawnAndJoin: + """Core spawn -> child dispatch -> join lifecycle.""" + + def test_intercept_non_spawn_turn_returns_false(self): + orch, _, _, _, _ = _make_orchestrator() + credit = _make_credit(turn_index=0, num_turns=6) + assert orch.intercept(credit) is False + + def test_intercept_spawn_turn_returns_true(self): + orch, _, _, _, _ = _make_orchestrator() + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + assert orch.intercept(credit) is True + + def test_spawn_creates_pending_gate(self): + orch, _, _, _, _ = _make_orchestrator() + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + orch.intercept(credit) + + assert "parent-1" in orch._gated_turns + gate = orch._gated_turns["parent-1"] + assert gate.outstanding == {"spawn_join:s0": [2, 0]} + assert gate.gated_turn_index == 3 + + def test_spawn_dispatches_children_via_scheduler(self): + orch, _, _, scheduler, _ = _make_orchestrator() + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + orch.intercept(credit) + + assert scheduler.execute_async.call_count == 2 + assert orch._stats.children_spawned == 2 + + def test_child_non_final_dispatches_next_turn(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + dispatched = orch._test_dispatched # type: ignore[attr-defined] + assert len(dispatched) == 1 + assert dispatched[0].conversation_id == child_ids[0] + assert dispatched[0].turn_index == 1 + + def test_child_final_completes_gate(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + # Complete all children + for i, corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + assert orch._stats.children_completed == 2 + assert orch._stats.parents_resumed == 1 + assert "parent-1" not in orch._gated_turns + + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 1 + assert join_turns[0].turn_index == 3 + + def test_background_spawn_does_not_suspend_parent(self): + orch, _, _, scheduler, child_ids = _make_orchestrator(is_background=True) + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + handled = orch.intercept(credit) + + assert handled is False + assert "parent-1" not in orch._gated_turns + assert len(orch._child_to_gate) == 0 + + dispatched = orch._test_dispatched # type: ignore[attr-defined] + bg_dispatches = [d for d in dispatched if d.agent_depth == 1] + assert len(bg_dispatches) == 2 + assert scheduler.execute_async.call_count == 0 + + def test_delayed_join_does_not_suspend_until_gated_turn(self): + orch, _, _, scheduler, _ = _make_orchestrator(join_at=5) + spawn_credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + + assert orch.intercept(spawn_credit) is False + assert "parent-1" not in orch._gated_turns + assert orch._future_gates["parent-1"][5].outstanding == { + "spawn_join:s0": [2, 0] + } + assert orch._stats.parents_suspended == 0 + assert scheduler.execute_async.call_count == 2 + + mid_credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=3, num_turns=6 + ) + assert orch.intercept(mid_credit) is False + assert "parent-1" not in orch._gated_turns + + pre_join_credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=4, num_turns=6 + ) + assert orch.intercept(pre_join_credit) is True + assert orch._gated_turns["parent-1"].gated_turn_index == 5 + assert orch._stats.parents_suspended == 1 + + def test_delayed_join_completion_before_gate_does_not_dispatch_early(self): + orch, _, _, _, child_ids = _make_orchestrator(join_at=5) + _, child_corr_ids = _spawn_children(orch) + + for i, corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + assert "parent-1" not in orch._gated_turns + assert "parent-1" not in orch._future_gates + assert orch._stats.parents_suspended == 0 + assert orch._stats.parents_resumed == 0 + + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 0 + + def test_delayed_join_dispatches_when_children_finish_after_block(self): + orch, _, _, _, child_ids = _make_orchestrator(join_at=5) + _, child_corr_ids = _spawn_children(orch) + + mid_credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=3, num_turns=6 + ) + orch.intercept(mid_credit) + + pre_join_credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=4, num_turns=6 + ) + assert orch.intercept(pre_join_credit) is True + + for i, corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + assert "parent-1" not in orch._gated_turns + assert orch._stats.parents_resumed == 1 + + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 1 + assert join_turns[0].turn_index == 5 + + +# ============================================================================= +# Metrics/stats collection +# ============================================================================= + + +class TestSubagentStats: + """Observability counters.""" + + def test_get_stats_returns_all_counters(self): + orch, _, _, _, _ = _make_orchestrator() + stats = orch.get_stats() + assert set(stats.keys()) == { + "subagent_children_spawned", + "subagent_children_completed", + "subagent_children_errored", + "subagent_parents_suspended", + "subagent_parents_resumed", + "subagent_joins_suppressed", + } + + def test_spawn_increments_counters(self): + orch, _, _, _, _ = _make_orchestrator() + _spawn_children(orch) + assert orch._stats.children_spawned == 2 + assert orch._stats.parents_suspended == 1 + + def test_error_increments_errored(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.terminate_child(child_credit) + assert orch._stats.children_errored == 1 + + +# ============================================================================= +# Cleanup +# ============================================================================= + + +class TestSubagentCleanup: + """Cleanup clears all tracking state.""" + + def test_cleanup_clears_state(self): + orch, _, _, _, _ = _make_orchestrator() + _spawn_children(orch) + + orch.cleanup() + + assert len(orch._gated_turns) == 0 + assert len(orch._child_to_gate) == 0 + assert len(orch._terminated_children) == 0 + assert orch._cleaning_up is True + + def test_intercept_after_cleanup_returns_false(self): + orch, _, _, _, _ = _make_orchestrator() + orch.cleanup() + + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + assert orch.intercept(credit) is False + + def test_cleanup_clears_future_gates(self): + orch, _, _, _, _ = _make_orchestrator(join_at=5) + _spawn_children(orch) + + assert "parent-1" in orch._future_gates + + orch.cleanup() + + assert len(orch._future_gates) == 0 + + +# ============================================================================= +# Strategy integration: subagent hooks are one-liners +# ============================================================================= + + +class TestStrategyIntegration: + """Strategy uses intercept() in handle_credit_return.""" + + def test_intercept_in_handle_credit_return_pattern(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + + # Strategy pattern: if intercept returns True, strategy returns early + handled = orch.intercept(child_credit) + assert handled is True + + def test_root_credit_passes_through(self): + orch, _, _, _, _ = _make_orchestrator() + credit = _make_credit(turn_index=1, num_turns=6) + assert orch.intercept(credit) is False + + def test_terminate_child_dispatches_join_when_last(self): + orch, _, _, _, child_ids = _make_orchestrator(num_children=1) + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.terminate_child(child_credit) + + assert orch._stats.children_errored == 1 + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 1 + assert join_turns[0].turn_index == 3 + + +# ============================================================================= +# Turn-0 background pre-dispatch +# ============================================================================= + + +class TestTurn0BackgroundSpawns: + """dispatch_turn0_background_spawns pre-dispatches background children.""" + + def test_dispatch_turn0_background_spawns(self): + """Background children on turn 0 are pre-dispatched.""" + child_conv_ids = ["conv_0_s0_c0", "conv_0_s0_c1"] + spawn = SubagentSpawnInfo( + spawn_id="s0", + child_conversation_ids=child_conv_ids, + is_background=True, + ) + + parent_turns = [] + for i in range(4): + spawn_ids = ["s0"] if i == 0 else [] + parent_turns.append( + TurnMetadata( + input_tokens=500, + subagent_spawn_ids=spawn_ids, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="conv_0", + turns=parent_turns, + subagent_spawns=[spawn], + ) + ] + for child_id in child_conv_ids: + convs.append( + ConversationMetadata( + conversation_id=child_id, + turns=[TurnMetadata(input_tokens=300) for _ in range(2)], + agent_depth=1, + ) + ) + + ds = DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + sampler = make_sampler(["conv_0"], ds.sampling_strategy) + src = ConversationSource(ds, sampler) + + dispatched: list[TurnToSend] = [] + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=MagicMock(issue_credit=AsyncMock(return_value=True)), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + scheduler=MagicMock(execute_async=MagicMock()), + dispatch_fn=lambda turn: dispatched.append(turn), + ) + + orch.dispatch_turn0_background_spawns() + + assert len(dispatched) == 2 + assert all(d.agent_depth == 1 for d in dispatched) + assert orch._stats.children_spawned == 2 diff --git a/tests/unit/timing/test_subagent_manager_adversary.py b/tests/unit/timing/test_subagent_manager_adversary.py new file mode 100644 index 000000000..f080e766a --- /dev/null +++ b/tests/unit/timing/test_subagent_manager_adversary.py @@ -0,0 +1,794 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Adversary tests for SubagentOrchestrator flaw fixes. + +Each test targets a specific attempt to break the terminated-children tracking, +cleanup, depth-aware gating, or state machine invariants. + +Uses the composition API: intercept(), terminate_child(), cleanup(), get_stats(). +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from aiperf.common.enums import CreditPhase, PrerequisiteKind +from aiperf.common.models import ( + ConversationMetadata, + DatasetMetadata, + SubagentSpawnInfo, + TurnMetadata, + TurnPrerequisite, +) +from aiperf.credit.structs import Credit, TurnToSend +from aiperf.plugin.enums import DatasetSamplingStrategy +from aiperf.timing.conversation_source import ConversationSource +from aiperf.timing.subagent_orchestrator import ( + PendingTurnGate, + SubagentOrchestrator, +) +from tests.unit.timing.conftest import make_sampler + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _make_credit( + *, + conv_id: str = "conv_0", + corr_id: str = "xcorr-1", + turn_index: int = 0, + num_turns: int = 5, + agent_depth: int = 0, +) -> Credit: + return Credit( + id=1, + phase=CreditPhase.PROFILING, + conversation_id=conv_id, + x_correlation_id=corr_id, + turn_index=turn_index, + num_turns=num_turns, + issued_at_ns=0, + agent_depth=agent_depth, + ) + + +def _make_dataset_and_source( + *, + spawn_at: int = 2, + num_children: int = 2, + child_turns: int = 3, + is_background: bool = False, +) -> tuple[ConversationSource, list[str]]: + child_conv_ids = [f"conv_0_s0_c{ci}" for ci in range(num_children)] + spawn = SubagentSpawnInfo( + spawn_id="s0", + child_conversation_ids=child_conv_ids, + is_background=is_background, + ) + + parent_turns = [] + for i in range(6): + spawn_ids = ["s0"] if i == spawn_at else [] + prereqs = [] + if i == spawn_at + 1 and not is_background: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0") + ] + parent_turns.append( + TurnMetadata( + delay_ms=200.0 if i > 0 else None, + input_tokens=500 + i * 100, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="conv_0", + turns=parent_turns, + subagent_spawns=[spawn], + ) + ] + for child_id in child_conv_ids: + convs.append( + ConversationMetadata( + conversation_id=child_id, + turns=[ + TurnMetadata(input_tokens=300 + j * 50) for j in range(child_turns) + ], + agent_depth=1, + ) + ) + + ds = DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + sampler = make_sampler(["conv_0"], ds.sampling_strategy) + src = ConversationSource(ds, sampler) + return src, child_conv_ids + + +def _make_orchestrator( + *, + spawn_at: int = 2, + num_children: int = 2, + is_background: bool = False, +) -> tuple[SubagentOrchestrator, MagicMock, MagicMock, MagicMock, list[str]]: + src, child_conv_ids = _make_dataset_and_source( + spawn_at=spawn_at, + num_children=num_children, + is_background=is_background, + ) + + issuer = MagicMock() + issuer.issue_credit = AsyncMock(return_value=True) + + stop_checker = MagicMock() + stop_checker.can_send_any_turn = MagicMock(return_value=True) + + scheduler = MagicMock() + scheduler.execute_async = MagicMock() + dispatched: list[TurnToSend] = [] + + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=issuer, + stop_checker=stop_checker, + scheduler=scheduler, + dispatch_fn=lambda turn: dispatched.append(turn), + ) + orch._test_dispatched = dispatched # type: ignore[attr-defined] + return orch, issuer, stop_checker, scheduler, child_conv_ids + + +def _spawn_children(orch): + """Spawn blocking children via intercept and return child corr_ids.""" + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + orch.intercept(credit) + child_corr_ids = list(orch._child_to_gate.keys()) + return credit, child_corr_ids + + +# ============================================================================= +# Double-event attacks: error + cancel on same child +# ============================================================================= + + +class TestDoubleTerminationRace: + """Try to corrupt state by erroring AND cancelling the same child.""" + + def test_error_then_cancel_same_child_idempotent(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + + orch.terminate_child(child_credit) + assert child_credit.x_correlation_id in orch._terminated_children + + # Second terminate: child already popped from _child_to_gate + orch.terminate_child(child_credit) + + assert orch._terminated_children == {child_credit.x_correlation_id} + assert orch._stats.children_errored == 1 + + def test_error_then_is_terminated_drains_once(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + + orch.terminate_child(child_credit) + assert orch._is_terminated(child_credit) is True + assert orch._is_terminated(child_credit) is False + + +# ============================================================================= +# Cleanup edge cases +# ============================================================================= + + +class TestCleanupBoundary: + """Attack cleanup with boundary conditions.""" + + def test_cleanup_clears_all_state(self): + orch, _, _, _, _ = _make_orchestrator() + _spawn_children(orch) + + orch.cleanup() + + assert len(orch._gated_turns) == 0 + assert len(orch._child_to_gate) == 0 + assert len(orch._terminated_children) == 0 + + def test_double_cleanup_is_safe(self): + orch, _, _, _, _ = _make_orchestrator() + _spawn_children(orch) + + orch.cleanup() + orch.cleanup() + + assert len(orch._gated_turns) == 0 + assert len(orch._child_to_gate) == 0 + + def test_cleanup_multiple_leaked_gates_all_abandoned(self): + orch, _, _, _, _ = _make_orchestrator() + + for i in range(5): + orch._gated_turns[f"p-{i}"] = PendingTurnGate( + parent_conversation_id=f"conv_{i}", + parent_correlation_id=f"p-{i}", + gated_turn_index=2, + parent_num_turns=6, + parent_agent_depth=0, + created_at_ns=0, + outstanding={"spawn_join:s0": [3, i]}, + ) + + orch.cleanup() + assert len(orch._gated_turns) == 0 + + def test_cleanup_with_created_at_ns_zero_does_not_crash(self): + orch, _, _, _, _ = _make_orchestrator() + + orch._gated_turns["p-1"] = PendingTurnGate( + parent_conversation_id="conv_0", + parent_correlation_id="p-1", + gated_turn_index=3, + parent_num_turns=6, + parent_agent_depth=0, + created_at_ns=0, + outstanding={"spawn_join:s0": [1, 0]}, + ) + + orch.cleanup() + + +# ============================================================================= +# Late arrivals after cleanup +# ============================================================================= + + +class TestLateArrivalsAfterCleanup: + """Children completing after cleanup has cleared all tracking state.""" + + def test_intercept_child_after_cleanup_passes_through(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + orch.cleanup() + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=2, + num_turns=3, + agent_depth=1, + ) + # After cleanup, intercept returns False (cleaning_up flag set) + handled = orch.intercept(child_credit) + assert handled is False + + def test_terminate_child_after_cleanup_is_noop(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + orch.cleanup() + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.terminate_child(child_credit) + assert orch._stats.children_errored == 0 + + def test_intercept_after_cleanup_returns_false(self): + orch, _, _, _, _ = _make_orchestrator() + orch.cleanup() + + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=3, num_turns=6 + ) + handled = orch.intercept(credit) + assert handled is False + + +# ============================================================================= +# Root credits should never trigger terminated tracking +# ============================================================================= + + +class TestRootCreditTerminationGuard: + """Root credits (depth=0) must never enter terminated tracking.""" + + def test_terminate_root_credit_is_noop(self): + orch, _, _, _, _ = _make_orchestrator() + + root_credit = _make_credit( + conv_id="conv_0", + corr_id="root-1", + turn_index=1, + num_turns=5, + agent_depth=0, + ) + orch.terminate_child(root_credit) + assert "root-1" not in orch._terminated_children + assert orch._stats.children_errored == 0 + + +# ============================================================================= +# Final turn guards +# ============================================================================= + + +class TestFinalTurnTerminationGuard: + """Error/cancel on a FINAL child turn must not terminate-track.""" + + def test_terminate_final_child_is_noop(self): + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=2, # final turn (num_turns=3) + num_turns=3, + agent_depth=1, + ) + orch.terminate_child(child_credit) + assert child_credit.x_correlation_id not in orch._terminated_children + + +# ============================================================================= +# Untracked child: error on a child NOT in _child_to_gate +# ============================================================================= + + +class TestUntrackedChildTermination: + """Children not in tracking map (background or unknown) must not corrupt state.""" + + def test_terminate_untracked_child_is_noop(self): + orch, _, _, _, _ = _make_orchestrator() + + unknown_child = _make_credit( + conv_id="unknown-conv", + corr_id="unknown-corr", + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.terminate_child(unknown_child) + assert "unknown-corr" not in orch._terminated_children + assert orch._stats.children_errored == 0 + + +# ============================================================================= +# Stop condition suppression +# ============================================================================= + + +class TestStopConditionSuppression: + """Gated turn dispatch suppressed when stop condition fires.""" + + def test_gate_suppressed_when_stop_fired(self): + orch, _, stop_checker, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + stop_checker.can_send_any_turn.return_value = False + + for i, child_corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=child_corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + assert orch._stats.joins_suppressed == 1 + + def test_terminate_all_children_with_stop_fired_suppresses_gate(self): + orch, _, stop_checker, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + stop_checker.can_send_any_turn.return_value = False + + for i, child_corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=child_corr_id, + turn_index=0, + num_turns=3, + agent_depth=1, + ) + orch.terminate_child(child_credit) + + assert orch._stats.joins_suppressed == 1 + + +# ============================================================================= +# Terminate-then-intercept races +# ============================================================================= + + +class TestTerminateThenIntercept: + """Race between terminate_child and subsequent intercept for the same child.""" + + def test_terminate_then_intercept_child_final(self): + """Error a child via terminate, then its final turn arrives. Gate accounting + should count the child once (from terminate), not double-count.""" + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit_nonfinal = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + + # Terminate the child (non-final turn) + orch.terminate_child(child_credit_nonfinal) + assert orch._stats.children_errored == 1 + + # Now the final turn arrives for same child via intercept + child_credit_final = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit_final) + + # Child already popped from _child_to_gate by terminate, + # so final intercept does not double-count + parent_gates = list(orch._gated_turns.values()) + assert len(parent_gates) == 1 + assert ( + parent_gates[0].outstanding["spawn_join:s0"][1] == 1 + ) # only from terminate + + def test_terminate_then_non_final_intercept_suppressed(self): + """After terminate, the next non-final intercept for that child + should be suppressed (is_terminated consumed once).""" + orch, _, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + agent_depth=1, + ) + + orch.terminate_child(child_credit) + dispatched_before = len(orch._test_dispatched) # type: ignore[attr-defined] + + # Non-final intercept: should be suppressed (is_terminated == True, consumed) + child_credit_next = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=1, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit_next) + dispatched_after = len(orch._test_dispatched) # type: ignore[attr-defined] + + # No new dispatch happened (is_terminated consumed the credit) + assert dispatched_after == dispatched_before + + # Second non-final intercept: is_terminated already consumed, so dispatch fires + child_credit_next2 = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=1, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit_next2) + dispatched_final = len(orch._test_dispatched) # type: ignore[attr-defined] + assert dispatched_final == dispatched_after + 1 + + +# ============================================================================= +# Stop fires mid-spawn +# ============================================================================= + + +class TestStopMidSpawn: + """Stop condition fires between spawn resolution and child completion.""" + + def test_stop_fires_mid_spawn_resolution(self): + """Stop fires after children dispatched. When children complete, + gated turn suppressed because can_send_any_turn is False.""" + orch, _, stop_checker, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + dispatched_before = len(orch._test_dispatched) # type: ignore[attr-defined] + + # Stop fires after children are dispatched + stop_checker.can_send_any_turn.return_value = False + + # Children complete (final turns) + for i, child_corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=child_corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + # Gated turn should be suppressed (no new dispatch for parent) + assert orch._stats.joins_suppressed == 1 + parent_dispatches = [ + t + for t in orch._test_dispatched[dispatched_before:] # type: ignore[attr-defined] + if t.conversation_id == "conv_0" + ] + assert len(parent_dispatches) == 0 + + +# ============================================================================= +# Concurrent spawns on different parents +# ============================================================================= + + +class TestConcurrentSpawnsOnDifferentParents: + """Two parents spawn children simultaneously; each tracked independently.""" + + def test_concurrent_spawns_on_different_parents(self): + """Two parents each spawn children. Their gates tracked independently.""" + parent_ids = ["parent_A", "parent_B"] + convs = [] + all_child_ids: dict[str, list[str]] = {} + + for pid in parent_ids: + child_ids = [f"{pid}_s0_c0", f"{pid}_s0_c1"] + all_child_ids[pid] = child_ids + spawn = SubagentSpawnInfo( + spawn_id="s0", + child_conversation_ids=child_ids, + ) + turns = [] + for i in range(5): + spawn_ids = ["s0"] if i == 2 else [] + prereqs = [] + if i == 3: + prereqs = [ + TurnPrerequisite( + kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0" + ) + ] + turns.append( + TurnMetadata( + delay_ms=200.0 if i > 0 else None, + input_tokens=500, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + convs.append( + ConversationMetadata( + conversation_id=pid, + turns=turns, + subagent_spawns=[spawn], + ) + ) + for cid in child_ids: + convs.append( + ConversationMetadata( + conversation_id=cid, + turns=[TurnMetadata(input_tokens=300) for _ in range(3)], + agent_depth=1, + ) + ) + + ds = DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + sampler = make_sampler(parent_ids, ds.sampling_strategy) + src = ConversationSource(ds, sampler) + + issuer = MagicMock() + issuer.issue_credit = AsyncMock(return_value=True) + stop_checker = MagicMock() + stop_checker.can_send_any_turn = MagicMock(return_value=True) + scheduler = MagicMock() + scheduler.execute_async = MagicMock() + dispatched: list[TurnToSend] = [] + + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=issuer, + stop_checker=stop_checker, + scheduler=scheduler, + dispatch_fn=lambda turn: dispatched.append(turn), + ) + + # Parent A spawns + credit_a = _make_credit( + conv_id="parent_A", corr_id="corr-A", turn_index=2, num_turns=5 + ) + assert orch.intercept(credit_a) is True + + # Parent B spawns + credit_b = _make_credit( + conv_id="parent_B", corr_id="corr-B", turn_index=2, num_turns=5 + ) + assert orch.intercept(credit_b) is True + + assert "corr-A" in orch._gated_turns + assert "corr-B" in orch._gated_turns + assert orch._gated_turns["corr-A"].outstanding == {"spawn_join:s0": [2, 0]} + assert orch._gated_turns["corr-B"].outstanding == {"spawn_join:s0": [2, 0]} + + # Complete parent A's children + a_child_corr_ids = [ + k for k, v in orch._child_to_gate.items() if v.parent_corr_id == "corr-A" + ] + for i, corr_id in enumerate(a_child_corr_ids): + child_credit = _make_credit( + conv_id=all_child_ids["parent_A"][i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + # A's gate satisfied, B still pending + assert "corr-A" not in orch._gated_turns + assert "corr-B" in orch._gated_turns + join_a = [t for t in dispatched if t.conversation_id == "parent_A"] + assert len(join_a) == 1 + assert join_a[0].turn_index == 3 + + # Complete parent B's children + b_child_corr_ids = [ + k for k, v in orch._child_to_gate.items() if v.parent_corr_id == "corr-B" + ] + for i, corr_id in enumerate(b_child_corr_ids): + child_credit = _make_credit( + conv_id=all_child_ids["parent_B"][i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + assert "corr-B" not in orch._gated_turns + join_b = [t for t in dispatched if t.conversation_id == "parent_B"] + assert len(join_b) == 1 + + +# ============================================================================= +# Background child error isolation +# ============================================================================= + + +class TestBackgroundChildErrorIsolation: + """Background child errors must not corrupt gate tracking.""" + + def test_background_child_error_does_not_affect_gate(self): + """A background child that errors should not touch gated turns + or child_to_gate mappings.""" + orch, _, _, _, _ = _make_orchestrator() + _spawn_children(orch) + + # Create a background child credit (not tracked) + bg_credit = _make_credit( + conv_id="bg-child-conv", + corr_id="bg-corr-1", + turn_index=0, + num_turns=3, + agent_depth=1, + ) + + gates_before = dict(orch._gated_turns) + c2g_before = dict(orch._child_to_gate) + + # Terminate the background child -- it's not in _child_to_gate + orch.terminate_child(bg_credit) + + # No state change + assert orch._stats.children_errored == 0 + assert dict(orch._gated_turns) == gates_before + assert dict(orch._child_to_gate) == c2g_before + + +# ============================================================================= +# All blocking children fail to issue credit +# ============================================================================= + + +class TestAllBlockingChildrenFailToIssue: + """If ALL blocking children fail credit issuance, parent gate completes + with all errored.""" + + @pytest.mark.asyncio + async def test_issue_credit_failure_releases_all_blocking_children(self): + orch, issuer, _, _, child_ids = _make_orchestrator() + _, child_corr_ids = _spawn_children(orch) + + issuer.issue_credit.return_value = False + + # Simulate failed issuance for each blocking child + for i, child_corr_id in enumerate(child_corr_ids): + turn = TurnToSend( + conversation_id=child_ids[i], + x_correlation_id=child_corr_id, + turn_index=0, + num_turns=3, + ) + await orch._issue_child_credit_or_release(turn, child_corr_id) + + # All children errored + assert orch._stats.children_errored == 2 + + # Parent gate should have completed (all children released) + assert len(orch._gated_turns) == 0 + + # Parents resumed counter incremented + assert orch._stats.parents_resumed == 1 + + # Gated turn dispatched via dispatch_fn + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 1 + assert join_turns[0].turn_index == 3 # spawn_at(2) + 1 + + @pytest.mark.asyncio + async def test_issue_credit_exception_releases_child_from_gate(self): + """If issue_credit raises, the child is released from gate tracking + instead of leaking a pending gate forever.""" + orch, issuer, _, _, child_ids = _make_orchestrator(num_children=1) + _, child_corr_ids = _spawn_children(orch) + + issuer.issue_credit = AsyncMock(side_effect=RuntimeError("connection lost")) + + turn = TurnToSend( + conversation_id=child_ids[0], + x_correlation_id=child_corr_ids[0], + turn_index=0, + num_turns=3, + ) + await orch._issue_child_credit_or_release(turn, child_corr_ids[0]) + + assert orch._stats.children_errored == 1 + assert len(orch._gated_turns) == 0 + assert orch._stats.parents_resumed == 1 + + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 1 + assert join_turns[0].turn_index == 3 diff --git a/tests/unit/timing/test_subagent_orchestrator_gating.py b/tests/unit/timing/test_subagent_orchestrator_gating.py new file mode 100644 index 000000000..03ca3ed76 --- /dev/null +++ b/tests/unit/timing/test_subagent_orchestrator_gating.py @@ -0,0 +1,699 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for prerequisite-based gating in SubagentOrchestrator. + +Focuses on: +- PendingTurnGate.is_satisfied property semantics +- ChildGateEntry field tracking +- Multi-prerequisite gating (multiple spawns gate the same turn) +- _satisfy_prerequisite on future (not yet blocked) gates +- _find_gated_turn_index with multiple spawn IDs +- set_dispatch late-binding +- Gate pointing past conversation end +- _maybe_suspend_parent when future gate already satisfied +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pytest import param + +from aiperf.common.enums import CreditPhase, PrerequisiteKind +from aiperf.common.models import ( + ConversationMetadata, + DatasetMetadata, + SubagentSpawnInfo, + TurnMetadata, + TurnPrerequisite, +) +from aiperf.credit.structs import Credit, TurnToSend +from aiperf.plugin.enums import DatasetSamplingStrategy +from aiperf.timing.conversation_source import ConversationSource +from aiperf.timing.subagent_orchestrator import ( + ChildGateEntry, + PendingTurnGate, + SubagentOrchestrator, +) +from tests.unit.timing.conftest import make_sampler + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _make_credit( + *, + conv_id: str = "conv_0", + corr_id: str = "xcorr-1", + turn_index: int = 0, + num_turns: int = 5, + agent_depth: int = 0, +) -> Credit: + return Credit( + id=1, + phase=CreditPhase.PROFILING, + conversation_id=conv_id, + x_correlation_id=corr_id, + turn_index=turn_index, + num_turns=num_turns, + issued_at_ns=0, + agent_depth=agent_depth, + ) + + +def _make_multi_spawn_source( + *, + spawn_at_s0: int = 1, + spawn_at_s1: int = 2, + join_at: int = 3, +) -> tuple[ConversationSource, list[str], list[str]]: + """Create a dataset with two blocking spawns that gate the same turn.""" + s0_children = ["conv_0_s0_c0"] + s1_children = ["conv_0_s1_c0", "conv_0_s1_c1"] + spawn_s0 = SubagentSpawnInfo(spawn_id="s0", child_conversation_ids=s0_children) + spawn_s1 = SubagentSpawnInfo(spawn_id="s1", child_conversation_ids=s1_children) + + parent_turns = [] + for i in range(6): + spawn_ids = [] + if i == spawn_at_s0: + spawn_ids.append("s0") + if i == spawn_at_s1: + spawn_ids.append("s1") + prereqs = [] + if i == join_at: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0"), + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s1"), + ] + parent_turns.append( + TurnMetadata( + input_tokens=500 + i * 100, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="conv_0", + turns=parent_turns, + subagent_spawns=[spawn_s0, spawn_s1], + ) + ] + for cid in s0_children + s1_children: + convs.append( + ConversationMetadata( + conversation_id=cid, + turns=[TurnMetadata(input_tokens=300) for _ in range(2)], + agent_depth=1, + ) + ) + + ds = DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + sampler = make_sampler(["conv_0"], ds.sampling_strategy) + src = ConversationSource(ds, sampler) + return src, s0_children, s1_children + + +def _make_simple_orchestrator( + *, + spawn_at: int = 2, + join_at: int | None = None, + num_children: int = 2, + is_background: bool = False, + num_parent_turns: int = 6, +) -> tuple[SubagentOrchestrator, MagicMock, MagicMock, MagicMock, list[str]]: + join_at = spawn_at + 1 if join_at is None else join_at + child_conv_ids = [f"conv_0_s0_c{ci}" for ci in range(num_children)] + spawn = SubagentSpawnInfo( + spawn_id="s0", + child_conversation_ids=child_conv_ids, + is_background=is_background, + ) + + parent_turns = [] + for i in range(num_parent_turns): + spawn_ids = ["s0"] if i == spawn_at else [] + prereqs = [] + if i == join_at and not is_background: + prereqs = [ + TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0") + ] + parent_turns.append( + TurnMetadata( + input_tokens=500 + i * 100, + subagent_spawn_ids=spawn_ids, + prerequisites=prereqs, + ) + ) + + convs = [ + ConversationMetadata( + conversation_id="conv_0", + turns=parent_turns, + subagent_spawns=[spawn], + ) + ] + for child_id in child_conv_ids: + convs.append( + ConversationMetadata( + conversation_id=child_id, + turns=[TurnMetadata(input_tokens=300) for _ in range(3)], + agent_depth=1, + ) + ) + + ds = DatasetMetadata( + conversations=convs, + sampling_strategy=DatasetSamplingStrategy.SEQUENTIAL, + ) + sampler = make_sampler(["conv_0"], ds.sampling_strategy) + src = ConversationSource(ds, sampler) + + issuer = MagicMock() + issuer.issue_credit = AsyncMock(return_value=True) + stop_checker = MagicMock() + stop_checker.can_send_any_turn = MagicMock(return_value=True) + scheduler = MagicMock() + scheduler.execute_async = MagicMock() + dispatched: list[TurnToSend] = [] + + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=issuer, + stop_checker=stop_checker, + scheduler=scheduler, + dispatch_fn=lambda turn: dispatched.append(turn), + ) + orch._test_dispatched = dispatched # type: ignore[attr-defined] + return orch, issuer, stop_checker, scheduler, child_conv_ids + + +# ============================================================================= +# PendingTurnGate.is_satisfied +# ============================================================================= + + +class TestPendingTurnGateIsSatisfied: + """Verify is_satisfied semantics for various outstanding states.""" + + @pytest.mark.parametrize( + "outstanding,expected", + [ + ({}, True), + ({"spawn_join:s0": [2, 2]}, True), + ({"spawn_join:s0": [2, 1]}, False), + ({"spawn_join:s0": [2, 0]}, False), + param( + {"spawn_join:s0": [1, 1], "spawn_join:s1": [2, 2]}, + True, + id="multi-prereq-all-satisfied", + ), + param( + {"spawn_join:s0": [1, 1], "spawn_join:s1": [2, 1]}, + False, + id="multi-prereq-one-unsatisfied", + ), + param( + {"spawn_join:s0": [1, 5]}, + True, + id="completed-exceeds-expected", + ), + ], + ) # fmt: skip + def test_is_satisfied_returns_expected( + self, outstanding: dict[str, list[int]], expected: bool + ) -> None: + gate = PendingTurnGate( + parent_conversation_id="conv_0", + parent_correlation_id="p-1", + gated_turn_index=3, + outstanding=outstanding, + ) + assert gate.is_satisfied is expected + + +# ============================================================================= +# ChildGateEntry field tracking +# ============================================================================= + + +class TestChildGateEntryTracking: + """Verify ChildGateEntry fields are populated correctly after spawn.""" + + def test_child_gate_entry_fields_after_spawn(self) -> None: + orch, _, _, _, child_ids = _make_simple_orchestrator() + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + orch.intercept(credit) + + for _corr_id, entry in orch._child_to_gate.items(): + assert isinstance(entry, ChildGateEntry) + assert entry.parent_corr_id == "parent-1" + assert entry.gated_turn_index == 3 + assert entry.prereq_key == "spawn_join:s0" + + +# ============================================================================= +# Multi-prerequisite gating +# ============================================================================= + + +class TestMultiPrerequisiteGating: + """Two spawns gate the same turn; both must complete.""" + + def test_multi_spawn_both_must_complete(self) -> None: + src, s0_children, s1_children = _make_multi_spawn_source( + spawn_at_s0=1, spawn_at_s1=2, join_at=3 + ) + dispatched: list[TurnToSend] = [] + scheduler = MagicMock() + scheduler.execute_async = MagicMock() + + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=MagicMock(issue_credit=AsyncMock(return_value=True)), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + scheduler=scheduler, + dispatch_fn=lambda turn: dispatched.append(turn), + ) + + # Spawn s0 on turn 1 + credit_t1 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=1, num_turns=6 + ) + orch.intercept(credit_t1) + + # Spawn s1 on turn 2 + credit_t2 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + result = orch.intercept(credit_t2) + assert result is True # parent suspended at gated turn 3 + + gate = orch._gated_turns["parent-1"] + assert "spawn_join:s0" in gate.outstanding + assert "spawn_join:s1" in gate.outstanding + + # Complete s0 child (1 child) + s0_child_corr_ids = [ + k for k, v in orch._child_to_gate.items() if v.prereq_key == "spawn_join:s0" + ] + for corr_id in s0_child_corr_ids: + child_credit = _make_credit( + conv_id=s0_children[0], + corr_id=corr_id, + turn_index=1, + num_turns=2, + agent_depth=1, + ) + orch.intercept(child_credit) + + # s0 done but s1 not done -- gate still active + assert "parent-1" in orch._gated_turns + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 0 + + # Complete s1 children (2 children) + s1_child_corr_ids = [ + k for k, v in orch._child_to_gate.items() if v.prereq_key == "spawn_join:s1" + ] + for i, corr_id in enumerate(s1_child_corr_ids): + child_credit = _make_credit( + conv_id=s1_children[i], + corr_id=corr_id, + turn_index=1, + num_turns=2, + agent_depth=1, + ) + orch.intercept(child_credit) + + # Now both prerequisites satisfied + assert "parent-1" not in orch._gated_turns + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 1 + assert join_turns[0].turn_index == 3 + + +# ============================================================================= +# Satisfy prerequisite on future (unblocked) gate +# ============================================================================= + + +class TestSatisfyPrerequisiteOnFutureGate: + """When children complete before parent reaches gated turn, the + future gate is cleaned up silently (no dispatch).""" + + def test_children_complete_before_parent_blocks_cleans_future_gate(self) -> None: + orch, _, _, _, child_ids = _make_simple_orchestrator(spawn_at=1, join_at=5) + # Spawn on turn 1 (join at turn 5, so future gate created) + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=1, num_turns=6 + ) + orch.intercept(credit) + + assert "parent-1" in orch._future_gates + child_corr_ids = list(orch._child_to_gate.keys()) + + # Complete all children + for i, corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + # Future gate cleaned up + assert "parent-1" not in orch._future_gates + # No gated turn dispatched (parent not blocked yet) + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 0 + + def test_parent_not_suspended_when_future_gate_already_satisfied(self) -> None: + """If children complete before parent reaches gate, parent proceeds.""" + orch, _, _, _, child_ids = _make_simple_orchestrator(spawn_at=1, join_at=3) + credit_t1 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=1, num_turns=6 + ) + orch.intercept(credit_t1) + + child_corr_ids = list(orch._child_to_gate.keys()) + + # Complete all children before parent reaches turn 2 + for i, corr_id in enumerate(child_corr_ids): + child_credit = _make_credit( + conv_id=child_ids[i], + corr_id=corr_id, + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + # Parent turn 2 completes (next turn is 3, the gated turn) + credit_t2 = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + result = orch.intercept(credit_t2) + + # Parent should NOT be suspended -- prerequisites already met + assert result is False + assert orch._stats.parents_suspended == 0 + + +# ============================================================================= +# _find_gated_turn_index with multiple spawn IDs +# ============================================================================= + + +class TestFindGatedTurnIndex: + """_find_gated_turn_index returns first match from spawn_ids list.""" + + def test_find_gated_turn_index_returns_first_match(self) -> None: + src, _, _ = _make_multi_spawn_source(spawn_at_s0=1, spawn_at_s1=2, join_at=3) + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=MagicMock(issue_credit=AsyncMock(return_value=True)), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + scheduler=MagicMock(execute_async=MagicMock()), + dispatch_fn=lambda t: None, + ) + + # Both s0 and s1 point to turn 3 + assert orch._find_gated_turn_index("conv_0", ["s0"]) == 3 + assert orch._find_gated_turn_index("conv_0", ["s1"]) == 3 + assert orch._find_gated_turn_index("conv_0", ["s0", "s1"]) == 3 + + def test_find_gated_turn_index_unknown_spawn_returns_none(self) -> None: + src, _, _ = _make_multi_spawn_source() + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=MagicMock(issue_credit=AsyncMock(return_value=True)), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + scheduler=MagicMock(execute_async=MagicMock()), + dispatch_fn=lambda t: None, + ) + assert orch._find_gated_turn_index("conv_0", ["nonexistent"]) is None + assert orch._find_gated_turn_index("nonexistent_conv", ["s0"]) is None + + +# ============================================================================= +# set_dispatch late-binding +# ============================================================================= + + +class TestSetDispatch: + """set_dispatch allows late-binding the dispatch callback.""" + + def test_set_dispatch_replaces_callback(self) -> None: + orch, _, _, _, child_ids = _make_simple_orchestrator(num_children=1) + new_dispatched: list[TurnToSend] = [] + orch.set_dispatch(lambda turn: new_dispatched.append(turn)) + + # Spawn and complete child to trigger dispatch of gated turn + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + orch.intercept(credit) + + child_corr_ids = list(orch._child_to_gate.keys()) + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + # Gated turn dispatched via the NEW callback + join_turns = [t for t in new_dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 1 + assert join_turns[0].turn_index == 3 + + +# ============================================================================= +# Gate pointing past conversation end +# ============================================================================= + + +class TestGatePastConversationEnd: + """Gate turn_index >= num_turns means gated turn is unreachable.""" + + def test_release_blocked_gate_past_end_returns_none(self) -> None: + """When gated_turn_index >= num_turns, _release_blocked_gate pops the gate but dispatches nothing.""" + orch, _, _, _, child_ids = _make_simple_orchestrator( + num_children=1, num_parent_turns=4, spawn_at=2, join_at=3 + ) + + # Spawn on turn 2 → creates gate at turn 3 + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=4 + ) + orch.intercept(credit) + + # Mutate the gate to point past the end before any child completes + gate = orch._gated_turns.get("parent-1") + assert gate is not None + gate.gated_turn_index = 10 + + # Directly call _release_blocked_gate (the unit under test) + result = orch._release_blocked_gate("parent-1") + assert result is None + assert "parent-1" not in orch._gated_turns + + +# ============================================================================= +# Prerequisite index built correctly +# ============================================================================= + + +class TestPrerequisiteIndexBuilding: + """Verify the prerequisite index and spawn_join index are built at init.""" + + def test_prerequisite_index_populated(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator(spawn_at=2, join_at=3) + + # Turn 3 has a spawn_join prerequisite for s0 + assert ("conv_0", 3) in orch._prerequisite_index + prereqs = orch._prerequisite_index[("conv_0", 3)] + assert len(prereqs) == 1 + assert prereqs[0].kind == PrerequisiteKind.SPAWN_JOIN + assert prereqs[0].spawn_id == "s0" + + def test_spawn_join_index_populated(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator(spawn_at=2, join_at=3) + + assert ("conv_0", "s0") in orch._spawn_join_index + assert orch._spawn_join_index[("conv_0", "s0")] == 3 + + def test_multi_spawn_prerequisite_index(self) -> None: + src, _, _ = _make_multi_spawn_source(spawn_at_s0=1, spawn_at_s1=2, join_at=3) + orch = SubagentOrchestrator( + conversation_source=src, + credit_issuer=MagicMock(issue_credit=AsyncMock(return_value=True)), + stop_checker=MagicMock(can_send_any_turn=MagicMock(return_value=True)), + scheduler=MagicMock(execute_async=MagicMock()), + dispatch_fn=lambda t: None, + ) + + prereqs = orch._prerequisite_index[("conv_0", 3)] + assert len(prereqs) == 2 + assert orch._spawn_join_index[("conv_0", "s0")] == 3 + assert orch._spawn_join_index[("conv_0", "s1")] == 3 + + +# ============================================================================= +# _get_gate: active vs future lookup +# ============================================================================= + + +class TestGetGateLookup: + """_get_gate finds active blocked gates and future gates.""" + + def test_get_gate_returns_active_gate(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator() + gate = PendingTurnGate( + parent_conversation_id="conv_0", + parent_correlation_id="p-1", + gated_turn_index=3, + is_blocked=True, + outstanding={"spawn_join:s0": [2, 0]}, + ) + orch._gated_turns["p-1"] = gate + + result = orch._get_gate("p-1", 3) + assert result is gate + + def test_get_gate_returns_future_gate(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator() + gate = PendingTurnGate( + parent_conversation_id="conv_0", + parent_correlation_id="p-1", + gated_turn_index=5, + outstanding={"spawn_join:s0": [2, 0]}, + ) + orch._future_gates["p-1"] = {5: gate} + + result = orch._get_gate("p-1", 5) + assert result is gate + + def test_get_gate_wrong_turn_index_returns_none(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator() + gate = PendingTurnGate( + parent_conversation_id="conv_0", + parent_correlation_id="p-1", + gated_turn_index=3, + is_blocked=True, + outstanding={"spawn_join:s0": [2, 0]}, + ) + orch._gated_turns["p-1"] = gate + + assert orch._get_gate("p-1", 999) is None + + def test_get_gate_unknown_parent_returns_none(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator() + assert orch._get_gate("nonexistent", 3) is None + + +# ============================================================================= +# _iter_future_gates +# ============================================================================= + + +class TestIterFutureGates: + """_iter_future_gates flattens nested dict for cleanup.""" + + def test_iter_future_gates_flattens_correctly(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator() + gate_a = PendingTurnGate( + parent_conversation_id="conv_0", + parent_correlation_id="p-1", + gated_turn_index=3, + ) + gate_b = PendingTurnGate( + parent_conversation_id="conv_0", + parent_correlation_id="p-1", + gated_turn_index=5, + ) + gate_c = PendingTurnGate( + parent_conversation_id="conv_1", + parent_correlation_id="p-2", + gated_turn_index=4, + ) + orch._future_gates = { + "p-1": {3: gate_a, 5: gate_b}, + "p-2": {4: gate_c}, + } + + result = orch._iter_future_gates() + assert len(result) == 3 + # All gates present + gates_found = {(corr_id, gate.gated_turn_index) for corr_id, gate in result} + assert gates_found == {("p-1", 3), ("p-1", 5), ("p-2", 4)} + + def test_iter_future_gates_empty_returns_empty(self) -> None: + orch, _, _, _, _ = _make_simple_orchestrator() + assert orch._iter_future_gates() == [] + + +# ============================================================================= +# TurnPrerequisite model validation +# ============================================================================= + + +class TestTurnPrerequisiteModel: + """TurnPrerequisite model field behavior.""" + + def test_spawn_join_prerequisite_has_spawn_id(self) -> None: + prereq = TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN, spawn_id="s0") + assert prereq.kind == PrerequisiteKind.SPAWN_JOIN + assert prereq.spawn_id == "s0" + + def test_spawn_join_prerequisite_spawn_id_defaults_to_none(self) -> None: + prereq = TurnPrerequisite(kind=PrerequisiteKind.SPAWN_JOIN) + assert prereq.spawn_id is None + + +# ============================================================================= +# Stop condition narrowed to StopConditionChecker +# ============================================================================= + + +class TestStopCheckerType: + """Stop checker uses can_send_any_turn method.""" + + def test_stop_checker_can_send_any_turn_gates_dispatch(self) -> None: + orch, _, stop_checker, _, child_ids = _make_simple_orchestrator(num_children=1) + credit = _make_credit( + conv_id="conv_0", corr_id="parent-1", turn_index=2, num_turns=6 + ) + orch.intercept(credit) + + stop_checker.can_send_any_turn.return_value = False + + child_corr_ids = list(orch._child_to_gate.keys()) + child_credit = _make_credit( + conv_id=child_ids[0], + corr_id=child_corr_ids[0], + turn_index=2, + num_turns=3, + agent_depth=1, + ) + orch.intercept(child_credit) + + # Gated turn suppressed + dispatched = orch._test_dispatched # type: ignore[attr-defined] + join_turns = [t for t in dispatched if t.conversation_id == "conv_0"] + assert len(join_turns) == 0 + assert orch._stats.joins_suppressed == 1 + assert stop_checker.can_send_any_turn.call_count == 1 diff --git a/tests/unit/workers/test_session_manager.py b/tests/unit/workers/test_session_manager.py index f418e1293..b74201691 100644 --- a/tests/unit/workers/test_session_manager.py +++ b/tests/unit/workers/test_session_manager.py @@ -1,9 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for UserSessionManager to ensure Credit.num_turns is respected. +"""Unit tests for UserSessionManager. -These tests ensure that the worker properly uses Credit.num_turns instead of -len(conversation.turns), which is critical for ramp-up users who start mid-session. +Tests ensure: +- Credit.num_turns is respected (ramp-up users who start mid-session) +- ConversationContextMode controls history accumulation and response storage """ import pytest