diff --git a/docs/cli-options.md b/docs/cli-options.md index 79d54dd2b..e6ec2a4c5 100644 --- a/docs/cli-options.md +++ b/docs/cli-options.md @@ -223,6 +223,11 @@ Start offset in milliseconds for fixed schedule replay. Skips all requests befor End offset in milliseconds for fixed schedule replay. Stops issuing requests after this timestamp, allowing benchmark of specific trace subsets. Requests at exactly the end offset are included. Defaults to last timestamp in dataset. Must be ≥ `--fixed-schedule-start-offset` if both specified.
_Constraints: ≥ 0_ +#### `--fixed-schedule-speedup` `` + +Scaling factor for fixed schedule timestamps. A value of 2.0 replays the schedule twice as fast (halving inter-request delays), while 0.5 replays at half speed (doubling delays). Applied at the timing layer to any dataset using `--fixed-schedule`. +
_Constraints: > 0_ + #### `--public-dataset` `` Pre-configured public dataset to download and use for benchmarking (e.g., `sharegpt`). AIPerf automatically downloads and parses these datasets. Mutually exclusive with `--custom-dataset-type`. Run `aiperf plugins public_dataset_loader` to list available datasets. @@ -230,8 +235,8 @@ Pre-configured public dataset to download and use for benchmarking (e.g., `share #### `--custom-dataset-type` `` -Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), `mooncake_trace`/`bailian_trace` (timestamped trace files), `random_pool` (directory of reusable prompts; when using `random_pool`, `--conversation-num` defaults to 100 if not specified; batch sizes > 1 sample each modality independently from a flat pool and do not preserve per-entry associations — use `single_turn` if paired modalities must stay together). Requires `--input-file`. Mutually exclusive with `--public-dataset`. -
_Choices: [`bailian_trace`, `mooncake_trace`, `multi_turn`, `random_pool`, `single_turn`]_ +Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), `mooncake_trace`/`bailian_trace` (timestamped trace files), `conflux` (Conflux proxy capture with agent_id grouping and timestamp-based replay), `random_pool` (directory of reusable prompts; when using `random_pool`, `--conversation-num` defaults to 100 if not specified; batch sizes > 1 sample each modality independently from a flat pool and do not preserve per-entry associations — use `single_turn` if paired modalities must stay together). Requires `--input-file`. Mutually exclusive with `--public-dataset`. +
_Choices: [`bailian_trace`, `conflux`, `mooncake_trace`, `multi_turn`, `random_pool`, `single_turn`]_ #### `--dataset-sampling-strategy` `` @@ -246,6 +251,11 @@ Random seed for deterministic data generation. When set, makes synthetic prompts Specify service level objectives (SLOs) for goodput as space-separated 'KEY:VALUE' pairs, where KEY is a metric tag and VALUE is a number in the metric's display unit (falls back to its base unit if no display unit is defined). Examples: 'request_latency:250' (ms), 'inter_token_latency:10' (ms), `output_token_throughput_per_user:600` (tokens/s). Only metrics applicable to the current endpoint/config are considered. For more context on the definition of goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 and the blog: https://hao-ai-lab.github.io/blogs/distserve. +#### `--conflux-include-utility-calls` + +Include unattributed utility calls when loading Conflux proxy captures. These are lightweight model calls made by the client for housekeeping tasks (topic detection, title generation) that lack an agent_id and may fall outside the main session timeline. Applies when Conflux format is specified via --custom-dataset-type conflux or auto-detected from file contents. +
_Flag (no value required)_ + ### Audio Input #### `--audio-batch-size`, `--batch-size-audio` `` diff --git a/src/aiperf/common/config/input_config.py b/src/aiperf/common/config/input_config.py index e1488e170..4bbf65a37 100644 --- a/src/aiperf/common/config/input_config.py +++ b/src/aiperf/common/config/input_config.py @@ -95,6 +95,19 @@ def validate_fixed_schedule_start_and_end_offset(self) -> Self: ) return self + @model_validator(mode="after") + def validate_no_double_speedup(self) -> Self: + """Reject combining synthesis speedup with fixed-schedule speedup.""" + if ( + self.synthesis.speedup_ratio != 1.0 + and self.fixed_schedule_speedup is not None + ): + raise ValueError( + "--synthesis-speedup-ratio and --fixed-schedule-speedup cannot be used together. " + "Use --fixed-schedule-speedup to control replay speed at the timing layer." + ) + return self + @model_validator(mode="after") def validate_dataset_type(self) -> Self: """Validate the different dataset type configuration.""" @@ -267,6 +280,20 @@ def validate_goodput(self) -> Self: ), ] = InputDefaults.FIXED_SCHEDULE_END_OFFSET + fixed_schedule_speedup: Annotated[ + float | None, + Field( + gt=0, + description="Scaling factor for fixed schedule timestamps. A value of 2.0 replays the schedule twice as fast " + "(halving inter-request delays), while 0.5 replays at half speed (doubling delays). " + "Applied at the timing layer to any dataset using `--fixed-schedule`.", + ), + CLIParameter( + name=("--fixed-schedule-speedup",), + group=_CLI_GROUP, + ), + ] = None + public_dataset: Annotated[ PublicDatasetType | None, Field( @@ -285,7 +312,8 @@ def validate_goodput(self) -> Self: Field( description="Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. " "Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), " - "`mooncake_trace`/`bailian_trace` (timestamped trace files), `random_pool` (directory of reusable prompts; " + "`mooncake_trace`/`bailian_trace` (timestamped trace files), `conflux` (Conflux proxy capture with " + "agent_id grouping and timestamp-based replay), `random_pool` (directory of reusable prompts; " "when using `random_pool`, `--conversation-num` defaults to 100 if not specified; " "batch sizes > 1 sample each modality independently from a flat pool and do not preserve " "per-entry associations — use `single_turn` if paired modalities must stay together). " @@ -348,6 +376,21 @@ def validate_goodput(self) -> Self: ), ] = InputDefaults.GOODPUT + conflux_include_utility_calls: Annotated[ + bool, + Field( + description="Include unattributed utility calls when loading Conflux proxy captures. " + "These are lightweight model calls made by the client for housekeeping tasks " + "(topic detection, title generation) that lack an agent_id and may fall outside " + "the main session timeline. Applies when Conflux format is specified via " + "--custom-dataset-type conflux or auto-detected from file contents.", + ), + CLIParameter( + name=("--conflux-include-utility-calls",), + group=_CLI_GROUP, + ), + ] = False + audio: AudioConfig = AudioConfig() image: ImageConfig = ImageConfig() video: VideoConfig = VideoConfig() diff --git a/src/aiperf/common/models/dataset_models.py b/src/aiperf/common/models/dataset_models.py index 42b6ad4ca..0c1b72edc 100644 --- a/src/aiperf/common/models/dataset_models.py +++ b/src/aiperf/common/models/dataset_models.py @@ -158,6 +158,15 @@ class Turn(AIPerfBaseModel): videos: list[Video] = Field( default=[], description="Collection of video data in each turn." ) + input_tokens: int | None = Field( + default=None, + description="Expected input token count for this turn (from trace data).", + ) + extra_params: dict[str, Any] | None = Field( + default=None, + description="Per-turn hyperparameter overrides merged into the API payload at the top level. " + "Populated from dataset capture metadata.", + ) def metadata(self) -> TurnMetadata: """Get the metadata of the turn.""" @@ -209,6 +218,8 @@ def copy_with_stripped_media(self) -> "Turn": ) for vid in self.videos ], + input_tokens=self.input_tokens, + extra_params=dict(self.extra_params) if self.extra_params else None, ) diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index 3dd25ec25..a1ec382d1 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -5,11 +5,12 @@ from pathlib import Path from typing import Any +import orjson + from aiperf.common.config import UserConfig from aiperf.common.enums import ConversationContextMode from aiperf.common.models import Conversation from aiperf.common.tokenizer import Tokenizer -from aiperf.common.utils import load_json_str from aiperf.dataset.composer.base import BaseDatasetComposer from aiperf.dataset.loader.base_loader import BaseLoader from aiperf.dataset.utils import check_file_exists @@ -83,12 +84,20 @@ def _infer_dataset_type(self, file_path: str) -> CustomDatasetType: if path.is_dir(): return self._infer_type(data=None, filename=file_path) - # For files, read first non-empty line and use both content and path detection + # For files, read first non-empty line and use both content and path detection. + # If the first line isn't valid JSON (e.g. pretty-printed JSON arrays start + # with "["), fall through to filename-only detection so file-probing loaders + # like ConfluxLoader can inspect the file directly. with open(file_path) as f: for line in f: if not (line := line.strip()): continue - data = load_json_str(line) + try: + data = orjson.loads(line) + except orjson.JSONDecodeError: + return self._infer_type(data=None, filename=file_path) + if not isinstance(data, dict): + return self._infer_type(data=None, filename=file_path) return self._infer_type(data=data, filename=file_path) except ValueError as e: diff --git a/src/aiperf/dataset/loader/__init__.py b/src/aiperf/dataset/loader/__init__.py index b637b4c69..fc892bfb1 100644 --- a/src/aiperf/dataset/loader/__init__.py +++ b/src/aiperf/dataset/loader/__init__.py @@ -6,9 +6,11 @@ from aiperf.dataset.loader.base_loader import BaseFileLoader, BaseLoader from aiperf.dataset.loader.base_public_dataset import BasePublicDatasetLoader from aiperf.dataset.loader.base_trace_loader import BaseTraceDatasetLoader +from aiperf.dataset.loader.conflux import ConfluxLoader from aiperf.dataset.loader.mixins import MediaConversionMixin from aiperf.dataset.loader.models import ( BailianTrace, + ConfluxRecord, MooncakeTrace, MultiTurn, RandomPool, @@ -27,6 +29,8 @@ "BaseLoader", "BasePublicDatasetLoader", "BaseTraceDatasetLoader", + "ConfluxLoader", + "ConfluxRecord", "MediaConversionMixin", "MooncakeTrace", "MooncakeTraceDatasetLoader", diff --git a/src/aiperf/dataset/loader/conflux.py b/src/aiperf/dataset/loader/conflux.py new file mode 100644 index 000000000..582e2c2a3 --- /dev/null +++ b/src/aiperf/dataset/loader/conflux.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Conflux dataset loader for timestamp-based replay of proxy captures. + +Loads JSON files containing arrays of API request records. Groups records +by agent_id into independent Conversations with timestamp-based inter-turn +delays for fixed-schedule replay. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import orjson +from pydantic import ValidationError + +from aiperf.common.aiperf_logger import AIPerfLogger +from aiperf.common.enums import ConversationContextMode +from aiperf.common.models import Conversation, Turn +from aiperf.dataset.loader.base_loader import BaseFileLoader +from aiperf.dataset.loader.models import ConfluxRecord +from aiperf.plugin.enums import DatasetSamplingStrategy + +_EXTRA_PARAMS_SKIP = frozenset( + { + "max_tokens", + "max_completion_tokens", + "max_output_tokens", + } +) + + +_logger = AIPerfLogger(__name__) + + +class ConfluxLoader(BaseFileLoader): + """Dataset loader for Conflux proxy capture JSON files. + + Each agent_id group becomes an independent Conversation with + zero-aligned timestamps for fixed-schedule replay. + """ + + @classmethod + def get_default_context_mode(cls) -> ConversationContextMode: + return ConversationContextMode.MESSAGE_ARRAY_WITH_RESPONSES + + @classmethod + def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: + return DatasetSamplingStrategy.SEQUENTIAL + + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> bool: + """Return True if filename is a Conflux JSON file or directory.""" + if filename is None: + return False + path = Path(filename) + if path.is_dir(): + first = next(path.glob("*.json"), None) + return first is not None and cls._probe_file(first) + return cls._probe_file(path) + + @classmethod + def _probe_file(cls, path: Path) -> bool: + """Return True if the file loads as a valid Conflux JSON array.""" + if not path.is_file() or path.suffix != ".json": + return False + try: + raw_records: list[dict[str, Any]] = orjson.loads(path.read_bytes()) + if not raw_records or not isinstance(raw_records, list): + return False + ConfluxRecord.model_validate(raw_records[0]) + return True + except (orjson.JSONDecodeError, ValidationError) as e: + _logger.debug(f"Failed to validate Conflux JSON array: {e!r}") + return False + + def load_dataset(self) -> dict[str, list[ConfluxRecord]]: + """Load and group Conflux records by agent_id.""" + path = Path(self.filename) + if path.is_dir(): + return self._load_directory(path) + return self._load_single_file(self.filename) + + def _load_directory(self, path: Path) -> dict[str, list[ConfluxRecord]]: + """Load all JSON files in a directory as independent sessions.""" + json_files = sorted(path.glob("*.json")) + if not json_files: + raise FileNotFoundError( + f"No .json files found in directory: {self.filename}" + ) + + all_groups: dict[str, list[ConfluxRecord]] = {} + for file_idx, json_file in enumerate(json_files): + file_groups = self._load_single_file(str(json_file), prefix=f"f{file_idx}_") + all_groups.update(file_groups) + + total_records = sum(len(recs) for recs in all_groups.values()) + self.info( + f"Loaded {len(all_groups)} agent threads from " + f"{len(json_files)} files ({total_records} total records) in {path.name}/" + ) + return all_groups + + def _load_single_file( + self, filename: str, prefix: str = "" + ) -> dict[str, list[ConfluxRecord]]: + """Load and group records from a single JSON file.""" + raw_records: list[dict[str, Any]] = orjson.loads(Path(filename).read_bytes()) + + include_utility = self.user_config.input.conflux_include_utility_calls + + groups: dict[str, list[ConfluxRecord]] = {} + utility_count = 0 + + for raw in raw_records: + record = ConfluxRecord.model_validate(raw) + if record.agent_id is not None: + key = f"{prefix}{record.agent_id}" + groups.setdefault(key, []).append(record) + else: + if include_utility: + groups[f"{prefix}_utility_{utility_count}"] = [record] + utility_count += 1 + + for records in groups.values(): + records.sort(key=lambda r: r.timestamp) + + if not prefix: + total_records = sum(len(recs) for recs in groups.values()) + action = "included" if include_utility else "skipped" + utility_label = f"{utility_count} utility calls {action}" + self.info( + f"Loaded {len(groups)} agent threads + {utility_label} " + f"({total_records} total records)" + ) + + return groups + + def convert_to_conversations( + self, data: dict[str, list[ConfluxRecord]] + ) -> list[Conversation]: + """Convert grouped Conflux records to Conversation objects.""" + conversations = [ + self._build_conversation(agent_id, records) + for agent_id, records in data.items() + ] + + total_turns = sum(len(c.turns) for c in conversations) + self.info( + f"Converted {len(conversations)} conversations ({total_turns} total turns)" + ) + return conversations + + def _build_conversation( + self, + agent_id: str, + records: list[ConfluxRecord], + ) -> Conversation: + """Build a Conversation from a list of ConfluxRecords for one agent.""" + conversation = Conversation(session_id=f"conflux_{agent_id}") + + for record in records: + input_tokens = record.tokens.input if record.tokens else None + + max_tokens = None + if record.tokens is not None: + total_output = record.tokens.output + record.tokens.output_reasoning + max_tokens = total_output or None + + turn = Turn( + timestamp=record.timestamp, + max_tokens=max_tokens, + input_tokens=input_tokens, + raw_messages=record.messages, + raw_tools=record.tools or None, + extra_params=self._extract_extra_params(record), + ) + conversation.turns.append(turn) + + return conversation + + @staticmethod + def _extract_extra_params(record: ConfluxRecord) -> dict[str, Any] | None: + """Extract per-turn hyperparameter overrides from a ConfluxRecord.""" + if not record.hyperparameters: + return None + params = { + k: v + for k, v in record.hyperparameters.items() + if k not in _EXTRA_PARAMS_SKIP and v is not None + } + return params or None diff --git a/src/aiperf/dataset/loader/models.py b/src/aiperf/dataset/loader/models.py index 675b3c597..04cea11a1 100644 --- a/src/aiperf/dataset/loader/models.py +++ b/src/aiperf/dataset/loader/models.py @@ -336,8 +336,87 @@ class BailianTrace(AIPerfBaseModel): ) +class ConfluxTokens(AIPerfBaseModel): + """Normalized token counts across providers.""" + + input: int = Field( + default=0, + description="Total input tokens processed (all input the model saw). " + "For Anthropic: input_tokens + cache_creation_input_tokens + cache_read_input_tokens. " + "For OpenAI: prompt_tokens (already includes cached).", + ) + input_cached: int = Field( + default=0, + description="Tokens read from cache. " + "For Anthropic: cache_read_input_tokens. For OpenAI: equivalent cached tokens.", + ) + input_cache_write: int = Field( + default=0, + description="Tokens written to cache. For Anthropic: cache_creation_input_tokens.", + ) + output: int = Field(default=0, description="Total output tokens generated.") + output_reasoning: int = Field( + default=0, + description="Output tokens used for reasoning/thinking, when available from the provider.", + ) + + +class ConfluxRecord(AIPerfBaseModel): + """A single unified API call from a Conflux proxy capture. + + Minimal schema: only the fields needed for timestamp-based replay. + Extra fields in the JSON are silently ignored. + """ + + model_config = ConfigDict(extra="ignore") + + session_id: str = Field( + description="Session identifier grouping related API calls.", + ) + agent_id: str | None = Field( + default=None, + description="Identifier of the agent/persona that made this call.", + ) + is_subagent: bool | None = Field( + default=None, + description="Whether this call was made by a sub-agent.", + ) + timestamp: float = Field( + description="Request timestamp in milliseconds since epoch.", + ) + duration_ms: int | float = Field( + default=0, + description="Time in milliseconds from request start to response completion.", + ) + completed_at: str | None = Field( + default=None, + description="ISO 8601 timestamp when the API response completed.", + ) + tokens: ConfluxTokens | None = Field( + default=None, + description="Normalized token counts across providers.", + ) + messages: list[dict[str, Any]] = Field( + default_factory=list, + description="Input messages from the request.", + ) + tools: list[dict[str, Any]] = Field( + default_factory=list, + description="Tool definitions available to the model for this API call.", + ) + hyperparameters: dict[str, Any] | None = Field( + default=None, + description="Normalized generation hyperparameters.", + ) + + CustomDatasetT = TypeVar( "CustomDatasetT", - bound=SingleTurn | MultiTurn | RandomPool | MooncakeTrace | BailianTrace, + bound=SingleTurn + | MultiTurn + | RandomPool + | MooncakeTrace + | BailianTrace + | ConfluxRecord, ) """A union type of all custom data types.""" diff --git a/src/aiperf/endpoints/openai_chat.py b/src/aiperf/endpoints/openai_chat.py index 6fca2ff74..234fbd03e 100644 --- a/src/aiperf/endpoints/openai_chat.py +++ b/src/aiperf/endpoints/openai_chat.py @@ -39,47 +39,52 @@ def format_payload(self, request_info: RequestInfo) -> dict[str, Any]: raise ValueError("Chat endpoint requires at least one turn.") turns = request_info.turns + current_turn = turns[-1] model_endpoint = request_info.model_endpoint - if turns[-1].raw_messages is not None: - messages = turns[-1].raw_messages + if current_turn.raw_messages is not None: + messages = current_turn.raw_messages else: messages = self._create_messages( turns, request_info.system_message, request_info.user_context_message ) - payload = { - "messages": messages, - "model": turns[-1].model or model_endpoint.primary_model_name, - "stream": model_endpoint.endpoint.streaming, - } + payload = {"messages": messages} - if turns[-1].raw_tools is not None: - payload["tools"] = turns[-1].raw_tools + if current_turn.raw_tools is not None: + payload["tools"] = current_turn.raw_tools - if turns[-1].max_tokens is not None: + if current_turn.extra_params: + payload.update(current_turn.extra_params) + + if model_endpoint.endpoint.extra: + payload.update(model_endpoint.endpoint.extra) + + # Set max tokens, model, and stream after all other payload fields are set + # to avoid them being overwritten + if current_turn.max_tokens is not None: token_field = ( "max_tokens" if model_endpoint.endpoint.use_legacy_max_tokens else "max_completion_tokens" ) - payload[token_field] = turns[-1].max_tokens + payload[token_field] = current_turn.max_tokens - if model_endpoint.endpoint.extra: - payload.update(model_endpoint.endpoint.extra) + payload["model"] = current_turn.model or model_endpoint.primary_model_name + payload["stream"] = model_endpoint.endpoint.streaming if ( model_endpoint.endpoint.streaming and model_endpoint.endpoint.use_server_token_count ): - # Automatically set stream_options to include usage when using server token counts - if "stream_options" not in payload: + existing = payload.get("stream_options") + if existing is None: payload["stream_options"] = {"include_usage": True} - elif ( - isinstance(payload["stream_options"], dict) - and "include_usage" not in payload["stream_options"] - ): - payload["stream_options"]["include_usage"] = True + elif isinstance(existing, dict): + payload["stream_options"] = { + **existing, + "include_usage": existing.get("include_usage", True), + } self.trace(lambda: f"Formatted payload: {payload}") return payload diff --git a/src/aiperf/plugin/enums.py b/src/aiperf/plugin/enums.py index b05147f56..8aa750de9 100644 --- a/src/aiperf/plugin/enums.py +++ b/src/aiperf/plugin/enums.py @@ -59,7 +59,7 @@ CustomDatasetTypeStr: TypeAlias = str CustomDatasetType = plugins.create_enum(PluginType.CUSTOM_DATASET_LOADER, "CustomDatasetType", module=__name__) -"""Dynamic enum for custom dataset loader. Example: CustomDatasetType.BAILIAN_TRACE, CustomDatasetType.MOONCAKE_TRACE, CustomDatasetType.MULTI_TURN""" +"""Dynamic enum for custom dataset loader. Example: CustomDatasetType.BAILIAN_TRACE, CustomDatasetType.CONFLUX, CustomDatasetType.MOONCAKE_TRACE""" PublicDatasetTypeStr: TypeAlias = str PublicDatasetType = plugins.create_enum(PluginType.PUBLIC_DATASET_LOADER, "PublicDatasetType", module=__name__) diff --git a/src/aiperf/plugin/plugins.yaml b/src/aiperf/plugin/plugins.yaml index bd15b4515..9e8699cc3 100644 --- a/src/aiperf/plugin/plugins.yaml +++ b/src/aiperf/plugin/plugins.yaml @@ -444,8 +444,19 @@ custom_dataset_loader: conversation chains and fixed_schedule timing mode. metadata: is_trace: true + supports_timing: true default_block_size: 16 + conflux: + class: aiperf.dataset.loader.conflux:ConfluxLoader + description: | + Conflux proxy capture loader for verbatim replay of Claude Code and Codex sessions. + Loads JSON arrays of API request records with agent_id grouping and + timestamp-based delays for fixed-schedule replay. + metadata: + is_trace: false + supports_timing: true + mooncake_trace: class: aiperf.dataset.loader.mooncake_trace:MooncakeTraceDatasetLoader description: | @@ -453,6 +464,7 @@ custom_dataset_loader: timestamp-based replay support. Designed for fixed_schedule timing mode. metadata: is_trace: true + supports_timing: true default_block_size: 512 multi_turn: diff --git a/src/aiperf/timing/config.py b/src/aiperf/timing/config.py index 281302764..0982698eb 100644 --- a/src/aiperf/timing/config.py +++ b/src/aiperf/timing/config.py @@ -182,6 +182,12 @@ class CreditPhaseConfig(AIPerfBaseModel): ge=0, description="The fixed schedule end offset of the timing manager.", ) + fixed_schedule_speedup: float | None = Field( + default=None, + gt=0, + description="Scaling factor for fixed schedule timestamps. " + "2.0 = twice as fast, 0.5 = half speed. Only applicable when using fixed schedule timing mode.", + ) def _build_warmup_config(user_config: UserConfig) -> CreditPhaseConfig | None: @@ -267,4 +273,5 @@ def _build_profiling_config(user_config: UserConfig) -> CreditPhaseConfig: auto_offset_timestamps=input.fixed_schedule_auto_offset, fixed_schedule_start_offset=input.fixed_schedule_start_offset, fixed_schedule_end_offset=input.fixed_schedule_end_offset, + fixed_schedule_speedup=input.fixed_schedule_speedup, ) # fmt: skip diff --git a/src/aiperf/timing/strategies/fixed_schedule.py b/src/aiperf/timing/strategies/fixed_schedule.py index a55713055..a5830195f 100644 --- a/src/aiperf/timing/strategies/fixed_schedule.py +++ b/src/aiperf/timing/strategies/fixed_schedule.py @@ -53,7 +53,6 @@ def __init__( stop_checker: StopConditionChecker, **kwargs, ): - """Initialize fixed schedule timing strategy with all dependencies.""" super().__init__(logger_name="FixedScheduleTiming") self._config = config self._conversation_source = conversation_source @@ -61,16 +60,25 @@ def __init__( self._credit_issuer = credit_issuer self._lifecycle = lifecycle - # Computed in setup_phase + # Speedup is expressed as a multiplier (2x = twice as fast), but the + # timing math needs a divisor: a 2x speedup means each interval is + # *half* as long. Storing the reciprocal once avoids a division on + # every timestamp conversion and delay calculation. + self._time_scale = 1.0 / (config.fixed_schedule_speedup or 1.0) + self._absolute_schedule: list[ScheduleEntry] = [] self._schedule_zero_ms: float = 0.0 def _timestamp_to_perf_sec(self, timestamp_ms: int | float) -> float: - """Convert trace timestamp in milliseconds to perf counter seconds. + """Convert a trace timestamp to a perf-counter target. + + Subtracts the schedule zero point, scales by the speedup factor, + then anchors to the phase start time: - Uses the offset from the schedule zero to calculate the target performance seconds. + wall_target = started_at + (ts - zero) / speedup """ - target_offset_sec = (timestamp_ms - self._schedule_zero_ms) / MILLIS_PER_SECOND + scaled_offset_ms = (timestamp_ms - self._schedule_zero_ms) * self._time_scale + target_offset_sec = scaled_offset_ms / MILLIS_PER_SECOND return self._lifecycle.started_at_perf_sec + target_offset_sec async def setup_phase(self) -> None: @@ -116,10 +124,16 @@ async def setup_phase(self) -> None: else: self._schedule_zero_ms = 0.0 + speedup_msg = ( + f", speedup={self._config.fixed_schedule_speedup}x" + if self._config.fixed_schedule_speedup + else "" + ) self.info( f"Built schedule with {len(self._absolute_schedule)} timestamps, " f"zero_ms={self._schedule_zero_ms:.0f}, " f"auto_offset={self._config.auto_offset_timestamps}" + f"{speedup_msg}" ) async def execute_phase(self) -> None: @@ -162,7 +176,7 @@ async def handle_credit_return( ) elif next_meta.delay_ms is not None: self._scheduler.schedule_later( - next_meta.delay_ms / MILLIS_PER_SECOND, + next_meta.delay_ms * self._time_scale / MILLIS_PER_SECOND, self._credit_issuer.issue_credit(turn), ) else: diff --git a/tests/integration/test_conflux_loader.py b/tests/integration/test_conflux_loader.py new file mode 100644 index 000000000..fc786338d --- /dev/null +++ b/tests/integration/test_conflux_loader.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for Conflux proxy capture dataset loader.""" + +from pathlib import Path + +import orjson +import pytest + +from tests.harness.utils import AIPerfCLI, AIPerfMockServer +from tests.integration.conftest import IntegrationTestDefaults as defaults + + +def create_conflux_file( + tmp_path: Path, + records: list[dict], + filename: str = "conflux_session.json", +) -> Path: + """Create a pretty-printed Conflux JSON file for testing.""" + path = tmp_path / filename + path.write_bytes(orjson.dumps(records, option=orjson.OPT_INDENT_2)) + return path + + +def make_conflux_records() -> list[dict]: + """Generate a realistic multi-agent Conflux capture with staggered timestamps.""" + return [ + { + "session_id": "sess-1", + "agent_id": "planner", + "timestamp": 0.0, + "duration_ms": 800, + "messages": [ + {"role": "system", "content": "You are a planning assistant."}, + {"role": "user", "content": "Plan a web scraper for news articles."}, + ], + "tokens": {"input": 50, "output": 120}, + "hyperparameters": {"temperature": 0.7}, + }, + { + "session_id": "sess-1", + "agent_id": "coder", + "timestamp": 1000.0, + "duration_ms": 1200, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Implement the scraper using BeautifulSoup."}, + ], + "tokens": {"input": 80, "output": 250}, + "hyperparameters": {"temperature": 0.3}, + }, + { + "session_id": "sess-1", + "agent_id": "planner", + "timestamp": 3000.0, + "duration_ms": 600, + "messages": [ + {"role": "system", "content": "You are a planning assistant."}, + {"role": "user", "content": "Plan a web scraper for news articles."}, + {"role": "assistant", "content": "Here is the plan..."}, + {"role": "user", "content": "Now add error handling to the plan."}, + ], + "tokens": {"input": 150, "output": 100, "output_reasoning": 20}, + }, + { + "session_id": "sess-1", + "agent_id": "coder", + "timestamp": 5000.0, + "duration_ms": 1500, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Implement the scraper using BeautifulSoup."}, + {"role": "assistant", "content": "Here is the code..."}, + {"role": "user", "content": "Add retry logic and rate limiting."}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write content to a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}, "content": {"type": "string"}}, + }, + }, + } + ], + "tokens": {"input": 200, "output": 300}, + "hyperparameters": {"temperature": 0.2, "top_p": 0.95}, + }, + { + "session_id": "sess-1", + "agent_id": "reviewer", + "timestamp": 8000.0, + "duration_ms": 900, + "messages": [ + {"role": "user", "content": "Review the scraper code for best practices."}, + ], + "tokens": {"input": 100, "output": 180}, + }, + ] # fmt: skip + + +@pytest.mark.integration +@pytest.mark.asyncio +class TestConfluxLoaderIntegration: + """Integration tests for Conflux proxy capture dataset loader.""" + + async def test_auto_detect_and_replay_with_speedup( + self, + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, + ): + """Auto-detect a pretty-printed Conflux file and replay with --fixed-schedule-speedup 10.""" + records = make_conflux_records() + conflux_file = create_conflux_file(tmp_path, records) + request_count = len(records) + + result = await cli.run( + f""" + aiperf profile \ + --model {defaults.model} \ + --url {aiperf_mock_server.url} \ + --endpoint-type chat \ + --input-file {conflux_file} \ + --request-count {request_count} \ + --fixed-schedule \ + --fixed-schedule-speedup 10 \ + --workers-max {defaults.workers_max} \ + --ui {defaults.ui} + """ + ) + + assert result.request_count == request_count + assert result.has_all_outputs + + async def test_explicit_type_with_speedup( + self, + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, + ): + """Explicit --custom-dataset-type conflux with speedup.""" + records = make_conflux_records() + conflux_file = create_conflux_file(tmp_path, records) + request_count = len(records) + + result = await cli.run( + f""" + aiperf profile \ + --model {defaults.model} \ + --url {aiperf_mock_server.url} \ + --endpoint-type chat \ + --input-file {conflux_file} \ + --custom-dataset-type conflux \ + --request-count {request_count} \ + --fixed-schedule \ + --fixed-schedule-speedup 10 \ + --workers-max {defaults.workers_max} \ + --ui {defaults.ui} + """ + ) + + assert result.request_count == request_count + assert result.has_all_outputs + + async def test_directory_of_conflux_files( + self, + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, + ): + """Load a directory of Conflux JSON files with auto-detection.""" + input_dir = tmp_path / "sessions" + input_dir.mkdir() + + records = make_conflux_records() + # Split into two files by agent + planner_records = [r for r in records if r.get("agent_id") == "planner"] + other_records = [r for r in records if r.get("agent_id") != "planner"] + create_conflux_file(input_dir, planner_records, "session_planner.json") + create_conflux_file(input_dir, other_records, "session_others.json") + request_count = len(records) + + result = await cli.run( + f""" + aiperf profile \ + --model {defaults.model} \ + --url {aiperf_mock_server.url} \ + --endpoint-type chat \ + --input-file {input_dir} \ + --request-count {request_count} \ + --fixed-schedule \ + --fixed-schedule-speedup 10 \ + --workers-max {defaults.workers_max} \ + --ui {defaults.ui} + """ + ) + + assert result.request_count == request_count + assert result.has_all_outputs + + async def test_conflux_with_extra_params_in_payload( + self, + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, + ): + """Verify hyperparameters from Conflux records propagate into request payloads.""" + records = [ + { + "session_id": "sess-1", + "agent_id": "agent-A", + "timestamp": 0.0, + "duration_ms": 500, + "messages": [{"role": "user", "content": "Hello"}], + "tokens": {"input": 10, "output": 20}, + "hyperparameters": {"temperature": 0.42, "top_p": 0.88}, + }, + ] + conflux_file = create_conflux_file(tmp_path, records) + + result = await cli.run( + f""" + aiperf profile \ + --model {defaults.model} \ + --url {aiperf_mock_server.url} \ + --endpoint-type chat \ + --input-file {conflux_file} \ + --custom-dataset-type conflux \ + --request-count 1 \ + --fixed-schedule \ + --fixed-schedule-speedup 10 \ + --workers-max {defaults.workers_max} \ + --ui {defaults.ui} + """ + ) + + assert result.request_count == 1 + assert result.has_all_outputs + + # Verify hyperparameters made it into the actual payloads + assert result.inputs is not None + payloads = [p for session in result.inputs.data for p in session.payloads] + assert len(payloads) >= 1 + payload = payloads[0] + assert payload["temperature"] == 0.42 + assert payload["top_p"] == 0.88 diff --git a/tests/unit/dataset/composer/test_custom_composer.py b/tests/unit/dataset/composer/test_custom_composer.py index 75ec844b5..71c7ced76 100644 --- a/tests/unit/dataset/composer/test_custom_composer.py +++ b/tests/unit/dataset/composer/test_custom_composer.py @@ -1,11 +1,18 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from pathlib import Path from unittest.mock import Mock, mock_open, patch +import orjson import pytest -from aiperf.common.config import SynthesisConfig +from aiperf.common.config import ( + EndpointConfig, + InputConfig, + SynthesisConfig, + UserConfig, +) from aiperf.common.models import Conversation, Turn from aiperf.dataset.composer.custom import CustomDatasetComposer from aiperf.dataset.loader import ( @@ -285,3 +292,155 @@ def test_max_isl_alone_allowed_with_any_type(self, custom_config, mock_tokenizer # Should not raise - max_isl doesn't trigger should_synthesize() composer._validate_synthesis_config(CustomDatasetType.SINGLE_TURN) + + +class TestConfluxAutoDetection: + """Integration tests for Conflux dataset auto-detection and loading.""" + + @staticmethod + def _make_conflux_record( + *, + agent_id: str = "agent-A", + timestamp: float = 1000.0, + messages: list | None = None, + tokens: dict | None = None, + hyperparameters: dict | None = None, + ) -> dict: + rec = { + "session_id": "sess-1", + "agent_id": agent_id, + "timestamp": timestamp, + "duration_ms": 500, + } + if messages is not None: + rec["messages"] = messages + if tokens is not None: + rec["tokens"] = tokens + if hyperparameters is not None: + rec["hyperparameters"] = hyperparameters + return rec + + @staticmethod + def _write_pretty_json(path: Path, data: list) -> None: + path.write_bytes(orjson.dumps(data, option=orjson.OPT_INDENT_2)) + + @staticmethod + def _make_config( + file_path: str, dataset_type: CustomDatasetType | None = None + ) -> UserConfig: + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + file=Path(file_path), + custom_dataset_type=dataset_type, + conflux_include_utility_calls=False, + ), + ) + + def test_auto_detect_pretty_printed_conflux(self, tmp_path): + """Full pipeline: pretty-printed JSON auto-detects as Conflux and loads.""" + records = [ + self._make_conflux_record( + agent_id="coder", + timestamp=1000.0, + messages=[{"role": "user", "content": "Write hello world"}], + tokens={"input": 50, "output": 100}, + hyperparameters={"temperature": 0.3}, + ), + self._make_conflux_record( + agent_id="coder", + timestamp=3000.0, + messages=[ + {"role": "user", "content": "Write hello world"}, + {"role": "assistant", "content": "print('hello world')"}, + {"role": "user", "content": "Add error handling"}, + ], + tokens={"input": 150, "output": 200, "output_reasoning": 30}, + ), + ] + path = tmp_path / "session.json" + self._write_pretty_json(path, records) + + config = self._make_config(str(path)) + composer = CustomDatasetComposer(config, None) + conversations = composer.create_dataset() + + assert len(conversations) == 1 + convo = conversations[0] + assert convo.session_id == "conflux_coder" + assert len(convo.turns) == 2 + assert convo.turns[0].max_tokens == 100 + assert convo.turns[0].extra_params == {"temperature": 0.3} + assert convo.turns[1].max_tokens == 230 + assert len(convo.turns[1].raw_messages) == 3 + + def test_auto_detect_compact_conflux(self, tmp_path): + """Compact (single-line) JSON also auto-detects as Conflux.""" + records = [ + self._make_conflux_record( + messages=[{"role": "user", "content": "Hello"}], + tokens={"input": 10, "output": 20}, + ), + ] + path = tmp_path / "compact.json" + path.write_bytes(orjson.dumps(records)) + + config = self._make_config(str(path)) + composer = CustomDatasetComposer(config, None) + conversations = composer.create_dataset() + + assert len(conversations) == 1 + assert len(conversations[0].turns) == 1 + + def test_auto_detect_conflux_directory(self, tmp_path): + """Directory of pretty-printed JSON files auto-detects as Conflux.""" + for i, agent in enumerate(["planner", "executor"]): + self._write_pretty_json( + tmp_path / f"session_{i}.json", + [ + self._make_conflux_record( + agent_id=agent, timestamp=1000.0 + i * 1000 + ) + ], + ) + + config = self._make_config(str(tmp_path)) + composer = CustomDatasetComposer(config, None) + conversations = composer.create_dataset() + + assert len(conversations) == 2 + + def test_explicit_type_skips_auto_detection(self, tmp_path): + """With --custom-dataset-type conflux, auto-detection is bypassed.""" + records = [self._make_conflux_record()] + path = tmp_path / "data.json" + self._write_pretty_json(path, records) + + config = self._make_config(str(path), dataset_type=CustomDatasetType.CONFLUX) + composer = CustomDatasetComposer(config, None) + conversations = composer.create_dataset() + + assert len(conversations) == 1 + + def test_multi_agent_with_utility_filtering(self, tmp_path): + """Multi-agent session with utility calls filtered by default.""" + records = [ + self._make_conflux_record(agent_id="planner", timestamp=1000.0), + self._make_conflux_record(agent_id=None, timestamp=2000.0), + self._make_conflux_record(agent_id="executor", timestamp=3000.0), + self._make_conflux_record(agent_id="planner", timestamp=4000.0), + ] + path = tmp_path / "session.json" + self._write_pretty_json(path, records) + + config = self._make_config(str(path)) + composer = CustomDatasetComposer(config, None) + conversations = composer.create_dataset() + + assert len(conversations) == 2 + session_ids = {c.session_id for c in conversations} + assert "conflux_planner" in session_ids + assert "conflux_executor" in session_ids + + planner = next(c for c in conversations if c.session_id == "conflux_planner") + assert len(planner.turns) == 2 diff --git a/tests/unit/dataset/loader/test_conflux.py b/tests/unit/dataset/loader/test_conflux.py new file mode 100644 index 000000000..f521235ff --- /dev/null +++ b/tests/unit/dataset/loader/test_conflux.py @@ -0,0 +1,1124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for ConfluxLoader and ConfluxRecord models.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import orjson +import pytest +from pydantic import ValidationError + +from aiperf.common.config import EndpointConfig, InputConfig, UserConfig +from aiperf.common.enums import ConversationContextMode +from aiperf.dataset.loader.conflux import ConfluxLoader +from aiperf.dataset.loader.models import ConfluxRecord, ConfluxTokens +from aiperf.plugin.enums import DatasetSamplingStrategy + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_record( + *, + session_id: str = "sess-1", + agent_id: str | None = "agent-A", + timestamp: float = 1000.0, + duration_ms: int | float = 500, + messages: list[dict[str, Any]] | None = None, + tools: list[dict[str, Any]] | None = None, + tokens: dict[str, Any] | None = None, + hyperparameters: dict[str, Any] | None = None, + is_subagent: bool | None = None, + completed_at: str | None = None, + **extra_fields: Any, +) -> dict[str, Any]: + """Build a raw Conflux record dict with sensible defaults.""" + rec: dict[str, Any] = { + "session_id": session_id, + "agent_id": agent_id, + "timestamp": timestamp, + "duration_ms": duration_ms, + } + if messages is not None: + rec["messages"] = messages + if tools is not None: + rec["tools"] = tools + if tokens is not None: + rec["tokens"] = tokens + if hyperparameters is not None: + rec["hyperparameters"] = hyperparameters + if is_subagent is not None: + rec["is_subagent"] = is_subagent + if completed_at is not None: + rec["completed_at"] = completed_at + rec.update(extra_fields) + return rec + + +def _write_json(path: Path, data: Any, *, pretty: bool = True) -> str: + """Write data as JSON and return the string path. + + Uses pretty-printed format by default to match Conflux CLI export behavior. + """ + if pretty: + path.write_text(orjson.dumps(data, option=orjson.OPT_INDENT_2).decode("utf-8")) + else: + path.write_bytes(orjson.dumps(data)) + return str(path) + + +def _make_user_config(*, include_utility: bool = False) -> UserConfig: + """Create a minimal UserConfig for Conflux tests.""" + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig.model_construct( + conflux_include_utility_calls=include_utility, + ), + ) + + +# --------------------------------------------------------------------------- +# ConfluxTokens model tests +# --------------------------------------------------------------------------- + + +class TestConfluxTokens: + """Tests for ConfluxTokens model.""" + + def test_defaults(self): + tokens = ConfluxTokens() + assert tokens.input == 0 + assert tokens.input_cached == 0 + assert tokens.input_cache_write == 0 + assert tokens.output == 0 + assert tokens.output_reasoning == 0 + + def test_all_fields_populated(self): + tokens = ConfluxTokens( + input=1000, + input_cached=200, + input_cache_write=300, + output=500, + output_reasoning=50, + ) + assert tokens.input == 1000 + assert tokens.input_cached == 200 + assert tokens.input_cache_write == 300 + assert tokens.output == 500 + assert tokens.output_reasoning == 50 + + def test_partial_fields(self): + tokens = ConfluxTokens(input=42, output=10) + assert tokens.input == 42 + assert tokens.input_cached == 0 + assert tokens.output == 10 + assert tokens.output_reasoning == 0 + + +# --------------------------------------------------------------------------- +# ConfluxRecord model tests +# --------------------------------------------------------------------------- + + +class TestConfluxRecord: + """Tests for ConfluxRecord model validation.""" + + def test_minimal_valid_record(self): + record = ConfluxRecord(session_id="s1", timestamp=1000.0) + assert record.session_id == "s1" + assert record.agent_id is None + assert record.messages == [] + assert record.tools == [] + assert record.tokens is None + assert record.hyperparameters is None + assert record.duration_ms == 0 + + def test_full_record(self): + record = ConfluxRecord( + session_id="s1", + agent_id="agent-X", + is_subagent=True, + timestamp=2000.0, + duration_ms=1234, + completed_at="2025-01-15T10:30:01.234Z", + tokens=ConfluxTokens(input=100, output=50), + messages=[{"role": "user", "content": "Hello"}], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + hyperparameters={"temperature": 0.7, "top_p": 0.9}, + ) + assert record.agent_id == "agent-X" + assert record.is_subagent is True + assert record.tokens.input == 100 + assert record.tokens.output == 50 + assert len(record.messages) == 1 + assert len(record.tools) == 1 + assert record.hyperparameters["temperature"] == 0.7 + + def test_extra_fields_ignored(self): + record = ConfluxRecord.model_validate( + { + "session_id": "s1", + "timestamp": 1000.0, + "unknown_field": "should be ignored", + "another_extra": 42, + } + ) + assert record.session_id == "s1" + assert not hasattr(record, "unknown_field") + + def test_missing_required_session_id_raises(self): + with pytest.raises(ValidationError): + ConfluxRecord(timestamp=1000.0) + + def test_missing_required_timestamp_raises(self): + with pytest.raises(ValidationError): + ConfluxRecord(session_id="s1") + + def test_duration_ms_accepts_float(self): + record = ConfluxRecord(session_id="s1", timestamp=1000.0, duration_ms=123.456) + assert record.duration_ms == 123.456 + + def test_tokens_nested_validation(self): + record = ConfluxRecord.model_validate( + { + "session_id": "s1", + "timestamp": 1000.0, + "tokens": {"input": 999, "output": 100, "output_reasoning": 10}, + } + ) + assert record.tokens.input == 999 + assert record.tokens.output_reasoning == 10 + + +# --------------------------------------------------------------------------- +# ConfluxLoader class-level tests +# --------------------------------------------------------------------------- + + +class TestConfluxLoaderClassMethods: + """Tests for ConfluxLoader class methods.""" + + def test_default_context_mode(self): + assert ( + ConfluxLoader.get_default_context_mode() + == ConversationContextMode.MESSAGE_ARRAY_WITH_RESPONSES + ) + + def test_preferred_sampling_strategy(self): + assert ( + ConfluxLoader.get_preferred_sampling_strategy() + == DatasetSamplingStrategy.SEQUENTIAL + ) + + +# --------------------------------------------------------------------------- +# can_load tests +# --------------------------------------------------------------------------- + + +class TestConfluxCanLoad: + """Tests for ConfluxLoader.can_load auto-detection.""" + + def test_pretty_printed_single_file(self, tmp_path): + records = [_make_record()] + path = tmp_path / "session.json" + _write_json(path, records) + assert ConfluxLoader.can_load(filename=str(path)) is True + + def test_pretty_printed_directory(self, tmp_path): + records = [_make_record()] + _write_json(tmp_path / "a.json", records) + assert ConfluxLoader.can_load(filename=str(tmp_path)) is True + + def test_compact_json_detected(self, tmp_path): + """Compact (single-line) JSON is also auto-detected.""" + path = tmp_path / "compact.json" + _write_json(path, [_make_record()], pretty=False) + assert ConfluxLoader.can_load(filename=str(path)) is True + + def test_empty_array_returns_false(self, tmp_path): + path = tmp_path / "empty.json" + _write_json(path, []) + assert ConfluxLoader.can_load(filename=str(path)) is False + + def test_non_json_file_returns_false(self, tmp_path): + path = tmp_path / "data.txt" + path.write_text("not json") + assert ConfluxLoader.can_load(filename=str(path)) is False + + def test_json_object_not_array_returns_false(self, tmp_path): + path = tmp_path / "obj.json" + _write_json(path, {"key": "value"}) + assert ConfluxLoader.can_load(filename=str(path)) is False + + def test_wrong_json_extension_returns_false(self, tmp_path): + path = tmp_path / "data.jsonl" + path.write_bytes(orjson.dumps([_make_record()])) + assert ConfluxLoader.can_load(filename=str(path)) is False + + def test_none_filename_returns_false(self): + assert ConfluxLoader.can_load(filename=None) is False + + def test_nonexistent_file_returns_false(self): + assert ConfluxLoader.can_load(filename="/nonexistent/path/file.json") is False + + def test_array_of_non_conflux_records_returns_false(self, tmp_path): + path = tmp_path / "other.json" + _write_json(path, [{"unrelated": "data"}]) + assert ConfluxLoader.can_load(filename=str(path)) is False + + def test_directory_with_no_json_files(self, tmp_path): + (tmp_path / "readme.txt").write_text("hello") + assert ConfluxLoader.can_load(filename=str(tmp_path)) is False + + def test_directory_with_invalid_json(self, tmp_path): + (tmp_path / "bad.json").write_text("{not valid json{") + assert ConfluxLoader.can_load(filename=str(tmp_path)) is False + + def test_directory_with_valid_and_invalid_json(self, tmp_path): + """Directory probe uses next(glob) which is unordered; a single valid file may or may not be probed.""" + _write_json(tmp_path / "valid.json", [_make_record()]) + result = ConfluxLoader.can_load(filename=str(tmp_path)) + assert result is True + + def test_multiline_whitespace_variants(self, tmp_path): + """Array bracket on its own line with leading whitespace.""" + path = tmp_path / "spaced.json" + path.write_text( + '[\n {\n "session_id": "s1",\n "timestamp": 1000.0\n }\n]\n' + ) + assert ConfluxLoader.can_load(filename=str(path)) is True + + def test_large_first_record_within_probe_limit(self, tmp_path): + """A record with a large messages array still probes correctly.""" + big_messages = [ + {"role": "user", "content": f"msg-{'x' * 500}-{i}"} for i in range(200) + ] + records = [_make_record(messages=big_messages)] + path = tmp_path / "big_record.json" + _write_json(path, records) + assert ConfluxLoader.can_load(filename=str(path)) is True + + def test_multiple_records_probe_only_first(self, tmp_path): + """Probe validates only the first record, even with many in the array.""" + records = [_make_record(agent_id=f"agent-{i}") for i in range(100)] + path = tmp_path / "many.json" + _write_json(path, records) + assert ConfluxLoader.can_load(filename=str(path)) is True + + +# --------------------------------------------------------------------------- +# load_dataset tests +# --------------------------------------------------------------------------- + + +class TestConfluxLoadDataset: + """Tests for ConfluxLoader.load_dataset.""" + + def test_single_agent_group(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + _make_record(agent_id="A", timestamp=2000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + groups = loader.load_dataset() + + assert len(groups) == 1 + assert "A" in groups + assert len(groups["A"]) == 2 + + def test_multiple_agent_groups(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + _make_record(agent_id="B", timestamp=2000.0), + _make_record(agent_id="A", timestamp=3000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + groups = loader.load_dataset() + + assert len(groups) == 2 + assert len(groups["A"]) == 2 + assert len(groups["B"]) == 1 + + def test_records_sorted_by_timestamp_within_group(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=6000.0), + _make_record(agent_id="A", timestamp=2000.0), + _make_record(agent_id="A", timestamp=4000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + groups = loader.load_dataset() + + timestamps = [r.timestamp for r in groups["A"]] + assert timestamps == sorted(timestamps) + assert len(timestamps) == 3 + + def test_utility_calls_skipped_by_default(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + _make_record(agent_id=None, timestamp=2000.0), + _make_record(agent_id=None, timestamp=3000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader( + filename=str(path), user_config=_make_user_config(include_utility=False) + ) + groups = loader.load_dataset() + + assert len(groups) == 1 + assert "A" in groups + + def test_utility_calls_included_when_enabled(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + _make_record(agent_id=None, timestamp=2000.0), + _make_record(agent_id=None, timestamp=3000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader( + filename=str(path), user_config=_make_user_config(include_utility=True) + ) + groups = loader.load_dataset() + + assert len(groups) == 3 + assert "A" in groups + assert "_utility_0" in groups + assert "_utility_1" in groups + + def test_utility_calls_each_get_own_group(self, tmp_path): + records = [ + _make_record(agent_id=None, timestamp=1000.0), + _make_record(agent_id=None, timestamp=2000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader( + filename=str(path), user_config=_make_user_config(include_utility=True) + ) + groups = loader.load_dataset() + + assert len(groups) == 2 + for key in groups: + assert len(groups[key]) == 1 + + def test_all_utility_records_skipped_yields_empty(self, tmp_path): + records = [ + _make_record(agent_id=None, timestamp=1000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader( + filename=str(path), user_config=_make_user_config(include_utility=False) + ) + groups = loader.load_dataset() + + assert len(groups) == 0 + + +# --------------------------------------------------------------------------- +# load_dataset directory tests +# --------------------------------------------------------------------------- + + +class TestConfluxLoadDirectory: + """Tests for loading a directory of JSON files.""" + + def test_multiple_files_merged(self, tmp_path): + _write_json( + tmp_path / "file1.json", + [_make_record(agent_id="A", timestamp=1000.0)], + ) + _write_json( + tmp_path / "file2.json", + [_make_record(agent_id="B", timestamp=100000.0)], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=_make_user_config()) + groups = loader.load_dataset() + + assert len(groups) == 2 + assert "f0_A" in groups + assert "f1_B" in groups + + def test_directory_prefixes_prevent_key_collisions(self, tmp_path): + _write_json( + tmp_path / "a.json", + [_make_record(agent_id="X", timestamp=1000.0)], + ) + _write_json( + tmp_path / "b.json", + [_make_record(agent_id="X", timestamp=100000.0)], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=_make_user_config()) + groups = loader.load_dataset() + + assert "f0_X" in groups + assert "f1_X" in groups + assert len(groups) == 2 + + def test_empty_directory_raises(self, tmp_path): + loader = ConfluxLoader(filename=str(tmp_path), user_config=_make_user_config()) + with pytest.raises(FileNotFoundError, match="No .json files found"): + loader.load_dataset() + + def test_files_loaded_in_sorted_order(self, tmp_path): + _write_json( + tmp_path / "z.json", + [_make_record(agent_id="Z", timestamp=1000.0)], + ) + _write_json( + tmp_path / "a.json", + [_make_record(agent_id="A", timestamp=1000.0)], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=_make_user_config()) + groups = loader.load_dataset() + + keys = list(groups.keys()) + assert keys[0] == "f0_A" + assert keys[1] == "f1_Z" + + def test_directory_utility_calls_with_prefix(self, tmp_path): + _write_json( + tmp_path / "file.json", + [_make_record(agent_id=None, timestamp=1000.0)], + ) + + loader = ConfluxLoader( + filename=str(tmp_path), user_config=_make_user_config(include_utility=True) + ) + groups = loader.load_dataset() + + assert len(groups) == 1 + assert "f0__utility_0" in groups + + +# --------------------------------------------------------------------------- +# convert_to_conversations tests +# --------------------------------------------------------------------------- + + +class TestConfluxConvertToConversations: + """Tests for ConfluxLoader.convert_to_conversations.""" + + def _load_and_convert( + self, + tmp_path: Path, + records: list[dict[str, Any]], + *, + include_utility: bool = False, + ) -> list: + path = tmp_path / "data.json" + _write_json(path, records) + loader = ConfluxLoader( + filename=str(path), + user_config=_make_user_config(include_utility=include_utility), + ) + data = loader.load_dataset() + return loader.convert_to_conversations(data) + + def test_basic_conversion(self, tmp_path): + records = [ + _make_record( + agent_id="A", + timestamp=1000.0, + messages=[{"role": "user", "content": "Hello"}], + ), + ] + convos = self._load_and_convert(tmp_path, records) + + assert len(convos) == 1 + assert convos[0].session_id == "conflux_A" + assert len(convos[0].turns) == 1 + + def test_turn_has_raw_messages(self, tmp_path): + msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + records = [ + _make_record(agent_id="A", timestamp=1000.0, messages=msgs), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].raw_messages == msgs + + def test_turn_has_raw_tools(self, tmp_path): + tools = [{"type": "function", "function": {"name": "search"}}] + records = [ + _make_record(agent_id="A", timestamp=1000.0, tools=tools), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].raw_tools == tools + + def test_empty_tools_becomes_none(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0, tools=[]), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].raw_tools is None + + def test_turn_timestamp_is_milliseconds(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + ] + convos = self._load_and_convert(tmp_path, records) + + ts = convos[0].turns[0].timestamp + assert isinstance(ts, float) + assert ts == 1000.0 + + def test_multi_turn_timestamps_preserved(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + _make_record(agent_id="A", timestamp=6000.0), + ] + convos = self._load_and_convert(tmp_path, records) + + t0 = convos[0].turns[0].timestamp + t1 = convos[0].turns[1].timestamp + assert t1 > t0 + delta_ms = t1 - t0 + assert abs(delta_ms - 5000.0) < 1.0 + + def test_max_tokens_from_output_plus_reasoning(self, tmp_path): + records = [ + _make_record( + agent_id="A", + timestamp=1000.0, + tokens={"input": 100, "output": 200, "output_reasoning": 50}, + ), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].max_tokens == 250 + + def test_max_tokens_none_when_no_tokens(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].max_tokens is None + + def test_max_tokens_none_when_zero_output(self, tmp_path): + records = [ + _make_record( + agent_id="A", + timestamp=1000.0, + tokens={"input": 100, "output": 0, "output_reasoning": 0}, + ), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].max_tokens is None + + def test_input_tokens_from_token_data(self, tmp_path): + records = [ + _make_record( + agent_id="A", + timestamp=1000.0, + tokens={"input": 500, "output": 100}, + ), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].input_tokens == 500 + + def test_input_tokens_none_when_no_token_data(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + ] + convos = self._load_and_convert(tmp_path, records) + + assert convos[0].turns[0].input_tokens is None + + def test_multiple_conversations_from_agents(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.0), + _make_record(agent_id="B", timestamp=2000.0), + _make_record(agent_id="A", timestamp=3000.0), + ] + convos = self._load_and_convert(tmp_path, records) + + assert len(convos) == 2 + session_ids = {c.session_id for c in convos} + assert "conflux_A" in session_ids + assert "conflux_B" in session_ids + + convo_a = next(c for c in convos if c.session_id == "conflux_A") + assert len(convo_a.turns) == 2 + + def test_empty_data_produces_no_conversations(self, tmp_path): + path = tmp_path / "data.json" + _write_json(path, [_make_record(agent_id=None)]) + + loader = ConfluxLoader( + filename=str(path), user_config=_make_user_config(include_utility=False) + ) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 0 + + +# --------------------------------------------------------------------------- +# _extract_extra_params tests +# --------------------------------------------------------------------------- + + +class TestExtractExtraParams: + """Tests for ConfluxLoader._extract_extra_params.""" + + def test_no_hyperparameters_returns_none(self): + record = ConfluxRecord(session_id="s1", timestamp=1000.0) + assert ConfluxLoader._extract_extra_params(record) is None + + def test_empty_hyperparameters_returns_none(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={}, + ) + assert ConfluxLoader._extract_extra_params(record) is None + + def test_basic_hyperparameters_extracted(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"temperature": 0.7, "top_p": 0.9}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"temperature": 0.7, "top_p": 0.9} + + def test_max_tokens_filtered_out(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"temperature": 0.5, "max_tokens": 1000}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"temperature": 0.5} + assert "max_tokens" not in params + + def test_max_output_tokens_filtered_out(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"temperature": 0.5, "max_output_tokens": 2000}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"temperature": 0.5} + + def test_none_values_filtered_out(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"temperature": 0.7, "top_k": None, "stop": None}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"temperature": 0.7} + + def test_all_filtered_returns_none(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"max_tokens": 100, "max_output_tokens": 200}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params is None + + def test_all_none_values_returns_none(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"temperature": None, "top_p": None}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params is None + + def test_zero_value_preserved(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"temperature": 0, "frequency_penalty": 0.0}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"temperature": 0, "frequency_penalty": 0.0} + + def test_false_value_preserved(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"logprobs": False}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"logprobs": False} + + def test_empty_string_value_preserved(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={"stop": ""}, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"stop": ""} + + def test_nested_dict_value_preserved(self): + record = ConfluxRecord( + session_id="s1", + timestamp=1000.0, + hyperparameters={ + "response_format": {"type": "json_object"}, + }, + ) + params = ConfluxLoader._extract_extra_params(record) + assert params == {"response_format": {"type": "json_object"}} + + +# --------------------------------------------------------------------------- +# End-to-end / integration-style tests +# --------------------------------------------------------------------------- + + +class TestConfluxEndToEnd: + """End-to-end tests combining load + convert.""" + + def test_full_pipeline_single_agent_session(self, tmp_path): + records = [ + _make_record( + agent_id="coder", + timestamp=1000.0, + messages=[{"role": "user", "content": "Write hello world"}], + tokens={"input": 50, "output": 100}, + hyperparameters={"temperature": 0.3}, + ), + _make_record( + agent_id="coder", + timestamp=3000.0, + messages=[ + {"role": "user", "content": "Write hello world"}, + {"role": "assistant", "content": "print('hello world')"}, + {"role": "user", "content": "Add error handling"}, + ], + tokens={"input": 150, "output": 200, "output_reasoning": 30}, + tools=[{"type": "function", "function": {"name": "write_file"}}], + ), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 1 + convo = convos[0] + assert convo.session_id == "conflux_coder" + assert len(convo.turns) == 2 + + turn0 = convo.turns[0] + assert turn0.raw_messages == [{"role": "user", "content": "Write hello world"}] + assert turn0.max_tokens == 100 + assert turn0.input_tokens == 50 + assert turn0.extra_params == {"temperature": 0.3} + assert turn0.raw_tools is None + + turn1 = convo.turns[1] + assert len(turn1.raw_messages) == 3 + assert turn1.max_tokens == 230 + assert turn1.input_tokens == 150 + assert turn1.raw_tools is not None + assert len(turn1.raw_tools) == 1 + + def test_full_pipeline_multi_agent_with_utility(self, tmp_path): + records = [ + _make_record(agent_id="planner", timestamp=1000.0), + _make_record(agent_id=None, timestamp=2000.0), + _make_record(agent_id="executor", timestamp=3000.0), + _make_record(agent_id="planner", timestamp=4000.0), + ] + path = tmp_path / "session.json" + _write_json(path, records) + + loader = ConfluxLoader( + filename=str(path), user_config=_make_user_config(include_utility=True) + ) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 3 + session_ids = {c.session_id for c in convos} + assert "conflux_planner" in session_ids + assert "conflux_executor" in session_ids + + planner = next(c for c in convos if c.session_id == "conflux_planner") + assert len(planner.turns) == 2 + + def test_full_pipeline_directory_with_mixed_agents(self, tmp_path): + _write_json( + tmp_path / "session1.json", + [ + _make_record(agent_id="A", timestamp=1000.0), + _make_record(agent_id="B", timestamp=2000.0), + ], + ) + _write_json( + tmp_path / "session2.json", + [ + _make_record(agent_id="A", timestamp=100000.0), + ], + ) + + loader = ConfluxLoader(filename=str(tmp_path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 3 + + +# --------------------------------------------------------------------------- +# Boundary / pathological tests +# --------------------------------------------------------------------------- + + +class TestConfluxBoundaryConditions: + """Edge cases, boundary conditions, and pathological inputs.""" + + def test_single_record_file(self, tmp_path): + records = [_make_record(agent_id="solo")] + path = tmp_path / "single.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 1 + assert len(convos[0].turns) == 1 + + def test_very_large_messages_array(self, tmp_path): + big_messages = [{"role": "user", "content": f"msg-{i}"} for i in range(500)] + records = [ + _make_record(agent_id="A", timestamp=1000.0, messages=big_messages), + ] + path = tmp_path / "big.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos[0].turns[0].raw_messages) == 500 + + def test_many_agents_in_one_file(self, tmp_path): + records = [ + _make_record( + agent_id=f"agent-{i}", + timestamp=float(1000 + i * 1000), + ) + for i in range(50) + ] + path = tmp_path / "many.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 50 + + def test_identical_timestamps_stable_within_group(self, tmp_path): + records = [ + _make_record( + agent_id="A", + timestamp=1000.0, + messages=[{"role": "user", "content": f"msg-{i}"}], + ) + for i in range(5) + ] + path = tmp_path / "dupes.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 1 + assert len(convos[0].turns) == 5 + + def test_timestamps_with_microsecond_precision(self, tmp_path): + records = [ + _make_record(agent_id="A", timestamp=1000.000001), + _make_record(agent_id="A", timestamp=1000.000002), + ] + path = tmp_path / "micro.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + t0 = convos[0].turns[0].timestamp + t1 = convos[0].turns[1].timestamp + assert t1 >= t0 + + def test_record_with_all_optional_fields_none(self, tmp_path): + records = [ + { + "session_id": "s1", + "timestamp": 1000.0, + } + ] + path = tmp_path / "minimal.json" + _write_json(path, records) + + loader = ConfluxLoader( + filename=str(path), user_config=_make_user_config(include_utility=True) + ) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert len(convos) == 1 + turn = convos[0].turns[0] + assert turn.raw_messages == [] + assert turn.raw_tools is None + assert turn.max_tokens is None + assert turn.input_tokens is None + assert turn.extra_params is None + + def test_hyperparameters_only_skip_fields_returns_no_extra_params(self, tmp_path): + records = [ + _make_record( + agent_id="A", + timestamp=1000.0, + hyperparameters={"max_tokens": 1024, "max_output_tokens": 512}, + ), + ] + path = tmp_path / "skip.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert convos[0].turns[0].extra_params is None + + def test_tokens_with_only_reasoning_output(self, tmp_path): + records = [ + _make_record( + agent_id="A", + timestamp=1000.0, + tokens={"input": 100, "output": 0, "output_reasoning": 500}, + ), + ] + path = tmp_path / "reasoning.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + convos = loader.convert_to_conversations(data) + + assert convos[0].turns[0].max_tokens == 500 + + def test_duration_ms_zero(self, tmp_path): + records = [_make_record(agent_id="A", duration_ms=0)] + path = tmp_path / "zero_dur.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + + assert data["A"][0].duration_ms == 0 + + def test_extra_json_fields_silently_dropped(self, tmp_path): + records = [ + { + "session_id": "s1", + "agent_id": "A", + "timestamp": 1000.0, + "provider": "anthropic", + "model_name": "claude-3.5-sonnet", + "response_text": "Hello!", + "metadata": {"version": 2}, + } + ] + path = tmp_path / "extra.json" + _write_json(path, records) + + loader = ConfluxLoader(filename=str(path), user_config=_make_user_config()) + data = loader.load_dataset() + + assert len(data["A"]) == 1 + record = data["A"][0] + assert not hasattr(record, "provider") + assert not hasattr(record, "model_name") + + +# --------------------------------------------------------------------------- +# Turn.copy_with_stripped_media with new fields +# --------------------------------------------------------------------------- + + +class TestTurnCopyWithStrippedMediaNewFields: + """Verify Turn.copy_with_stripped_media preserves new fields.""" + + def test_copy_with_stripped_media_preserves_input_tokens(self): + from aiperf.common.models import Turn + + turn = Turn( + texts=[], + input_tokens=42, + ) + copy = turn.copy_with_stripped_media() + assert copy.input_tokens == 42 + + def test_copy_with_stripped_media_preserves_extra_params(self): + from aiperf.common.models import Turn + + turn = Turn( + texts=[], + extra_params={"temperature": 0.7, "top_p": 0.9}, + ) + copy = turn.copy_with_stripped_media() + assert copy.extra_params == {"temperature": 0.7, "top_p": 0.9} + + def test_copy_with_stripped_media_extra_params_is_independent(self): + from aiperf.common.models import Turn + + original_params = {"temperature": 0.7} + turn = Turn(texts=[], extra_params=original_params) + copy = turn.copy_with_stripped_media() + + copy.extra_params["temperature"] = 999 + assert turn.extra_params["temperature"] == 0.7 + + def test_copy_with_stripped_media_none_fields(self): + from aiperf.common.models import Turn + + turn = Turn(texts=[], input_tokens=None, extra_params=None) + copy = turn.copy_with_stripped_media() + assert copy.input_tokens is None + assert copy.extra_params is None diff --git a/tests/unit/endpoints/test_openai_chat_completions.py b/tests/unit/endpoints/test_openai_chat_completions.py index e5a75b321..f960c832d 100644 --- a/tests/unit/endpoints/test_openai_chat_completions.py +++ b/tests/unit/endpoints/test_openai_chat_completions.py @@ -374,3 +374,153 @@ def test_name_field_excluded_from_multimodal_payload(self, model_endpoint): message = payload["messages"][0] assert "name" not in message assert isinstance(message["content"], list) + + # ----------------------------------------------------------------------- + # extra_params integration tests + # ----------------------------------------------------------------------- + + def test_extra_params_merged_into_payload(self, model_endpoint): + """Turn.extra_params are included in the final payload.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn( + texts=[Text(contents=["Hello"])], + extra_params={"temperature": 0.3, "top_p": 0.95}, + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + assert payload["temperature"] == 0.3 + assert payload["top_p"] == 0.95 + + def test_extra_params_none_adds_nothing(self, model_endpoint): + """No extra keys when extra_params is None.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn(texts=[Text(contents=["Hello"])], extra_params=None) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + assert "temperature" not in payload + assert set(payload.keys()) == {"messages", "model", "stream"} + + def test_extra_params_empty_dict_adds_nothing(self, model_endpoint): + """Empty dict is falsy, so no extra keys.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn(texts=[Text(contents=["Hello"])], extra_params={}) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + assert set(payload.keys()) == {"messages", "model", "stream"} + + def test_endpoint_extra_overrides_extra_params(self, model_endpoint): + """Endpoint-level extra takes precedence over per-turn extra_params.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn( + texts=[Text(contents=["Hello"])], + extra_params={"temperature": 0.3}, + ) + model_endpoint.endpoint.extra = {"temperature": 0.9} + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + assert payload["temperature"] == 0.9 + + def test_max_tokens_not_overridden_by_extra_params(self, model_endpoint): + """max_completion_tokens set from Turn.max_tokens wins over extra_params.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn( + texts=[Text(contents=["Hello"])], + max_tokens=100, + extra_params={"max_completion_tokens": 9999}, + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + assert payload["max_completion_tokens"] == 100 + + def test_model_not_overridden_by_extra_params(self, model_endpoint): + """The model field is always set last and cannot be overridden.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn( + texts=[Text(contents=["Hello"])], + extra_params={"model": "evil-model"}, + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + assert payload["model"] == "test-model" + + def test_stream_not_overridden_by_extra_params(self, model_endpoint): + """The stream field is always set last and cannot be overridden.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn( + texts=[Text(contents=["Hello"])], + extra_params={"stream": True}, + ) + model_endpoint.endpoint.streaming = False + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + assert payload["stream"] is False + + def test_extra_params_with_tools_and_stream_options(self, model_endpoint): + """extra_params, tools, and stream_options coexist correctly.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn( + texts=[Text(contents=["Hello"])], + raw_tools=[{"type": "function", "function": {"name": "search"}}], + extra_params={"temperature": 0.5}, + ) + model_endpoint.endpoint.streaming = True + model_endpoint.endpoint.use_server_token_count = True + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + + assert payload["temperature"] == 0.5 + assert payload["tools"] == [ + {"type": "function", "function": {"name": "search"}} + ] + assert payload["stream"] is True + assert payload["stream_options"] == {"include_usage": True} + + def test_extra_params_stream_options_not_clobbered(self, model_endpoint): + """stream_options from extra_params gets include_usage merged in.""" + endpoint = ChatEndpoint(model_endpoint) + turn = Turn( + texts=[Text(contents=["Hello"])], + extra_params={"stream_options": {"continuous_updates": True}}, + ) + model_endpoint.endpoint.streaming = True + model_endpoint.endpoint.use_server_token_count = True + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + + assert payload["stream_options"] == { + "continuous_updates": True, + "include_usage": True, + } + + def test_extra_params_stream_options_not_mutated(self, model_endpoint): + """format_payload must not mutate the original extra_params dict.""" + endpoint = ChatEndpoint(model_endpoint) + original_stream_opts = {"continuous_updates": True} + extra = {"stream_options": original_stream_opts} + turn = Turn( + texts=[Text(contents=["Hello"])], + extra_params=extra, + ) + model_endpoint.endpoint.streaming = True + model_endpoint.endpoint.use_server_token_count = True + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + endpoint.format_payload(request_info) + + assert original_stream_opts == {"continuous_updates": True} + assert "include_usage" not in original_stream_opts + + def test_extra_params_with_raw_messages(self, model_endpoint): + """extra_params work with raw_messages path (Conflux replay).""" + endpoint = ChatEndpoint(model_endpoint) + raw_msgs = [{"role": "user", "content": "Hello from replay"}] + turn = Turn( + texts=[], + raw_messages=raw_msgs, + extra_params={"temperature": 0.2}, + ) + request_info = create_request_info(model_endpoint=model_endpoint, turns=[turn]) + payload = endpoint.format_payload(request_info) + + assert payload["messages"] == raw_msgs + assert payload["temperature"] == 0.2 + assert payload["model"] == "test-model"