From 69b753280eb180278bec04a13be4b98b8901e58d Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Thu, 26 Feb 2026 13:17:58 -0800 Subject: [PATCH 1/7] feat: Add Bailian trace loader and extract BaseTraceDatasetLoader Add support for Alibaba Bailian trace format (bailian_trace) with multi-turn conversation linking via chat_id/parent_chat_id and 16-token SipHash blocks. Extract common trace loading logic into BaseTraceDatasetLoader to share infrastructure between Mooncake and Bailian loaders. Make block_size configurable per-loader via plugin metadata default_block_size, and generalize mooncake-specific validation to work with any trace dataset type. Signed-off-by: Anthony Casagrande --- docs/cli_options.md | 5 +- src/aiperf/common/config/input_config.py | 11 +- src/aiperf/common/config/prompt_config.py | 11 +- src/aiperf/common/config/user_config.py | 18 +- src/aiperf/dataset/composer/custom.py | 20 +- src/aiperf/dataset/generator/prompt.py | 6 +- src/aiperf/dataset/loader/__init__.py | 6 + src/aiperf/dataset/loader/bailian_trace.py | 155 ++++ .../dataset/loader/base_trace_loader.py | 349 +++++++++ src/aiperf/dataset/loader/models.py | 53 +- src/aiperf/dataset/loader/mooncake_trace.py | 256 +----- src/aiperf/plugin/categories.yaml | 1 + src/aiperf/plugin/enums.py | 2 +- src/aiperf/plugin/plugins.py | 28 + src/aiperf/plugin/plugins.yaml | 16 + src/aiperf/plugin/schema/plugins.schema.json | 32 +- src/aiperf/plugin/schema/schemas.py | 31 + tests/unit/common/config/test_input_config.py | 39 +- .../unit/common/config/test_prompt_config.py | 2 +- .../config/test_user_config_mooncake_trace.py | 30 +- .../dataset/composer/test_custom_composer.py | 29 +- .../unit/dataset/loader/test_bailian_trace.py | 739 ++++++++++++++++++ tests/unit/dataset/loader/test_trace.py | 6 +- 23 files changed, 1524 insertions(+), 321 deletions(-) create mode 100644 src/aiperf/dataset/loader/bailian_trace.py create mode 100644 src/aiperf/dataset/loader/base_trace_loader.py create mode 100644 tests/unit/dataset/loader/test_bailian_trace.py diff --git a/docs/cli_options.md b/docs/cli_options.md index 254a5f78c..1c7c547ef 100644 --- a/docs/cli_options.md +++ b/docs/cli_options.md @@ -293,7 +293,7 @@ Pre-configured public dataset to download and use for benchmarking (e.g., `share #### `--custom-dataset-type` `` Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), `mooncake_trace` (timestamped trace files), `random_pool` (directory of reusable prompts). Requires `--input-file`. Mutually exclusive with `--public-dataset`. -
_Choices: [`mooncake_trace`, `multi_turn`, `random_pool`, `single_turn`]_ +
_Choices: [`bailian_trace`, `mooncake_trace`, `multi_turn`, `random_pool`, `single_turn`]_ #### `--dataset-sampling-strategy` `` @@ -512,8 +512,7 @@ Standard deviation for synthetic input prompt token lengths. Creates variability #### `--prompt-input-tokens-block-size`, `--synthetic-input-tokens-block-size`, `--isl-block-size` `` -Token block size for hash-based prompt caching in `mooncake_trace` datasets. When `hash_ids` are provided in trace entries, prompts are divided into blocks of this size. Each `hash_id` maps to a cached block of `block_size` tokens, enabling simulation of KV-cache sharing patterns from production workloads. The total prompt length equals `(num_hash_ids - 1) * block_size + final_block_size`. -
_Default: `512`_ +Token block size for hash-based prompt caching in trace datasets (`mooncake_trace`, `bailian_trace`). When `hash_ids` are provided in trace entries, prompts are divided into blocks of this size. Each `hash_id` maps to a cached block of `block_size` tokens, enabling simulation of KV-cache sharing patterns from production workloads. The total prompt length equals `(num_hash_ids - 1) * block_size + final_block_size`. When not set, the trace loader's `default_block_size` from plugin metadata is used (e.g. 16 for `bailian_trace`, 512 for `mooncake_trace`). #### `--seq-dist`, `--sequence-distribution` `` diff --git a/src/aiperf/common/config/input_config.py b/src/aiperf/common/config/input_config.py index 586f8ed74..162ecf846 100644 --- a/src/aiperf/common/config/input_config.py +++ b/src/aiperf/common/config/input_config.py @@ -26,6 +26,7 @@ from aiperf.common.config.video_config import VideoConfig from aiperf.common.enums import PublicDatasetType from aiperf.common.exceptions import InvalidStateError, MetricTypeError +from aiperf.plugin import plugins from aiperf.plugin.enums import ( CustomDatasetType, DatasetSamplingStrategy, @@ -104,10 +105,10 @@ def validate_custom_dataset_file(self) -> Self: return self @model_validator(mode="after") - def validate_synthesis_requires_mooncake_trace(self) -> Self: - """Validate that synthesis options require mooncake_trace dataset type. + def validate_synthesis_requires_trace_dataset(self) -> Self: + """Validate that synthesis options require a trace dataset type. - Only validates when custom_dataset_type is explicitly set to a non-mooncake + Only validates when custom_dataset_type is explicitly set to a non-trace type. If custom_dataset_type is None (auto-detect), we allow synthesis options and defer validation to runtime when the actual type is determined. """ @@ -118,13 +119,13 @@ def validate_synthesis_requires_mooncake_trace(self) -> Self: or self.synthesis.max_osl is not None ) and self.custom_dataset_type is not None - and self.custom_dataset_type != CustomDatasetType.MOONCAKE_TRACE + and not plugins.is_trace_dataset(self.custom_dataset_type) ): raise ValueError( "Synthesis options (--synthesis-speedup-ratio, --synthesis-prefix-len-multiplier, " "--synthesis-prefix-root-multiplier, --synthesis-prompt-len-multiplier, " "--synthesis-max-isl, --synthesis-max-osl) " - "require --custom-dataset-type mooncake_trace" + "require a trace dataset type (e.g., mooncake_trace, bailian_trace)" ) return self diff --git a/src/aiperf/common/config/prompt_config.py b/src/aiperf/common/config/prompt_config.py index f78e731fb..88f5a2efe 100644 --- a/src/aiperf/common/config/prompt_config.py +++ b/src/aiperf/common/config/prompt_config.py @@ -61,12 +61,13 @@ class InputTokensConfig(BaseConfig): ] = InputTokensDefaults.STDDEV block_size: Annotated[ - int, + int | None, Field( - default=512, - description="Token block size for hash-based prompt caching in `mooncake_trace` datasets. When `hash_ids` are provided in trace entries, " + default=None, + description="Token block size for hash-based prompt caching in trace datasets (`mooncake_trace`, `bailian_trace`). When `hash_ids` are provided in trace entries, " "prompts are divided into blocks of this size. Each `hash_id` maps to a cached block of `block_size` tokens, enabling simulation " - "of KV-cache sharing patterns from production workloads. The total prompt length equals `(num_hash_ids - 1) * block_size + final_block_size`.", + "of KV-cache sharing patterns from production workloads. The total prompt length equals `(num_hash_ids - 1) * block_size + final_block_size`. " + "When not set, the trace loader's `default_block_size` from plugin metadata is used (e.g. 16 for `bailian_trace`, 512 for `mooncake_trace`).", ), CLIParameter( name=( @@ -76,7 +77,7 @@ class InputTokensConfig(BaseConfig): ), group=_CLI_GROUP, ), - ] = InputTokensDefaults.BLOCK_SIZE + ] = None class OutputTokensConfig(BaseConfig): diff --git a/src/aiperf/common/config/user_config.py b/src/aiperf/common/config/user_config.py index 751c6d829..ea3b11d35 100644 --- a/src/aiperf/common/config/user_config.py +++ b/src/aiperf/common/config/user_config.py @@ -26,9 +26,9 @@ from aiperf.common.config.tokenizer_config import TokenizerConfig from aiperf.common.enums import GPUTelemetryMode, ServerMetricsFormat from aiperf.common.utils import load_json_str +from aiperf.plugin import plugins from aiperf.plugin.enums import ( ArrivalPattern, - CustomDatasetType, EndpointType, GPUTelemetryCollectorType, TimingMode, @@ -104,10 +104,10 @@ def validate_timing_mode(self) -> Self: _logger.info( f"No request count value provided for fixed schedule mode, setting to dataset entry count: {self.loadgen.request_count}" ) - elif self._should_use_fixed_schedule_for_mooncake_trace(): + elif self._should_use_fixed_schedule_for_trace_dataset(): self._timing_mode = TimingMode.FIXED_SCHEDULE _logger.info( - "Automatically enabling fixed schedule mode for mooncake_trace dataset with timestamps" + f"Automatically enabling fixed schedule mode for {self.input.custom_dataset_type} dataset with timestamps" ) if ( self.loadgen.request_count is None @@ -115,7 +115,7 @@ def validate_timing_mode(self) -> Self: ): self.loadgen.request_count = self._count_dataset_entries() _logger.info( - f"No request count value provided for mooncake trace dataset, setting to dataset entry count: {self.loadgen.request_count}" + f"No request count value provided for trace dataset, setting to dataset entry count: {self.loadgen.request_count}" ) elif self.loadgen.user_centric_rate is not None: # User-centric rate mode: per-user rate limiting (LMBenchmark parity) @@ -334,13 +334,15 @@ def validate_unused_options(self) -> Self: return self - def _should_use_fixed_schedule_for_mooncake_trace(self) -> bool: - """Check if mooncake_trace dataset has timestamps and should use fixed schedule. + def _should_use_fixed_schedule_for_trace_dataset(self) -> bool: + """Check if a trace dataset has timestamps and should use fixed schedule. Returns: - bool: True if fixed schedule should be enabled for this mooncake trace + True if fixed schedule should be enabled for this trace dataset. """ - if self.input.custom_dataset_type != CustomDatasetType.MOONCAKE_TRACE: + if self.input.custom_dataset_type is None or not plugins.is_trace_dataset( + self.input.custom_dataset_type + ): return False if not self.input.file: diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index 422f13e84..33fb2657e 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -167,23 +167,24 @@ def _set_sampling_strategy(self, dataset_type: CustomDatasetType) -> None: ) def _validate_synthesis_config(self, dataset_type: CustomDatasetType) -> None: - """Validate that synthesis options are only used with mooncake_trace. + """Validate that synthesis options are only used with trace datasets. Args: dataset_type: The determined dataset type. Raises: - ValueError: If synthesis options are set but dataset type is not mooncake_trace. + ValueError: If synthesis options are set but dataset type is not a trace format. """ if ( self.config.input.synthesis.should_synthesize() - and dataset_type != CustomDatasetType.MOONCAKE_TRACE + and not plugins.is_trace_dataset(dataset_type) ): raise ValueError( f"Synthesis options (--synthesis-speedup-ratio, --synthesis-prefix-len-multiplier, " f"--synthesis-prefix-root-multiplier, --synthesis-prompt-len-multiplier) " - f"are only supported with mooncake_trace datasets, but got {dataset_type.value}. " - f"Either remove synthesis options or use --custom-dataset-type mooncake_trace." + f"are only supported with trace datasets, " + f"but got {dataset_type.value}. " + f"Either remove synthesis options or use a trace dataset type." ) def _create_loader_instance(self, dataset_type: CustomDatasetType) -> None: @@ -192,9 +193,14 @@ def _create_loader_instance(self, dataset_type: CustomDatasetType) -> None: Args: dataset_type: The type of custom dataset to create. """ - kwargs = {} - if dataset_type == CustomDatasetType.MOONCAKE_TRACE: + kwargs: dict[str, Any] = {} + loader_metadata = plugins.get_dataset_loader_metadata(dataset_type) + if loader_metadata.is_trace: kwargs["prompt_generator"] = self.prompt_generator + + if loader_metadata.default_block_size is not None: + kwargs["default_block_size"] = loader_metadata.default_block_size + elif dataset_type == CustomDatasetType.RANDOM_POOL: kwargs["num_conversations"] = self.config.input.conversation.num diff --git a/src/aiperf/dataset/generator/prompt.py b/src/aiperf/dataset/generator/prompt.py index 5f52d3a8e..78da6bf70 100644 --- a/src/aiperf/dataset/generator/prompt.py +++ b/src/aiperf/dataset/generator/prompt.py @@ -7,6 +7,7 @@ from aiperf.common import random_generator as rng from aiperf.common.config import PromptConfig +from aiperf.common.config.config_defaults import InputTokensDefaults from aiperf.common.exceptions import ( ConfigurationError, InvalidStateError, @@ -166,9 +167,10 @@ def generate( A synthetic prompt as a string. """ if hash_ids: - return self._generate_cached_prompt( - mean, hash_ids, self.config.input_tokens.block_size + block_size = ( + self.config.input_tokens.block_size or InputTokensDefaults.BLOCK_SIZE ) + return self._generate_cached_prompt(mean, hash_ids, block_size) num_tokens = self.calculate_num_tokens(mean, stddev) return self.generate_prompt(num_tokens) diff --git a/src/aiperf/dataset/loader/__init__.py b/src/aiperf/dataset/loader/__init__.py index e586d0af4..b637b4c69 100644 --- a/src/aiperf/dataset/loader/__init__.py +++ b/src/aiperf/dataset/loader/__init__.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 """Dataset loader package for AIPerf.""" +from aiperf.dataset.loader.bailian_trace import BailianTraceDatasetLoader from aiperf.dataset.loader.base_loader import BaseFileLoader, BaseLoader from aiperf.dataset.loader.base_public_dataset import BasePublicDatasetLoader +from aiperf.dataset.loader.base_trace_loader import BaseTraceDatasetLoader from aiperf.dataset.loader.mixins import MediaConversionMixin from aiperf.dataset.loader.models import ( + BailianTrace, MooncakeTrace, MultiTurn, RandomPool, @@ -18,9 +21,12 @@ from aiperf.dataset.loader.single_turn import SingleTurnDatasetLoader __all__ = [ + "BailianTrace", + "BailianTraceDatasetLoader", "BaseFileLoader", "BaseLoader", "BasePublicDatasetLoader", + "BaseTraceDatasetLoader", "MediaConversionMixin", "MooncakeTrace", "MooncakeTraceDatasetLoader", diff --git a/src/aiperf/dataset/loader/bailian_trace.py b/src/aiperf/dataset/loader/bailian_trace.py new file mode 100644 index 000000000..b18f57d24 --- /dev/null +++ b/src/aiperf/dataset/loader/bailian_trace.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from collections import defaultdict +from pathlib import Path +from typing import Any + +from pydantic import ValidationError + +from aiperf.dataset.loader.base_trace_loader import BaseTraceDatasetLoader +from aiperf.dataset.loader.models import BailianTrace + + +class BailianTraceDatasetLoader(BaseTraceDatasetLoader[BailianTrace]): + """A dataset loader for Alibaba Bailian trace data. + + See https://github.com/alibaba-edu/qwen-bailian-usagetraces-anon + + Loads Bailian trace data from a JSONL file and converts it into + conversations for the dataset manager. Multi-turn conversations are + linked via `chat_id` / `parent_chat_id` (`-1` = root) and + ordered by `turn`. + + Timestamps are **seconds since request arrival** and are converted to + **milliseconds** internally. + + The 16-token SipHash block size is declared in `plugins.yaml` metadata + and applied automatically—no need to pass `--isl-block-size`. + + Example JSONL entry:: + + {"chat_id": 159, "parent_chat_id": -1, "timestamp": 61.114, + "input_length": 521, "output_length": 132, "type": "text", + "turn": 1, "hash_ids": [1089, 1090, 1091, 6326, 13148]} + """ + + @classmethod + def can_load( + cls, data: dict[str, Any] | None = None, filename: str | Path | None = None + ) -> bool: + """Check if this loader can handle the given data format. + + Detects Bailian traces by the presence of `chat_id`, + `parent_chat_id`, and `turn` fields. + """ + if data is None: + return False + + try: + BailianTrace.model_validate(data) + return True + except ValidationError: + return False + + # ------------------------------------------------------------------ + # Template-method hooks (see BaseTraceDatasetLoader.load_dataset) + # ------------------------------------------------------------------ + + def _parse_trace(self, line: str) -> BailianTrace: + return BailianTrace.model_validate_json(line) + + def _preprocess_trace(self, trace: BailianTrace) -> None: + """Convert timestamp from seconds to milliseconds.""" + trace.timestamp = trace.timestamp * 1000.0 + + def _group_traces(self, items: list[BailianTrace]) -> dict[str, list[BailianTrace]]: + return self._group_into_sessions(items) + + def _group_into_sessions( + self, items: list[BailianTrace] + ) -> dict[str, list[BailianTrace]]: + """Group flat trace entries into sessions using parent-child links. + + Builds a union-find over `chat_id` → `parent_chat_id` to identify + session roots, then groups entries by root and sorts each session by + `turn`. Root requests have `parent_chat_id == -1`. + """ + if not items: + return {} + + # Build lookup: chat_id → trace + by_chat_id: dict[int, BailianTrace] = {t.chat_id: t for t in items} + + # Find root chat_id for each entry by walking parent links + root_cache: dict[int, int] = {} + + def find_root(chat_id: int) -> int: + if chat_id in root_cache: + return root_cache[chat_id] + + path: list[int] = [] + current = chat_id + while current in by_chat_id and by_chat_id[current].parent_chat_id != -1: + parent = by_chat_id[current].parent_chat_id + if parent == current or parent not in by_chat_id: + break + path.append(current) + current = parent + + # Path compression + for node in path: + root_cache[node] = current + root_cache[chat_id] = current + return current + + groups: dict[str, list[BailianTrace]] = defaultdict(list) + for trace in items: + root = find_root(trace.chat_id) + session_id = str(root) + groups[session_id].append(trace) + + # Sort each session by turn number + for traces in groups.values(): + traces.sort(key=lambda t: t.turn) + + return dict(groups) + + # ------------------------------------------------------------------ + # Synthesis hooks + # ------------------------------------------------------------------ + + _BAILIAN_ONLY_FIELDS = frozenset( + { + "chat_id", + "parent_chat_id", + "request_type", + "turn", + } + ) + + def _synthesis_exclude_fields(self) -> frozenset[str]: + return self._BAILIAN_ONLY_FIELDS + + def _synthesis_dump_kwargs(self) -> dict[str, Any]: + return {"by_alias": True} + + def _reconstruct_traces( + self, originals: list[BailianTrace], synth_dicts: list[dict[str, Any]] + ) -> list[BailianTrace]: + result: list[BailianTrace] = [] + for i, synth_dict in enumerate(synth_dicts): + original = originals[i] if i < len(originals) else originals[-1] + result.append( + BailianTrace( + chat_id=original.chat_id, + parent_chat_id=original.parent_chat_id, + timestamp=synth_dict.get("timestamp", original.timestamp), + input_length=synth_dict["input_length"], + output_length=synth_dict["output_length"], + request_type=original.request_type, + turn=original.turn, + hash_ids=synth_dict.get("hash_ids", original.hash_ids), + ) + ) + return result diff --git a/src/aiperf/dataset/loader/base_trace_loader.py b/src/aiperf/dataset/loader/base_trace_loader.py new file mode 100644 index 000000000..58efb03a7 --- /dev/null +++ b/src/aiperf/dataset/loader/base_trace_loader.py @@ -0,0 +1,349 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from typing import Any, Generic, TypeVar + +from aiperf.common.config.config_defaults import InputTokensDefaults +from aiperf.common.config.user_config import UserConfig +from aiperf.common.models import Conversation, Text, Turn +from aiperf.dataset.generator.parallel_decode import parallel_decode +from aiperf.dataset.generator.prompt import PromptGenerator +from aiperf.dataset.loader.base_loader import BaseFileLoader +from aiperf.dataset.synthesis.models import SynthesisParams +from aiperf.dataset.synthesis.synthesizer import Synthesizer +from aiperf.plugin.enums import DatasetSamplingStrategy + +TraceT = TypeVar("TraceT") + + +class BaseTraceDatasetLoader(BaseFileLoader, Generic[TraceT]): + """Base class for trace dataset loaders with hash_ids-based prompt generation. + + Provides common infrastructure for loading trace-format datasets + (Mooncake, Bailian, etc.) including shared initialization, timestamp + filtering, 3-phase prompt generation with parallel decode, and + synthesis integration. + + Subclasses must implement: + - `can_load`: data format detection + - `load_dataset`: JSONL parsing and session grouping + - `_synthesis_exclude_fields`: fields to strip before synthesis + - `_reconstruct_traces`: rebuild typed traces from synthesized dicts + """ + + def __init__( + self, + *, + filename: str, + prompt_generator: PromptGenerator, + user_config: UserConfig, + default_block_size: int | None = None, + **kwargs: Any, + ) -> None: + super().__init__(filename=filename, user_config=user_config, **kwargs) + self.prompt_generator = prompt_generator + self._skipped_traces = 0 + self._skipped_max_isl = 0 + self._capped_max_osl = 0 + self._start_offset = user_config.input.fixed_schedule_start_offset + self._end_offset = user_config.input.fixed_schedule_end_offset + self._max_isl = user_config.input.synthesis.max_isl + self._max_osl = user_config.input.synthesis.max_osl + + # Use the resolved tokenizer name so worker processes can load from cache + # without needing alias resolution or network access. + self._tokenizer_name = ( + prompt_generator.tokenizer.resolved_name + or user_config.tokenizer.name + or user_config.endpoint.model_names[0] + ) + + # Precedence: user CLI --isl-block-size > plugin metadata default > hardcoded fallback + user_block_size = user_config.input.prompt.input_tokens.block_size + if user_block_size is not None: + self._block_size = user_block_size + elif default_block_size is not None: + self._block_size = default_block_size + else: + self._block_size = InputTokensDefaults.BLOCK_SIZE + + # ------------------------------------------------------------------ + # Shared class methods + # ------------------------------------------------------------------ + + @classmethod + def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: + """Trace datasets use sequential sampling to preserve timestamp order.""" + return DatasetSamplingStrategy.SEQUENTIAL + + # ------------------------------------------------------------------ + # Subclass hooks + # ------------------------------------------------------------------ + + @abstractmethod + def _parse_trace(self, line: str) -> TraceT: + """Parse a single JSONL line into a typed trace object.""" + ... + + def _preprocess_trace(self, trace: TraceT) -> None: + """Optional hook for per-trace pre-processing (e.g. unit conversion). + + Called after parsing but before filtering. Default is a no-op. + """ + + @abstractmethod + def _group_traces(self, items: list[TraceT]) -> dict[str, list[TraceT]]: + """Group flat trace entries into sessions keyed by session ID.""" + ... + + # ------------------------------------------------------------------ + # Timestamp / filtering helpers + # ------------------------------------------------------------------ + + def _timestamp_within_offsets(self, timestamp: int | float) -> bool: + """Check if a timestamp falls within configured offsets.""" + return (self._start_offset is None or timestamp >= self._start_offset) and ( + self._end_offset is None or timestamp <= self._end_offset + ) + + def _filter_and_cap_trace(self, trace: TraceT) -> bool: + """Apply timestamp-window, max_isl, and max_osl filters. + + Returns `True` if the trace should be kept, `False` to skip. + """ + timestamp = getattr(trace, "timestamp", None) + if timestamp is not None and not self._timestamp_within_offsets(timestamp): + self._skipped_traces += 1 + return False + + input_length = getattr(trace, "input_length", None) + if ( + self._max_isl is not None + and input_length is not None + and input_length > self._max_isl + ): + self._skipped_max_isl += 1 + return False + + output_length = getattr(trace, "output_length", None) + if ( + self._max_osl is not None + and output_length is not None + and output_length > self._max_osl + ): + self._capped_max_osl += 1 + trace.output_length = self._max_osl # type: ignore[attr-defined] + + return True + + def _log_filtering_summary(self) -> None: + """Emit info-level messages for any skipped or capped traces.""" + if self._skipped_traces > 0: + self.info( + f"Skipped {self._skipped_traces:,} traces because they were " + f"before the start offset of {self._start_offset} or " + f"after the end offset of {self._end_offset}" + ) + if self._skipped_max_isl > 0: + self.info( + f"Skipped {self._skipped_max_isl:,} traces because input_length " + f"exceeded max_isl of {self._max_isl}" + ) + if self._capped_max_osl > 0: + self.info( + f"{self._capped_max_osl:,} traces exceeded max_osl of " + f"{self._max_osl} and were capped to {self._max_osl}" + ) + + # ------------------------------------------------------------------ + # load_dataset — template method + # ------------------------------------------------------------------ + + def load_dataset(self) -> dict[str, list[TraceT]]: + """Load, filter, group, and optionally synthesize trace data. + + Template method that delegates format-specific work to subclass hooks: + :meth:`_parse_trace`, :meth:`_preprocess_trace`, and + :meth:`_group_traces`. + """ + items: list[TraceT] = [] + + with open(self.filename) as f: + for line in f: + if (line := line.strip()) == "": + continue + + trace = self._parse_trace(line) + self._preprocess_trace(trace) + + if not self._filter_and_cap_trace(trace): + continue + + items.append(trace) + + self._log_filtering_summary() + + data = self._group_traces(items) + self.debug( + lambda: ( + f"Loaded {sum(len(v) for v in data.values()):,} traces " + f"across {len(data):,} sessions from {self.filename}" + ) + ) + + if self.user_config.input.synthesis.should_synthesize(): + data = self._apply_synthesis(data) + + return data + + # ------------------------------------------------------------------ + # convert_to_conversations — 3-phase prompt generation + # ------------------------------------------------------------------ + + def _get_text_input(self, trace: TraceT) -> str | None: + """Return pre-existing text input, or `None` to use hash_ids generation. + + Override for traces that carry literal prompts (e.g. `MooncakeTrace.text_input`). + Default: checks for a `text_input` attribute via getattr. + """ + return getattr(trace, "text_input", None) + + def _build_turn(self, trace: TraceT, prompt: str) -> Turn: + """Build a :class:`Turn` from trace data and a generated prompt. + + Default implementation extracts `timestamp`, `delay`, `output_length` + via getattr, which works for both Mooncake and Bailian traces. + """ + return Turn( + timestamp=getattr(trace, "timestamp", None), + delay=getattr(trace, "delay", None), + texts=[Text(name="text", contents=[prompt])], + max_tokens=getattr(trace, "output_length", None), + ) + + def convert_to_conversations( + self, data: dict[str, list[TraceT]] + ) -> list[Conversation]: + """Convert trace sessions to :class:`Conversation` objects. + + Uses a three-phase approach for optimal performance: + + 1. Build token sequences, checking the string cache first. + 2. Batch parallel decode for all cache misses. + 3. Assemble final :class:`Conversation` objects. + """ + # Phase 1: Build token sequences and identify cache misses + pending_decodes: list[tuple[str, int, list[int], tuple]] = [] + conversations_data: dict[str, list[tuple[TraceT, str | None]]] = {} + + for session_id, traces in data.items(): + conversations_data[session_id] = [] + for idx, trace in enumerate(traces): + text_input = self._get_text_input(trace) + if text_input is not None: + conversations_data[session_id].append((trace, text_input)) + continue + + hash_ids: list[int] = getattr(trace, "hash_ids", None) or [] + input_length: int = getattr(trace, "input_length", 0) + + if hash_ids: + cache_key = ( + tuple(hash_ids), + input_length, + self._block_size, + ) + if cache_key in self.prompt_generator._decoded_cache: + prompt = self.prompt_generator._decoded_cache[cache_key] + conversations_data[session_id].append((trace, prompt)) + else: + tokens = self.prompt_generator._build_token_sequence( + input_length, hash_ids, self._block_size + ) + pending_decodes.append((session_id, idx, tokens, cache_key)) + conversations_data[session_id].append((trace, None)) + else: + prompt = self.prompt_generator.generate( + mean=input_length, stddev=0, hash_ids=[] + ) + conversations_data[session_id].append((trace, prompt)) + + # Phase 2: Batch parallel decode for all cache misses + if pending_decodes: + self.debug( + lambda: f"Parallel decoding {len(pending_decodes)} prompts " + f"({len(data)} conversations)" + ) + token_sequences = [p[2] for p in pending_decodes] + decoded_prompts = parallel_decode(token_sequences, self._tokenizer_name) + + for (session_id, idx, _, cache_key), prompt in zip( + pending_decodes, decoded_prompts, strict=True + ): + self.prompt_generator._decoded_cache[cache_key] = prompt + trace, _ = conversations_data[session_id][idx] + conversations_data[session_id][idx] = (trace, prompt) + + # Phase 3: Build final conversation objects + conversations: list[Conversation] = [] + for session_id, trace_prompt_pairs in conversations_data.items(): + conversation = Conversation(session_id=session_id) + for trace, prompt in trace_prompt_pairs: + conversation.turns.append(self._build_turn(trace, prompt)) + conversations.append(conversation) + + return conversations + + # ------------------------------------------------------------------ + # Synthesis — shared orchestration with subclass hooks + # ------------------------------------------------------------------ + + @abstractmethod + def _synthesis_exclude_fields(self) -> frozenset[str]: + """Fields to exclude when serializing traces for the Synthesizer.""" + ... + + def _synthesis_dump_kwargs(self) -> dict[str, Any]: + """Extra kwargs for `model_dump` during synthesis serialization. + + Override to add e.g. `by_alias=True` for aliased fields. + """ + return {} + + @abstractmethod + def _reconstruct_traces( + self, originals: list[TraceT], synth_dicts: list[dict[str, Any]] + ) -> list[TraceT]: + """Rebuild typed trace objects from synthesized dicts. + + Args: + originals: The original traces for this session (for metadata recovery). + synth_dicts: The synthesized dicts from the Synthesizer. + """ + ... + + def _apply_synthesis( + self, data: dict[str, list[TraceT]] + ) -> dict[str, list[TraceT]]: + """Apply synthesis transformations to traces in-memory.""" + params = SynthesisParams.from_synthesis_config( + self.user_config.input.synthesis, block_size=self._block_size + ) + + exclude = self._synthesis_exclude_fields() + dump_kwargs = self._synthesis_dump_kwargs() + dict_data = { + sid: [ + t.model_dump(exclude=exclude, exclude_none=True, **dump_kwargs) # type: ignore[union-attr] + for t in traces + ] + for sid, traces in data.items() + } + + synthesized = Synthesizer(params=params).synthesize_grouped_traces(dict_data) + + return { + sid: self._reconstruct_traces(data.get(sid, []), synth_traces) + for sid, synth_traces in synthesized.items() + } diff --git a/src/aiperf/dataset/loader/models.py b/src/aiperf/dataset/loader/models.py index 405bf5461..e3e8d6b4f 100644 --- a/src/aiperf/dataset/loader/models.py +++ b/src/aiperf/dataset/loader/models.py @@ -3,7 +3,7 @@ from typing import Literal, TypeVar -from pydantic import Field, model_validator +from pydantic import ConfigDict, Field, model_validator from aiperf.common.models import AIPerfBaseModel, Audio, Image, Text, Video from aiperf.plugin.enums import CustomDatasetType @@ -251,7 +251,56 @@ def validate_input(self) -> "MooncakeTrace": return self +class BailianTrace(AIPerfBaseModel): + """Defines the schema for Alibaba Bailian trace data. + + See https://github.com/alibaba-edu/qwen-bailian-usagetraces-anon for the + upstream dataset and full documentation. + + Each entry represents a single request in a conversation chain. Multi-turn + conversations are linked via ``chat_id`` and ``parent_chat_id``: entries + sharing the same root ``chat_id`` (reachable through ``parent_chat_id``) + belong to the same session and are ordered by ``turn``. + + Important: Bailian traces use a block size of 16 tokens per salted SipHash + block. Set ``--isl-block-size 16`` when using this format. + + Examples: + - Root request: ``{"chat_id": 159, "parent_chat_id": -1, "timestamp": 61.114, "input_length": 521, "output_length": 132, "type": "text", "turn": 1, "hash_ids": [1089, 1090, 1091]}`` + - Follow-up: ``{"chat_id": 160, "parent_chat_id": 159, "timestamp": 62.5, "input_length": 400, "output_length": 80, "type": "text", "turn": 2, "hash_ids": [1089, 1090]}`` + + Note: + The ``type`` field in Bailian JSONL is the request type (text/search/image/file), + not the dataset type. Use ``--custom-dataset-type bailian_trace`` when loading + this format. + """ + + model_config = ConfigDict(populate_by_name=True) + + chat_id: int = Field(description="Randomized chat identifier") + parent_chat_id: int = Field( + default=-1, + description="Parent chat ID for multi-turn conversation chains. -1 indicates a root request.", + ) + timestamp: float = Field( + description="Seconds since request arrival. Converted to milliseconds internally.", + ) + input_length: int = Field(description="Input token count") + output_length: int = Field(description="Output token count") + request_type: str = Field( + default="", + alias="type", + description="Request type from the trace (text/search/image/file). Aliased from 'type' in JSONL.", + ) + turn: int = Field(default=1, description="Conversation turn number") + hash_ids: list[int] = Field( + default_factory=list, + description="Salted SipHash block IDs (16 tokens per block)", + ) + + CustomDatasetT = TypeVar( - "CustomDatasetT", bound=SingleTurn | MultiTurn | RandomPool | MooncakeTrace + "CustomDatasetT", + bound=SingleTurn | MultiTurn | RandomPool | MooncakeTrace | BailianTrace, ) """A union type of all custom data types.""" diff --git a/src/aiperf/dataset/loader/mooncake_trace.py b/src/aiperf/dataset/loader/mooncake_trace.py index f46703f4d..55280567e 100644 --- a/src/aiperf/dataset/loader/mooncake_trace.py +++ b/src/aiperf/dataset/loader/mooncake_trace.py @@ -7,18 +7,11 @@ from pydantic import ValidationError -from aiperf.common.config.user_config import UserConfig -from aiperf.common.models import Conversation, Text, Turn -from aiperf.dataset.generator.parallel_decode import parallel_decode -from aiperf.dataset.generator.prompt import PromptGenerator -from aiperf.dataset.loader.base_loader import BaseFileLoader +from aiperf.dataset.loader.base_trace_loader import BaseTraceDatasetLoader from aiperf.dataset.loader.models import MooncakeTrace -from aiperf.dataset.synthesis.models import SynthesisParams -from aiperf.dataset.synthesis.synthesizer import Synthesizer -from aiperf.plugin.enums import DatasetSamplingStrategy -class MooncakeTraceDatasetLoader(BaseFileLoader): +class MooncakeTraceDatasetLoader(BaseTraceDatasetLoader[MooncakeTrace]): """A dataset loader that loads Mooncake trace data from a file. Loads Mooncake trace data from a file and converts the data into @@ -40,33 +33,6 @@ class MooncakeTraceDatasetLoader(BaseFileLoader): ``` """ - def __init__( - self, - *, - filename: str, - prompt_generator: PromptGenerator, - user_config: UserConfig, - **kwargs, - ): - super().__init__(filename=filename, user_config=user_config, **kwargs) - self.prompt_generator = prompt_generator - self._skipped_traces = 0 - self._skipped_max_isl = 0 - self._capped_max_osl = 0 - self._start_offset = user_config.input.fixed_schedule_start_offset - self._end_offset = user_config.input.fixed_schedule_end_offset - self._max_isl = user_config.input.synthesis.max_isl - self._max_osl = user_config.input.synthesis.max_osl - - # Use the resolved tokenizer name so worker processes can load from cache - # without needing alias resolution or network access. - self._tokenizer_name = ( - prompt_generator.tokenizer.resolved_name - or user_config.tokenizer.name - or user_config.endpoint.model_names[0] - ) - self._block_size = user_config.input.prompt.input_tokens.block_size - @classmethod def can_load( cls, data: dict[str, Any] | None = None, filename: str | Path | None = None @@ -85,200 +51,30 @@ def can_load( except ValidationError: return False - @classmethod - def get_preferred_sampling_strategy(cls) -> DatasetSamplingStrategy: - """Get the preferred dataset sampling strategy for MooncakeTrace.""" - return DatasetSamplingStrategy.SEQUENTIAL - - def load_dataset(self) -> dict[str, list[MooncakeTrace]]: - """Load Mooncake trace data from a file. - - Returns: - A dictionary of session_id and list of Mooncake trace data. - """ - data: dict[str, list[MooncakeTrace]] = defaultdict(list) - - with open(self.filename) as f: - for line in f: - if (line := line.strip()) == "": - continue # Skip empty lines - - trace_data = MooncakeTrace.model_validate_json(line) - - # Skip traces before or after the fixed schedule offset - if ( - trace_data.timestamp is not None - and not self._timestamp_within_offsets(trace_data.timestamp) - ): - self._skipped_traces += 1 - continue - - # Filter by max_isl if configured - if ( - self._max_isl is not None - and trace_data.input_length is not None - and trace_data.input_length > self._max_isl - ): - self._skipped_max_isl += 1 - continue - - # Cap by max_osl if configured - if ( - self._max_osl is not None - and trace_data.output_length is not None - and trace_data.output_length > self._max_osl - ): - self._capped_max_osl += 1 - # Only cap it, do not skip the trace - trace_data.output_length = self._max_osl - - session_id = trace_data.session_id or self.session_id_generator.next() - data[session_id].append(trace_data) - - if self._skipped_traces > 0: - self.info( - f"Skipped {self._skipped_traces:,} traces because they were " - f"before the start offset of {self._start_offset} or " - f"after the end offset of {self._end_offset}" - ) - if self._skipped_max_isl > 0: - self.info( - f"Skipped {self._skipped_max_isl:,} traces because input_length " - f"exceeded max_isl of {self._max_isl}" - ) - if self._capped_max_osl > 0: - self.info( - f"{self._capped_max_osl:,} traces exceeded max_osl of {self._max_osl} and were capped to {self._max_osl}" - ) - self.debug(lambda: f"Loaded {len(data):,} traces from {self.filename}") - - # Apply synthesis if needed - synthesis_config = self.user_config.input.synthesis - if synthesis_config.should_synthesize(): - data = self._apply_synthesis(data) - - return data + # ------------------------------------------------------------------ + # Template-method hooks (see BaseTraceDatasetLoader.load_dataset) + # ------------------------------------------------------------------ - def _timestamp_within_offsets(self, timestamp: int) -> bool: - return (self._start_offset is None or timestamp >= self._start_offset) and ( - self._end_offset is None or timestamp <= self._end_offset - ) + def _parse_trace(self, line: str) -> MooncakeTrace: + return MooncakeTrace.model_validate_json(line) - def convert_to_conversations( - self, data: dict[str, list[MooncakeTrace]] - ) -> list[Conversation]: - """Convert all the Mooncake trace data to conversation objects. - - Uses a three-phase approach for optimal performance: - 1. Phase 1: Build token sequences, checking string cache first - 2. Phase 2: Batch parallel decode for cache misses - 3. Phase 3: Assemble final conversation objects - - Args: - data: A dictionary of session_id and list of Mooncake trace data. - - Returns: - A list of conversations. - """ - # Phase 1: Build token sequences and identify cache misses - # pending_decodes: list of (session_id, trace_idx, tokens, cache_key) - pending_decodes: list[tuple[str, int, list[int], tuple]] = [] - # conversations_data: session_id -> list of (trace, prompt or None) - conversations_data: dict[str, list[tuple[MooncakeTrace, str | None]]] = {} - - for session_id, traces in data.items(): - conversations_data[session_id] = [] - for idx, trace in enumerate(traces): - if trace.text_input is not None: - # Already a string, no decode needed - conversations_data[session_id].append((trace, trace.text_input)) - else: - hash_ids = trace.hash_ids or [] - if hash_ids: - # Check string cache first - cache_key = ( - tuple(hash_ids), - trace.input_length, - self._block_size, - ) - if cache_key in self.prompt_generator._decoded_cache: - # Cache hit - use cached prompt - prompt = self.prompt_generator._decoded_cache[cache_key] - conversations_data[session_id].append((trace, prompt)) - else: - # Cache miss - build tokens for batch decode - tokens = self.prompt_generator._build_token_sequence( - trace.input_length, hash_ids, self._block_size - ) - pending_decodes.append((session_id, idx, tokens, cache_key)) - conversations_data[session_id].append( - (trace, None) - ) # Placeholder - else: - # No hash_ids - use normal generation (already optimized) - prompt = self.prompt_generator.generate( - mean=trace.input_length, stddev=0, hash_ids=[] - ) - conversations_data[session_id].append((trace, prompt)) - - # Phase 2: Batch parallel decode for all cache misses - if pending_decodes: - self.debug( - lambda: f"Parallel decoding {len(pending_decodes)} prompts " - f"({len(data)} conversations)" - ) - token_sequences = [p[2] for p in pending_decodes] - decoded_prompts = parallel_decode(token_sequences, self._tokenizer_name) - - # Fill in placeholders and update cache - for (session_id, idx, _, cache_key), prompt in zip( - pending_decodes, decoded_prompts, strict=True - ): - # Update decoded cache for future reuse - self.prompt_generator._decoded_cache[cache_key] = prompt - # Update placeholder in conversations_data - trace, _ = conversations_data[session_id][idx] - conversations_data[session_id][idx] = (trace, prompt) - - # Phase 3: Build final conversation objects - conversations = [] - for session_id, trace_prompt_pairs in conversations_data.items(): - conversation = Conversation(session_id=session_id) - for trace, prompt in trace_prompt_pairs: - turn = Turn( - timestamp=trace.timestamp, - delay=trace.delay, - texts=[Text(name="text", contents=[prompt])], - max_tokens=trace.output_length, - ) - conversation.turns.append(turn) - conversations.append(conversation) - - return conversations - - def _apply_synthesis( - self, data: dict[str, list[MooncakeTrace]] + def _group_traces( + self, items: list[MooncakeTrace] ) -> dict[str, list[MooncakeTrace]]: - """Apply synthesis transformations to mooncake traces in-memory. - - Args: - data: Dictionary of session_id to list of MooncakeTrace objects. - - Returns: - Dictionary of session_id to list of synthesized MooncakeTrace objects. - """ - params = SynthesisParams.from_synthesis_config( - self.user_config.input.synthesis, block_size=self._block_size - ) - - # Convert to dicts for synthesizer (exclude discriminator field "type") - dict_data = { - sid: [t.model_dump(exclude={"type"}, exclude_none=True) for t in traces] - for sid, traces in data.items() - } - synthesized = Synthesizer(params=params).synthesize_grouped_traces(dict_data) - - return { - sid: [MooncakeTrace.model_validate(t) for t in traces] - for sid, traces in synthesized.items() - } + data: dict[str, list[MooncakeTrace]] = defaultdict(list) + for trace in items: + session_id = trace.session_id or self.session_id_generator.next() + data[session_id].append(trace) + return dict(data) + + # ------------------------------------------------------------------ + # Synthesis hooks + # ------------------------------------------------------------------ + + def _synthesis_exclude_fields(self) -> frozenset[str]: + return frozenset({"type"}) + + def _reconstruct_traces( + self, originals: list[MooncakeTrace], synth_dicts: list[dict[str, Any]] + ) -> list[MooncakeTrace]: + return [MooncakeTrace.model_validate(t) for t in synth_dicts] diff --git a/src/aiperf/plugin/categories.yaml b/src/aiperf/plugin/categories.yaml index dba7caeed..7db95adc3 100644 --- a/src/aiperf/plugin/categories.yaml +++ b/src/aiperf/plugin/categories.yaml @@ -86,6 +86,7 @@ dataset_composer: custom_dataset_loader: protocol: aiperf.dataset.protocols:CustomDatasetLoaderProtocol + metadata_class: aiperf.plugin.schema.schemas:CustomDatasetLoaderMetadata enum: CustomDatasetType description: | Custom dataset loaders parse different JSONL file formats into conversations. diff --git a/src/aiperf/plugin/enums.py b/src/aiperf/plugin/enums.py index 945fc0f68..8c7035d32 100644 --- a/src/aiperf/plugin/enums.py +++ b/src/aiperf/plugin/enums.py @@ -55,7 +55,7 @@ CustomDatasetTypeStr: TypeAlias = str CustomDatasetType = plugins.create_enum(PluginType.CUSTOM_DATASET_LOADER, "CustomDatasetType", module=__name__) -"""Dynamic enum for custom dataset loader. Example: CustomDatasetType.MOONCAKE_TRACE, CustomDatasetType.MULTI_TURN, CustomDatasetType.RANDOM_POOL""" +"""Dynamic enum for custom dataset loader. Example: CustomDatasetType.BAILIAN_TRACE, CustomDatasetType.MOONCAKE_TRACE, CustomDatasetType.MULTI_TURN""" EndpointTypeStr: TypeAlias = str EndpointType = plugins.create_enum(PluginType.ENDPOINT, "EndpointType", module=__name__) diff --git a/src/aiperf/plugin/plugins.py b/src/aiperf/plugin/plugins.py index 1864adc2f..253630f01 100644 --- a/src/aiperf/plugin/plugins.py +++ b/src/aiperf/plugin/plugins.py @@ -26,6 +26,7 @@ ) from aiperf.plugin.extensible_enums import ExtensibleStrEnum, _normalize_name from aiperf.plugin.schema.schemas import ( + CustomDatasetLoaderMetadata, EndpointMetadata, PlotMetadata, PluginsManifest, @@ -1165,12 +1166,39 @@ def get_service_metadata(name: str) -> ServiceMetadata: return get_entry("service", name).get_typed_metadata(ServiceMetadata) +def get_dataset_loader_metadata(name: str) -> CustomDatasetLoaderMetadata: + """Get typed metadata for a custom dataset loader plugin. + + Args: + name: Dataset loader plugin name (e.g., 'mooncake_trace', 'bailian_trace'). + + Returns: + Validated CustomDatasetLoaderMetadata instance. + """ + return get_entry("custom_dataset_loader", name).get_typed_metadata( + CustomDatasetLoaderMetadata + ) + + +def is_trace_dataset(name: str) -> bool: + """Check if a custom dataset loader is a trace-format dataset. + + Args: + name: Dataset loader plugin name (e.g., 'mooncake_trace', 'single_turn'). + + Returns: + True if the loader handles trace-format datasets. + """ + return get_dataset_loader_metadata(name).is_trace + + # Mapping of categories to their metadata classes (for categories with typed metadata) _CATEGORY_METADATA_CLASSES: dict[str, type] = { "endpoint": EndpointMetadata, "transport": TransportMetadata, "plot": PlotMetadata, "service": ServiceMetadata, + "custom_dataset_loader": CustomDatasetLoaderMetadata, } diff --git a/src/aiperf/plugin/plugins.yaml b/src/aiperf/plugin/plugins.yaml index 80e0d1cf0..0e32fc86c 100644 --- a/src/aiperf/plugin/plugins.yaml +++ b/src/aiperf/plugin/plugins.yaml @@ -366,11 +366,27 @@ dataset_composer: # Auto-detection: the custom composer tries loaders in priority order based on can_load(). # ============================================================================= custom_dataset_loader: + bailian_trace: + class: aiperf.dataset.loader.bailian_trace:BailianTraceDatasetLoader + description: | + Alibaba Bailian trace dataset loader for Qwen model serving traces from + https://github.com/alibaba-edu/qwen-bailian-usagetraces-anon. + Loads JSONL traces with chat_id, parent_chat_id (-1 for root), timestamp + (seconds), input/output lengths, turn numbers, and hash_ids for salted + SipHash block-based prompt caching (block_size=16). Supports multi-turn + conversation chains and fixed_schedule timing mode. + metadata: + is_trace: true + default_block_size: 16 + mooncake_trace: class: aiperf.dataset.loader.mooncake_trace:MooncakeTraceDatasetLoader description: | Mooncake trace dataset loader for loading Alibaba Mooncake trace format with timestamp-based replay support. Designed for fixed_schedule timing mode. + metadata: + is_trace: true + default_block_size: 512 multi_turn: class: aiperf.dataset.loader.multi_turn:MultiTurnDatasetLoader diff --git a/src/aiperf/plugin/schema/plugins.schema.json b/src/aiperf/plugin/schema/plugins.schema.json index 177c7fff2..258c3651c 100644 --- a/src/aiperf/plugin/schema/plugins.schema.json +++ b/src/aiperf/plugin/schema/plugins.schema.json @@ -522,18 +522,30 @@ "type": "integer" }, "metadata": { - "anyOf": [ - { - "additionalProperties": true, - "type": "object" + "description": "Metadata schema for custom dataset loader plugins.\n\nDefines format-specific defaults for dataset loaders. When a loader specifies\n``block_size``, it overrides the user's ``--isl-block-size`` config default,\nensuring hash-based prompt generation uses the correct token block size for the\ntrace format (e.g. 16 for Bailian, 512 for Mooncake).\n\nReferenced by: categories.yaml custom_dataset_loader.metadata_class\nUsed in: plugins.yaml custom_dataset_loader entries", + "properties": { + "is_trace": { + "default": false, + "description": "Whether this loader handles trace-format datasets. Trace datasets use hash_ids-based prompt generation, support synthesis options, and prefer sequential sampling with fixed_schedule timing.", + "title": "Is Trace", + "type": "boolean" }, - { - "type": "null" + "default_block_size": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Default token block size for hash-based prompt caching. Used when the user does not explicitly set --isl-block-size. Must match the block size used to generate the trace's hash_ids (e.g. 16 for Bailian, 512 for Mooncake).", + "title": "Default Block Size" } - ], - "default": null, - "description": "Category-specific configuration for this plugin type. The allowed fields depend on the category's metadata_class in categories.yaml.", - "title": "Metadata" + }, + "title": "CustomDatasetLoaderMetadata", + "type": "object" } }, "required": [ diff --git a/src/aiperf/plugin/schema/schemas.py b/src/aiperf/plugin/schema/schemas.py index 32ad4d72c..09729bd2e 100644 --- a/src/aiperf/plugin/schema/schemas.py +++ b/src/aiperf/plugin/schema/schemas.py @@ -315,6 +315,37 @@ class PlotMetadata(BaseModel): ) +class CustomDatasetLoaderMetadata(BaseModel): + """Metadata schema for custom dataset loader plugins. + + Defines format-specific defaults for dataset loaders. When a loader specifies + ``block_size``, it overrides the user's ``--isl-block-size`` config default, + ensuring hash-based prompt generation uses the correct token block size for the + trace format (e.g. 16 for Bailian, 512 for Mooncake). + + Referenced by: categories.yaml custom_dataset_loader.metadata_class + Used in: plugins.yaml custom_dataset_loader entries + """ + + is_trace: bool = Field( + default=False, + description=( + "Whether this loader handles trace-format datasets. " + "Trace datasets use hash_ids-based prompt generation, support synthesis " + "options, and prefer sequential sampling with fixed_schedule timing." + ), + ) + default_block_size: int | None = Field( + default=None, + description=( + "Default token block size for hash-based prompt caching. " + "Used when the user does not explicitly set --isl-block-size. " + "Must match the block size used to generate the trace's hash_ids " + "(e.g. 16 for Bailian, 512 for Mooncake)." + ), + ) + + class ServiceMetadata(BaseModel): """Metadata schema for service plugins. diff --git a/tests/unit/common/config/test_input_config.py b/tests/unit/common/config/test_input_config.py index 5d58ccd5f..2aa10f86f 100644 --- a/tests/unit/common/config/test_input_config.py +++ b/tests/unit/common/config/test_input_config.py @@ -193,16 +193,23 @@ def test_all_custom_dataset_types_require_file(dataset_type): # ============================================================================ -def test_synthesis_with_mooncake_trace_succeeds(): - """Test that synthesis options with mooncake_trace dataset type succeeds.""" +@pytest.mark.parametrize( + "dataset_type", + [ + CustomDatasetType.MOONCAKE_TRACE, + CustomDatasetType.BAILIAN_TRACE, + ], +) # fmt: skip +def test_synthesis_with_trace_dataset_succeeds(dataset_type): + """Test that synthesis options with trace dataset types succeed.""" with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: config = InputConfig( - custom_dataset_type=CustomDatasetType.MOONCAKE_TRACE, + custom_dataset_type=dataset_type, file=temp_file.name, synthesis=SynthesisConfig(speedup_ratio=2.0), ) assert config.synthesis.speedup_ratio == 2.0 - assert config.custom_dataset_type == CustomDatasetType.MOONCAKE_TRACE + assert config.custom_dataset_type == dataset_type @pytest.mark.parametrize( @@ -213,8 +220,8 @@ def test_synthesis_with_mooncake_trace_succeeds(): CustomDatasetType.RANDOM_POOL, ], ) # fmt: skip -def test_synthesis_with_non_mooncake_trace_raises_error(dataset_type): - """Test that synthesis options with non-mooncake_trace dataset type raises error.""" +def test_synthesis_with_non_trace_dataset_raises_error(dataset_type): + """Test that synthesis options with non-trace dataset type raises error.""" with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: with pytest.raises(ValidationError) as exc: InputConfig( @@ -223,7 +230,7 @@ def test_synthesis_with_non_mooncake_trace_raises_error(dataset_type): synthesis=SynthesisConfig(speedup_ratio=2.0), ) - assert "require --custom-dataset-type mooncake_trace" in str(exc.value) + assert "require a trace dataset type" in str(exc.value) def test_synthesis_with_auto_detect_dataset_type_succeeds(): @@ -252,8 +259,8 @@ def test_synthesis_with_auto_detect_dataset_type_succeeds(): SynthesisConfig(speedup_ratio=0.5, prefix_len_multiplier=1.5), ], ) # fmt: skip -def test_synthesis_various_options_require_mooncake_trace(synthesis_config): - """Test that various synthesis option combinations require mooncake_trace.""" +def test_synthesis_various_options_require_trace_dataset(synthesis_config): + """Test that various synthesis option combinations require a trace dataset.""" with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: with pytest.raises(ValidationError) as exc: InputConfig( @@ -262,7 +269,7 @@ def test_synthesis_various_options_require_mooncake_trace(synthesis_config): synthesis=synthesis_config, ) - assert "require --custom-dataset-type mooncake_trace" in str(exc.value) + assert "require a trace dataset type" in str(exc.value) def test_synthesis_defaults_with_any_dataset_type_succeeds(): @@ -320,8 +327,8 @@ def test_synthesis_max_osl_alone_does_not_trigger_synthesis(): CustomDatasetType.RANDOM_POOL, ], ) # fmt: skip -def test_synthesis_max_isl_requires_mooncake_trace(dataset_type): - """Test that max_isl requires mooncake_trace dataset type.""" +def test_synthesis_max_isl_requires_trace_dataset(dataset_type): + """Test that max_isl requires a trace dataset type.""" with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: with pytest.raises(ValidationError) as exc: InputConfig( @@ -330,7 +337,7 @@ def test_synthesis_max_isl_requires_mooncake_trace(dataset_type): synthesis=SynthesisConfig(max_isl=4096), ) - assert "require --custom-dataset-type mooncake_trace" in str(exc.value) + assert "require a trace dataset type" in str(exc.value) @pytest.mark.parametrize( @@ -341,8 +348,8 @@ def test_synthesis_max_isl_requires_mooncake_trace(dataset_type): CustomDatasetType.RANDOM_POOL, ], ) # fmt: skip -def test_synthesis_max_osl_requires_mooncake_trace(dataset_type): - """Test that max_osl requires mooncake_trace dataset type.""" +def test_synthesis_max_osl_requires_trace_dataset(dataset_type): + """Test that max_osl requires a trace dataset type.""" with tempfile.NamedTemporaryFile(suffix=".jsonl") as temp_file: with pytest.raises(ValidationError) as exc: InputConfig( @@ -351,4 +358,4 @@ def test_synthesis_max_osl_requires_mooncake_trace(dataset_type): synthesis=SynthesisConfig(max_osl=2048), ) - assert "require --custom-dataset-type mooncake_trace" in str(exc.value) + assert "require a trace dataset type" in str(exc.value) diff --git a/tests/unit/common/config/test_prompt_config.py b/tests/unit/common/config/test_prompt_config.py index 8bd362e49..655835988 100644 --- a/tests/unit/common/config/test_prompt_config.py +++ b/tests/unit/common/config/test_prompt_config.py @@ -33,7 +33,7 @@ def test_input_tokens_config_defaults(): config = InputTokensConfig() assert config.mean == InputTokensDefaults.MEAN assert config.stddev == InputTokensDefaults.STDDEV - assert config.block_size == InputTokensDefaults.BLOCK_SIZE + assert config.block_size is None def test_input_tokens_config_custom_values(): diff --git a/tests/unit/common/config/test_user_config_mooncake_trace.py b/tests/unit/common/config/test_user_config_mooncake_trace.py index 0e0844891..8ae4c0e05 100644 --- a/tests/unit/common/config/test_user_config_mooncake_trace.py +++ b/tests/unit/common/config/test_user_config_mooncake_trace.py @@ -187,8 +187,8 @@ def test_count_dataset_entries_with_edge_cases(self, mock_is_file, mock_exists): assert count == 3 # 3 non-empty/non-whitespace lines -class TestMooncakeTraceTimingDetection: - """Test _should_use_fixed_schedule_for_mooncake_trace() for automatic timing detection.""" +class TestTraceDatasetTimingDetection: + """Test _should_use_fixed_schedule_for_trace_dataset() for automatic timing detection.""" @patch("pathlib.Path.exists", return_value=True) @patch("pathlib.Path.is_file", return_value=True) @@ -210,7 +210,7 @@ def test_mooncake_trace_with_timestamps_enables_fixed_schedule( ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - result = config._should_use_fixed_schedule_for_mooncake_trace() + result = config._should_use_fixed_schedule_for_trace_dataset() assert result is True @patch("pathlib.Path.exists", return_value=True) @@ -233,15 +233,13 @@ def test_mooncake_trace_without_timestamps_no_fixed_schedule( ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - result = config._should_use_fixed_schedule_for_mooncake_trace() + result = config._should_use_fixed_schedule_for_trace_dataset() assert result is False @patch("pathlib.Path.exists", return_value=True) @patch("pathlib.Path.is_file", return_value=True) - def test_non_mooncake_trace_dataset_no_auto_detection( - self, mock_is_file, mock_exists - ): - """Test that non-mooncake_trace datasets don't trigger auto-detection.""" + def test_non_trace_dataset_no_auto_detection(self, mock_is_file, mock_exists): + """Test that non-trace datasets don't trigger auto-detection.""" mock_file_content = '{"timestamp": 1000, "data": "test"}\n' config = UserConfig( @@ -253,7 +251,7 @@ def test_non_mooncake_trace_dataset_no_auto_detection( ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - result = config._should_use_fixed_schedule_for_mooncake_trace() + result = config._should_use_fixed_schedule_for_trace_dataset() assert result is False @patch("pathlib.Path.exists", return_value=True) @@ -262,7 +260,6 @@ def test_file_parsing_with_empty_lines_and_malformed_json( self, mock_is_file, mock_exists ): """Test file parsing handles empty lines and malformed JSON gracefully.""" - # Content with empty lines, whitespace, and malformed JSON mock_file_content = ( '{"input_length": 50, "timestamp": 1000}\n' "\n" # Empty line @@ -284,15 +281,14 @@ def test_file_parsing_with_empty_lines_and_malformed_json( ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - # Should handle malformed JSON gracefully and still detect timestamps - has_timestamps = config._should_use_fixed_schedule_for_mooncake_trace() - assert has_timestamps is True # Should find valid timestamps despite errors + has_timestamps = config._should_use_fixed_schedule_for_trace_dataset() + assert has_timestamps is True @patch("pathlib.Path.exists", return_value=True) @patch("pathlib.Path.is_file", return_value=True) def test_empty_file_timing_detection(self, mock_is_file, mock_exists): """Test timing detection with completely empty files.""" - mock_file_content = "" # Completely empty file + mock_file_content = "" config = UserConfig( endpoint=EndpointConfig(model_names=["test-model"]), @@ -303,8 +299,7 @@ def test_empty_file_timing_detection(self, mock_is_file, mock_exists): ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - # Should handle empty file gracefully - assert config._should_use_fixed_schedule_for_mooncake_trace() is False + assert config._should_use_fixed_schedule_for_trace_dataset() is False @patch("pathlib.Path.exists", return_value=True) @patch("pathlib.Path.is_file", return_value=True) @@ -326,5 +321,4 @@ def test_only_malformed_json_timing_detection(self, mock_is_file, mock_exists): ) with patch("builtins.open", mock_open(read_data=mock_file_content)): - # Should find no timestamps in malformed JSON - assert config._should_use_fixed_schedule_for_mooncake_trace() is False + assert config._should_use_fixed_schedule_for_trace_dataset() is False diff --git a/tests/unit/dataset/composer/test_custom_composer.py b/tests/unit/dataset/composer/test_custom_composer.py index 584a08c7a..75ec844b5 100644 --- a/tests/unit/dataset/composer/test_custom_composer.py +++ b/tests/unit/dataset/composer/test_custom_composer.py @@ -64,7 +64,7 @@ def test_create_loader_instance_dataset_types( composer._create_loader_instance(dataset_type) assert isinstance(composer.loader, expected_instance) - @patch("aiperf.dataset.loader.mooncake_trace.parallel_decode") + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) def test_create_dataset_trace( @@ -80,7 +80,7 @@ def test_create_dataset_trace( assert all(isinstance(turn, Turn) for c in conversations for turn in c.turns) assert all(len(turn.texts) == 1 for c in conversations for turn in c.turns) - @patch("aiperf.dataset.loader.mooncake_trace.parallel_decode") + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) def test_max_tokens_config( @@ -104,7 +104,7 @@ def test_max_tokens_config( # Should be roughly around the mean of 120 (within 3 stddev) assert 96 < turn.max_tokens < 144 - @patch("aiperf.dataset.loader.mooncake_trace.parallel_decode") + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) @patch("pathlib.Path.iterdir", return_value=[]) @@ -205,13 +205,22 @@ def test_set_sampling_strategy_does_not_override( class TestSynthesisValidation: """Test class for synthesis configuration validation.""" - def test_synthesis_allowed_with_mooncake_trace(self, trace_config, mock_tokenizer): - """Test that synthesis options are allowed with mooncake_trace dataset type.""" + @pytest.mark.parametrize( + "dataset_type", + [ + CustomDatasetType.MOONCAKE_TRACE, + CustomDatasetType.BAILIAN_TRACE, + ], + ) + def test_synthesis_allowed_with_trace_datasets( + self, trace_config, mock_tokenizer, dataset_type + ): + """Test that synthesis options are allowed with trace dataset types.""" trace_config.input.synthesis = SynthesisConfig(speedup_ratio=2.0) composer = CustomDatasetComposer(trace_config, mock_tokenizer) # Should not raise - composer._validate_synthesis_config(CustomDatasetType.MOONCAKE_TRACE) + composer._validate_synthesis_config(dataset_type) @pytest.mark.parametrize( "dataset_type", @@ -221,17 +230,17 @@ def test_synthesis_allowed_with_mooncake_trace(self, trace_config, mock_tokenize CustomDatasetType.RANDOM_POOL, ], ) - def test_synthesis_raises_error_with_non_mooncake_types( + def test_synthesis_raises_error_with_non_trace_types( self, custom_config, mock_tokenizer, dataset_type ): - """Test that synthesis options raise error with non-mooncake dataset types.""" + """Test that synthesis options raise error with non-trace dataset types.""" custom_config.input.synthesis = SynthesisConfig(speedup_ratio=2.0) composer = CustomDatasetComposer(custom_config, mock_tokenizer) with pytest.raises(ValueError) as exc: composer._validate_synthesis_config(dataset_type) - assert "only supported with mooncake_trace" in str(exc.value) + assert "only supported with trace datasets" in str(exc.value) assert dataset_type.value in str(exc.value) @pytest.mark.parametrize( @@ -253,7 +262,7 @@ def test_various_synthesis_options_raise_error( with pytest.raises(ValueError) as exc: composer._validate_synthesis_config(CustomDatasetType.SINGLE_TURN) - assert "only supported with mooncake_trace" in str(exc.value) + assert "only supported with trace datasets" in str(exc.value) def test_default_synthesis_allowed_with_any_type( self, custom_config, mock_tokenizer diff --git a/tests/unit/dataset/loader/test_bailian_trace.py b/tests/unit/dataset/loader/test_bailian_trace.py new file mode 100644 index 000000000..97885a993 --- /dev/null +++ b/tests/unit/dataset/loader/test_bailian_trace.py @@ -0,0 +1,739 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import logging +from unittest.mock import Mock, patch + +import pytest +from pydantic import ValidationError + +from aiperf.common.config import ( + EndpointConfig, + InputConfig, + InputTokensConfig, + PromptConfig, + SynthesisConfig, + UserConfig, +) +from aiperf.dataset.loader.bailian_trace import BailianTraceDatasetLoader +from aiperf.dataset.loader.models import BailianTrace + +# ============================================================================ +# BailianTrace Model Tests +# ============================================================================ + + +class TestBailianTrace: + """Validation and construction tests for the BailianTrace model.""" + + def test_create_minimal(self): + trace = BailianTrace( + chat_id=1, + timestamp=1700000000.0, + input_length=100, + output_length=40, + ) + assert trace.chat_id == 1 + assert trace.parent_chat_id == -1 + assert trace.timestamp == 1700000000.0 + assert trace.input_length == 100 + assert trace.output_length == 40 + assert trace.request_type == "" + assert trace.turn == 1 + assert trace.hash_ids == [] + + def test_create_full(self): + trace = BailianTrace( + chat_id=42, + parent_chat_id=10, + timestamp=1700000001.5, + input_length=256, + output_length=64, + request_type="chat", + turn=3, + hash_ids=[1, 2, 3, 4, 5], + ) + assert trace.chat_id == 42 + assert trace.parent_chat_id == 10 + assert trace.turn == 3 + assert trace.hash_ids == [1, 2, 3, 4, 5] + assert trace.request_type == "chat" + + def test_type_alias_deserialization(self): + """The JSONL 'type' field maps to 'request_type' via alias.""" + raw = ( + '{"chat_id": 1, "timestamp": 1.0, "input_length": 10, ' + '"output_length": 5, "type": "inference"}' + ) + trace = BailianTrace.model_validate_json(raw) + assert trace.request_type == "inference" + + def test_missing_required_chat_id(self): + with pytest.raises(ValidationError, match="chat_id"): + BailianTrace( + timestamp=1.0, + input_length=10, + output_length=5, + ) + + def test_missing_required_timestamp(self): + with pytest.raises(ValidationError, match="timestamp"): + BailianTrace( + chat_id=1, + input_length=10, + output_length=5, + ) + + def test_missing_required_input_length(self): + with pytest.raises(ValidationError, match="input_length"): + BailianTrace( + chat_id=1, + timestamp=1.0, + output_length=5, + ) + + def test_missing_required_output_length(self): + with pytest.raises(ValidationError, match="output_length"): + BailianTrace( + chat_id=1, + timestamp=1.0, + input_length=10, + ) + + +# ============================================================================ +# BailianTraceDatasetLoader Tests +# ============================================================================ + + +class TestBailianTraceDatasetLoader: + """Core loader functionality tests.""" + + @pytest.fixture + def mock_prompt_generator(self): + generator = Mock() + generator.generate.return_value = "Generated prompt text" + generator._decoded_cache = {} + generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + return generator + + @pytest.fixture + def default_user_config(self): + return UserConfig(endpoint=EndpointConfig(model_names=["test-model"])) + + def _make_user_config( + self, + start_offset: int | None = None, + end_offset: int | None = None, + file: str | None = None, + ) -> UserConfig: + has_offsets = start_offset is not None or end_offset is not None + input_config = ( + InputConfig( + file=file, + fixed_schedule=True, + fixed_schedule_start_offset=start_offset, + fixed_schedule_end_offset=end_offset, + ) + if has_offsets + else InputConfig() + ) + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=input_config, + ) + + # ---- basic loading ---- + + def test_load_basic( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + content = [ + '{"chat_id": 1, "parent_chat_id": -1, "timestamp": 1.0, "input_length": 100, "output_length": 40, "type": "text", "turn": 1, "hash_ids": [10, 20]}', + '{"chat_id": 2, "parent_chat_id": -1, "timestamp": 2.0, "input_length": 200, "output_length": 80, "type": "text", "turn": 1, "hash_ids": [30]}', + ] + filename = create_jsonl_file(content) + + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + assert len(dataset) == 2 + all_traces = [t for traces in dataset.values() for t in traces] + assert all_traces[0].input_length == 100 + assert all_traces[1].input_length == 200 + + def test_timestamps_converted_to_milliseconds( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + content = [ + '{"chat_id": 1, "timestamp": 1.5, "input_length": 10, "output_length": 5}', + ] + filename = create_jsonl_file(content) + + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + trace = list(dataset.values())[0][0] + assert trace.timestamp == 1500.0 + + def test_skips_empty_lines( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + content = [ + '{"chat_id": 1, "timestamp": 1.0, "input_length": 10, "output_length": 5}', + "", + '{"chat_id": 2, "timestamp": 2.0, "input_length": 20, "output_length": 10}', + ] + filename = create_jsonl_file(content) + + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + total = sum(len(v) for v in dataset.values()) + assert total == 2 + + # ---- multi-turn grouping ---- + + def test_groups_by_parent_chat_id( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Entries with the same root should be grouped into one session.""" + content = [ + '{"chat_id": 100, "parent_chat_id": -1, "timestamp": 1.0, "input_length": 50, "output_length": 20, "turn": 1}', + '{"chat_id": 101, "parent_chat_id": 100, "timestamp": 2.0, "input_length": 60, "output_length": 25, "turn": 2}', + '{"chat_id": 102, "parent_chat_id": 101, "timestamp": 3.0, "input_length": 70, "output_length": 30, "turn": 3}', + ] + filename = create_jsonl_file(content) + + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + # All three should be in one session rooted at chat_id=100 + assert len(dataset) == 1 + session = list(dataset.values())[0] + assert len(session) == 3 + assert [t.turn for t in session] == [1, 2, 3] + + def test_separate_sessions( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Independent root entries form separate sessions.""" + content = [ + '{"chat_id": 1, "parent_chat_id": -1, "timestamp": 1.0, "input_length": 50, "output_length": 20, "turn": 1}', + '{"chat_id": 2, "parent_chat_id": -1, "timestamp": 2.0, "input_length": 60, "output_length": 25, "turn": 1}', + ] + filename = create_jsonl_file(content) + + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + assert len(dataset) == 2 + + def test_turns_sorted_within_session( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Turns are sorted even if JSONL order differs.""" + content = [ + '{"chat_id": 102, "parent_chat_id": 101, "timestamp": 3.0, "input_length": 70, "output_length": 30, "turn": 3}', + '{"chat_id": 100, "parent_chat_id": -1, "timestamp": 1.0, "input_length": 50, "output_length": 20, "turn": 1}', + '{"chat_id": 101, "parent_chat_id": 100, "timestamp": 2.0, "input_length": 60, "output_length": 25, "turn": 2}', + ] + filename = create_jsonl_file(content) + + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + session = list(dataset.values())[0] + assert [t.turn for t in session] == [1, 2, 3] + + # ---- offset filtering ---- + + @pytest.mark.parametrize( + "start_offset,end_offset,expected_count,description", + [ + (None, None, 3, "no filtering"), + (1500, None, 2, "start offset only — keeps ts >= 1500 ms"), + (None, 2500, 2, "end offset only — keeps ts <= 2500 ms"), + (1500, 2500, 1, "both offsets — keeps ts in [1500, 2500]"), + ], + ) # fmt: skip + def test_offset_filtering( + self, + create_jsonl_file, + mock_prompt_generator, + start_offset, + end_offset, + expected_count, + description, + ): + """Timestamps are converted to ms before offset comparison.""" + content = [ + '{"chat_id": 1, "timestamp": 1.0, "input_length": 10, "output_length": 5}', + '{"chat_id": 2, "timestamp": 2.0, "input_length": 10, "output_length": 5}', + '{"chat_id": 3, "timestamp": 3.0, "input_length": 10, "output_length": 5}', + ] + filename = create_jsonl_file(content) + + user_config = self._make_user_config(start_offset, end_offset, file=filename) + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + total = sum(len(v) for v in dataset.values()) + assert total == expected_count, f"Failed for {description}" + + def test_offset_filtering_logs_skipped( + self, create_jsonl_file, mock_prompt_generator, caplog + ): + caplog.set_level(logging.INFO) + + content = [ + '{"chat_id": 1, "timestamp": 0.5, "input_length": 10, "output_length": 5}', + '{"chat_id": 2, "timestamp": 2.0, "input_length": 10, "output_length": 5}', + '{"chat_id": 3, "timestamp": 5.0, "input_length": 10, "output_length": 5}', + ] + filename = create_jsonl_file(content) + + user_config = self._make_user_config(1000, 3000, file=filename) + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + loader.load_dataset() + + assert "Skipped 2 traces" in caplog.text + + # ---- max_isl / max_osl ---- + + @pytest.mark.parametrize( + "max_isl,expected_count", + [ + (None, 3), + (500, 3), + (150, 2), + (50, 0), + ], + ) # fmt: skip + def test_max_isl_filtering( + self, create_jsonl_file, mock_prompt_generator, max_isl, expected_count + ): + content = [ + '{"chat_id": 1, "timestamp": 1.0, "input_length": 100, "output_length": 10}', + '{"chat_id": 2, "timestamp": 2.0, "input_length": 150, "output_length": 10}', + '{"chat_id": 3, "timestamp": 3.0, "input_length": 200, "output_length": 10}', + ] + filename = create_jsonl_file(content) + + user_config = UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig(synthesis=SynthesisConfig(max_isl=max_isl)), + ) + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + total = sum(len(v) for v in dataset.values()) + assert total == expected_count + + @pytest.mark.parametrize( + "max_osl,expected_output_lengths", + [ + (None, [50, 100, 150]), + (500, [50, 100, 150]), + (100, [50, 100, 100]), + (25, [25, 25, 25]), + ], + ) # fmt: skip + def test_max_osl_capping( + self, create_jsonl_file, mock_prompt_generator, max_osl, expected_output_lengths + ): + content = [ + '{"chat_id": 1, "timestamp": 1.0, "input_length": 10, "output_length": 50}', + '{"chat_id": 2, "timestamp": 2.0, "input_length": 10, "output_length": 100}', + '{"chat_id": 3, "timestamp": 3.0, "input_length": 10, "output_length": 150}', + ] + filename = create_jsonl_file(content) + + user_config = UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig(synthesis=SynthesisConfig(max_osl=max_osl)), + ) + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + actual = [t.output_length for traces in dataset.values() for t in traces] + assert actual == expected_output_lengths + + # ---- can_load ---- + + @pytest.mark.parametrize( + "data,expected", + [ + ({"chat_id": 1, "timestamp": 1.0, "input_length": 10, "output_length": 5}, True), + ({"chat_id": 1, "timestamp": 1.0, "input_length": 10, "output_length": 5, "type": "chat", "turn": 2, "hash_ids": [1]}, True), + ({"input_length": 10, "hash_ids": [1]}, False), # Mooncake, not Bailian + ({"text": "hello"}, False), + (None, False), + ], + ) # fmt: skip + def test_can_load(self, data, expected): + assert BailianTraceDatasetLoader.can_load(data=data) is expected + + # ---- convert_to_conversations ---- + + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") + def test_convert_to_conversations( + self, mock_parallel_decode, mock_prompt_generator, default_user_config + ): + mock_parallel_decode.return_value = ["decoded prompt 1", "decoded prompt 2"] + + trace_data = { + "100": [ + BailianTrace( + chat_id=100, + timestamp=1000.0, + input_length=100, + output_length=50, + hash_ids=[1, 2, 3], + ), + ], + "200": [ + BailianTrace( + chat_id=200, + timestamp=2000.0, + input_length=200, + output_length=80, + hash_ids=[4, 5], + ), + ], + } + + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + conversations = loader.convert_to_conversations(trace_data) + + assert len(conversations) == 2 + assert conversations[0].session_id == "100" + assert conversations[0].turns[0].timestamp == 1000.0 + assert conversations[0].turns[0].max_tokens == 50 + + def test_convert_empty_data(self, mock_prompt_generator, default_user_config): + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + assert loader.convert_to_conversations({}) == [] + + def test_convert_without_hash_ids(self, mock_prompt_generator, default_user_config): + """When hash_ids is empty, falls back to normal prompt generation.""" + trace_data = { + "1": [ + BailianTrace( + chat_id=1, + timestamp=1000.0, + input_length=100, + output_length=50, + ), + ], + } + + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + conversations = loader.convert_to_conversations(trace_data) + + assert len(conversations) == 1 + mock_prompt_generator.generate.assert_called_once_with( + mean=100, stddev=0, hash_ids=[] + ) + + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") + def test_parallel_decode_length_mismatch_raises( + self, mock_parallel_decode, mock_prompt_generator, default_user_config + ): + """strict=True in zip guards against silent data loss.""" + mock_parallel_decode.return_value = ["only one"] # expecting 2 + + trace_data = { + "1": [ + BailianTrace( + chat_id=1, + timestamp=1.0, + input_length=10, + output_length=5, + hash_ids=[1], + ) + ], + "2": [ + BailianTrace( + chat_id=2, + timestamp=2.0, + input_length=20, + output_length=10, + hash_ids=[2], + ) + ], + } + + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + + with pytest.raises(ValueError, match="zip"): + loader.convert_to_conversations(trace_data) + + # ---- multi-turn conversation conversion ---- + + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") + def test_multi_turn_conversation_ordering( + self, mock_parallel_decode, mock_prompt_generator, default_user_config + ): + mock_parallel_decode.return_value = [ + "prompt turn 1", + "prompt turn 2", + "prompt turn 3", + ] + + trace_data = { + "100": [ + BailianTrace( + chat_id=100, + timestamp=1000.0, + input_length=50, + output_length=20, + turn=1, + hash_ids=[1], + ), + BailianTrace( + chat_id=101, + parent_chat_id=100, + timestamp=2000.0, + input_length=60, + output_length=25, + turn=2, + hash_ids=[2], + ), + BailianTrace( + chat_id=102, + parent_chat_id=101, + timestamp=3000.0, + input_length=70, + output_length=30, + turn=3, + hash_ids=[3], + ), + ], + } + + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + conversations = loader.convert_to_conversations(trace_data) + + assert len(conversations) == 1 + conv = conversations[0] + assert len(conv.turns) == 3 + assert conv.turns[0].timestamp == 1000.0 + assert conv.turns[1].timestamp == 2000.0 + assert conv.turns[2].timestamp == 3000.0 + + +# ============================================================================ +# Synthesis Integration Tests +# ============================================================================ + + +def _make_synthesis_config( + speedup_ratio: float = 1.0, + prefix_len_multiplier: float = 1.0, + max_isl: int | None = None, + block_size: int = 16, +) -> UserConfig: + return UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig( + synthesis=SynthesisConfig( + speedup_ratio=speedup_ratio, + prefix_len_multiplier=prefix_len_multiplier, + max_isl=max_isl, + ), + prompt=PromptConfig( + input_tokens=InputTokensConfig(block_size=block_size), + ), + ), + ) + + +class TestBailianTraceSynthesisIntegration: + @pytest.fixture + def mock_prompt_generator(self): + generator = Mock() + generator.generate.return_value = "Generated prompt text" + generator._decoded_cache = {} + generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + return generator + + def test_speedup_ratio_scales_timestamps(self, mock_prompt_generator): + data = { + "1": [ + BailianTrace( + chat_id=1, + timestamp=1000.0, + input_length=16, + output_length=10, + hash_ids=[1], + ), + BailianTrace( + chat_id=2, + parent_chat_id=1, + timestamp=2000.0, + input_length=16, + output_length=10, + turn=2, + hash_ids=[2], + ), + ], + } + user_config = _make_synthesis_config(speedup_ratio=2.0) + + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + result = loader._apply_synthesis(data) + + assert result["1"][0].timestamp == 500.0 + assert result["1"][1].timestamp == 1000.0 + + def test_synthesis_preserves_session_structure(self, mock_prompt_generator): + data = { + "1": [ + BailianTrace( + chat_id=1, + timestamp=1000.0, + input_length=16, + output_length=10, + hash_ids=[1], + ), + ], + "2": [ + BailianTrace( + chat_id=2, + timestamp=2000.0, + input_length=16, + output_length=10, + hash_ids=[2], + ), + ], + } + user_config = _make_synthesis_config(speedup_ratio=2.0) + + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + result = loader._apply_synthesis(data) + + assert set(result.keys()) == {"1", "2"} + assert len(result["1"]) == 1 + assert len(result["2"]) == 1 + + def test_synthesis_returns_bailian_trace_objects(self, mock_prompt_generator): + data = { + "1": [ + BailianTrace( + chat_id=1, + timestamp=1000.0, + input_length=16, + output_length=10, + hash_ids=[1], + ), + ], + } + user_config = _make_synthesis_config(speedup_ratio=2.0) + + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + result = loader._apply_synthesis(data) + + for traces in result.values(): + for trace in traces: + assert isinstance(trace, BailianTrace) + + def test_empty_input(self, mock_prompt_generator): + user_config = _make_synthesis_config(speedup_ratio=2.0) + loader = BailianTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + assert loader._apply_synthesis({}) == {} + + def test_end_to_end_with_synthesis(self, create_jsonl_file, mock_prompt_generator): + content = [ + '{"chat_id": 1, "timestamp": 1.0, "input_length": 16, "output_length": 10, "hash_ids": [1]}', + '{"chat_id": 2, "timestamp": 2.0, "input_length": 16, "output_length": 10, "hash_ids": [2]}', + ] + filename = create_jsonl_file(content) + + user_config = _make_synthesis_config(speedup_ratio=2.0) + loader = BailianTraceDatasetLoader( + filename=filename, + user_config=user_config, + prompt_generator=mock_prompt_generator, + ) + dataset = loader.load_dataset() + + traces = [t for ts in dataset.values() for t in ts] + # Timestamps: 1.0s → 1000ms, 2.0s → 2000ms, then /2 = 500ms, 1000ms + assert traces[0].timestamp == 500.0 + assert traces[1].timestamp == 1000.0 diff --git a/tests/unit/dataset/loader/test_trace.py b/tests/unit/dataset/loader/test_trace.py index 2943a6ba1..d03922974 100644 --- a/tests/unit/dataset/loader/test_trace.py +++ b/tests/unit/dataset/loader/test_trace.py @@ -356,7 +356,7 @@ def test_load_dataset_logs_skipped_traces( # Check that the skipped traces message is logged assert f"Skipped {expected_skipped:,} traces" in caplog.text - @patch("aiperf.dataset.loader.mooncake_trace.parallel_decode") + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") def test_convert_to_conversations( self, mock_parallel_decode, mock_prompt_generator, default_user_config ): @@ -826,7 +826,7 @@ def user_config_for_reproducibility(self): ), ) - @patch("aiperf.dataset.loader.mooncake_trace.parallel_decode") + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") def test_mooncake_flow_reproducibility_with_same_seed( self, mock_parallel_decode, mock_tokenizer_cls, user_config_for_reproducibility ): @@ -913,7 +913,7 @@ def deterministic_decode(token_sequences, tokenizer_name): f"First run: {prompts1}, Second run: {prompts2}" ) - @patch("aiperf.dataset.loader.mooncake_trace.parallel_decode") + @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") def test_parallel_decode_length_mismatch_raises( self, mock_parallel_decode, mock_prompt_generator, default_user_config ): From 66fd6899ef5467bed70b8bb7f15be108be219fc4 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Thu, 26 Feb 2026 21:04:53 -0800 Subject: [PATCH 2/7] fix: Address code review feedback for trace loader PR - Remove .value on enum in error message (use string-based enum directly) - Validate mean is not None before cached prompt generation - Add cycle detection in Bailian find_root to prevent infinite loops - Reset filtering counters per load_dataset() call to avoid over-reporting Signed-off-by: Anthony Casagrande --- src/aiperf/dataset/composer/custom.py | 2 +- src/aiperf/dataset/generator/prompt.py | 2 ++ src/aiperf/dataset/loader/bailian_trace.py | 4 ++++ src/aiperf/dataset/loader/base_trace_loader.py | 3 +++ 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index 33fb2657e..fd2ee73d2 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -183,7 +183,7 @@ def _validate_synthesis_config(self, dataset_type: CustomDatasetType) -> None: f"Synthesis options (--synthesis-speedup-ratio, --synthesis-prefix-len-multiplier, " f"--synthesis-prefix-root-multiplier, --synthesis-prompt-len-multiplier) " f"are only supported with trace datasets, " - f"but got {dataset_type.value}. " + f"but got {dataset_type}. " f"Either remove synthesis options or use a trace dataset type." ) diff --git a/src/aiperf/dataset/generator/prompt.py b/src/aiperf/dataset/generator/prompt.py index 78da6bf70..9c76d4a32 100644 --- a/src/aiperf/dataset/generator/prompt.py +++ b/src/aiperf/dataset/generator/prompt.py @@ -167,6 +167,8 @@ def generate( A synthetic prompt as a string. """ if hash_ids: + if mean is None: + raise ValueError("mean must be provided when hash_ids is set.") block_size = ( self.config.input_tokens.block_size or InputTokensDefaults.BLOCK_SIZE ) diff --git a/src/aiperf/dataset/loader/bailian_trace.py b/src/aiperf/dataset/loader/bailian_trace.py index b18f57d24..b3e4012e8 100644 --- a/src/aiperf/dataset/loader/bailian_trace.py +++ b/src/aiperf/dataset/loader/bailian_trace.py @@ -89,8 +89,12 @@ def find_root(chat_id: int) -> int: return root_cache[chat_id] path: list[int] = [] + seen: set[int] = set() current = chat_id while current in by_chat_id and by_chat_id[current].parent_chat_id != -1: + if current in seen: + break + seen.add(current) parent = by_chat_id[current].parent_chat_id if parent == current or parent not in by_chat_id: break diff --git a/src/aiperf/dataset/loader/base_trace_loader.py b/src/aiperf/dataset/loader/base_trace_loader.py index 58efb03a7..7f9a602b3 100644 --- a/src/aiperf/dataset/loader/base_trace_loader.py +++ b/src/aiperf/dataset/loader/base_trace_loader.py @@ -167,6 +167,9 @@ def load_dataset(self) -> dict[str, list[TraceT]]: :meth:`_parse_trace`, :meth:`_preprocess_trace`, and :meth:`_group_traces`. """ + self._skipped_traces = 0 + self._skipped_max_isl = 0 + self._capped_max_osl = 0 items: list[TraceT] = [] with open(self.filename) as f: From c9ad63eaeb7865b6e5218d8bc834b9cbe52dc909 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Fri, 27 Feb 2026 18:17:52 -0800 Subject: [PATCH 3/7] fix: support tokenizers with non-standard kwargs (e.g. Kimi) Tokenizers like Kimi use `allow_special_tokens` instead of the standard `add_special_tokens` for encode, and their `decode()` doesn't accept `skip_special_tokens`. Passing unsupported kwargs triggers the slow `PreTrainedTokenizer.super()` fallback path, causing ~5000x slower decode (~204ms vs 0.04ms per 4500 tokens). After loading, inspect the tokenizer's method signatures and override the default call/encode/decode args to match. Signed-off-by: Anthony Casagrande --- src/aiperf/common/tokenizer.py | 21 ++ .../common/test_tokenizer_kwarg_overrides.py | 248 ++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 tests/unit/common/test_tokenizer_kwarg_overrides.py diff --git a/src/aiperf/common/tokenizer.py b/src/aiperf/common/tokenizer.py index 51a07975a..fb553a560 100644 --- a/src/aiperf/common/tokenizer.py +++ b/src/aiperf/common/tokenizer.py @@ -4,6 +4,7 @@ """HuggingFace tokenizer wrapper with sensible defaults.""" import contextlib +import inspect import io import logging import os @@ -47,6 +48,14 @@ def __init__(self, name: str, suggestions: list[tuple[str, int]]) -> None: ) +def _supports_kwarg(obj: object, method_name: str, kwarg: str) -> bool: + """Check if a method on an object accepts a specific keyword argument.""" + method = getattr(obj, method_name, None) + if method is None: + return False + return kwarg in inspect.signature(method).parameters + + def _is_offline_mode() -> bool: """Check if HuggingFace offline mode is enabled via environment variables.""" return bool(os.environ.get("HF_HUB_OFFLINE", "")) or bool( @@ -147,6 +156,16 @@ def _require_init(self) -> None: if self._tokenizer is None: raise NotInitializedError("Tokenizer is not initialized.") + def _apply_kwarg_overrides(self) -> None: + """Override default args for tokenizers that use non-standard kwargs (e.g. Kimi).""" + if self._tokenizer is None: + return + if _supports_kwarg(self._tokenizer, "encode", "allow_special_tokens"): + self._call_args = {"allow_special_tokens": False} + self._encode_args = {"allow_special_tokens": False} + if not _supports_kwarg(self._tokenizer, "decode", "skip_special_tokens"): + self._decode_args = {} + @staticmethod def resolve_alias(name: str) -> AliasResolutionResult: """Resolve a tokenizer name alias to its canonical repository ID.""" @@ -208,6 +227,7 @@ def from_pretrained( revision=revision, ) tokenizer_cls._resolved_name = resolved_name + tokenizer_cls._apply_kwarg_overrides() except AmbiguousTokenizerNameError: raise except Exception as e: @@ -285,6 +305,7 @@ class _OfflineModelInfo: revision=revision, local_files_only=True, ) + tokenizer_cls._apply_kwarg_overrides() return tokenizer_cls finally: huggingface_hub.model_info = _original_model_info diff --git a/tests/unit/common/test_tokenizer_kwarg_overrides.py b/tests/unit/common/test_tokenizer_kwarg_overrides.py new file mode 100644 index 000000000..819070d67 --- /dev/null +++ b/tests/unit/common/test_tokenizer_kwarg_overrides.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for tokenizer kwarg override detection and application. + +Tokenizers like Kimi use non-standard kwargs (e.g. `allow_special_tokens` +instead of `add_special_tokens`). Passing unsupported kwargs triggers the +slow `PreTrainedTokenizer.super()` fallback. These tests verify that +`_supports_kwarg` and `_apply_kwarg_overrides` correctly detect and adapt. +""" + +import pytest + +from aiperf.common.tokenizer import Tokenizer, _supports_kwarg + +# -- Fake tokenizer backends for testing -- + + +class StandardTokenizerBackend: + """Mimics a standard HuggingFace tokenizer (e.g. Qwen).""" + + def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> list[int]: + return list(range(len(text.split()))) + + def decode( + self, token_ids: list[int], skip_special_tokens: bool = False, **kwargs + ) -> str: + return " ".join(f"t{i}" for i in token_ids) + + def __call__(self, text: str, add_special_tokens: bool = True, **kwargs) -> dict: + return {"input_ids": self.encode(text, add_special_tokens=add_special_tokens)} + + bos_token_id = 1 + eos_token_id = 2 + + +class KimiLikeTokenizerBackend: + """Mimics Kimi's TikTokenTokenizer: uses allow_special_tokens, no skip_special_tokens.""" + + def encode( + self, text: str, allow_special_tokens: bool = True, **kwargs + ) -> list[int]: + if kwargs: + raise TypeError( + f"Unexpected kwargs would trigger slow super().encode: {kwargs}" + ) + return list(range(len(text.split()))) + + def decode(self, token_ids: list[int] | int, **kwargs) -> str: + if kwargs: + raise TypeError( + f"Unexpected kwargs would trigger slow super().decode: {kwargs}" + ) + if isinstance(token_ids, int): + token_ids = [token_ids] + return " ".join(f"t{i}" for i in token_ids) + + def __call__(self, text: str, allow_special_tokens: bool = True, **kwargs) -> dict: + return { + "input_ids": self.encode(text, allow_special_tokens=allow_special_tokens) + } + + bos_token_id = 1 + eos_token_id = 2 + + +class MinimalDecodeTokenizerBackend: + """Tokenizer with standard encode but minimal decode (no skip_special_tokens).""" + + def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> list[int]: + return list(range(len(text.split()))) + + def decode(self, token_ids: list[int], **kwargs) -> str: + if kwargs: + raise TypeError(f"Unexpected kwargs: {kwargs}") + return " ".join(f"t{i}" for i in token_ids) + + bos_token_id = 1 + eos_token_id = 2 + + +class KwargsOnlyTokenizerBackend: + """Tokenizer that only accepts **kwargs (no named params beyond self/text).""" + + def encode(self, text, **kwargs): + return [0] + + def decode(self, token_ids, **kwargs): + return "decoded" + + bos_token_id = 0 + eos_token_id = 0 + + +# -- _supports_kwarg tests -- + + +class TestSupportsKwarg: + def test_detects_named_param(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "encode", "add_special_tokens") is True + + def test_rejects_missing_param(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "encode", "allow_special_tokens") is False + + def test_detects_allow_special_tokens(self): + backend = KimiLikeTokenizerBackend() + assert _supports_kwarg(backend, "encode", "allow_special_tokens") is True + + def test_rejects_add_special_tokens_on_kimi(self): + backend = KimiLikeTokenizerBackend() + assert _supports_kwarg(backend, "encode", "add_special_tokens") is False + + def test_detects_skip_special_tokens_on_standard(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "decode", "skip_special_tokens") is True + + def test_rejects_skip_special_tokens_on_kimi(self): + backend = KimiLikeTokenizerBackend() + assert _supports_kwarg(backend, "decode", "skip_special_tokens") is False + + def test_missing_method_returns_false(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "nonexistent_method", "anything") is False + + def test_kwargs_only_method(self): + backend = KwargsOnlyTokenizerBackend() + assert _supports_kwarg(backend, "encode", "add_special_tokens") is False + assert _supports_kwarg(backend, "encode", "allow_special_tokens") is False + + @pytest.mark.parametrize( + ("method", "kwarg", "expected"), + [ + ("encode", "text", True), + ("encode", "add_special_tokens", True), + ("decode", "token_ids", True), + ("decode", "skip_special_tokens", True), + ("encode", "nonexistent", False), + ], + ) + def test_standard_backend_parametrized(self, method, kwarg, expected): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, method, kwarg) is expected + + +# -- _apply_kwarg_overrides tests -- + + +class TestApplyKwargOverrides: + @staticmethod + def _make_tokenizer(backend) -> Tokenizer: + tok = Tokenizer() + tok._tokenizer = backend + tok._apply_kwarg_overrides() + return tok + + def test_standard_tokenizer_keeps_defaults(self): + tok = self._make_tokenizer(StandardTokenizerBackend()) + assert tok._encode_args == {"add_special_tokens": False} + assert tok._call_args == {"add_special_tokens": False} + assert tok._decode_args == {"skip_special_tokens": True} + + def test_kimi_like_overrides_encode_and_call_args(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + assert tok._encode_args == {"allow_special_tokens": False} + assert tok._call_args == {"allow_special_tokens": False} + + def test_kimi_like_clears_decode_args(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + assert tok._decode_args == {} + + def test_minimal_decode_clears_decode_args(self): + tok = self._make_tokenizer(MinimalDecodeTokenizerBackend()) + assert tok._encode_args == {"add_special_tokens": False} + assert tok._decode_args == {} + + def test_none_tokenizer_is_noop(self): + tok = Tokenizer() + tok._apply_kwarg_overrides() + assert tok._encode_args == {"add_special_tokens": False} + assert tok._call_args == {"add_special_tokens": False} + assert tok._decode_args == {"skip_special_tokens": True} + + +# -- End-to-end: encode/decode through Tokenizer wrapper -- + + +class TestKwargOverridesEndToEnd: + @staticmethod + def _make_tokenizer(backend) -> Tokenizer: + tok = Tokenizer() + tok._tokenizer = backend + tok._apply_kwarg_overrides() + return tok + + def test_standard_encode_passes_correct_kwargs(self): + tok = self._make_tokenizer(StandardTokenizerBackend()) + result = tok.encode("hello world") + assert isinstance(result, list) + + def test_standard_decode_passes_correct_kwargs(self): + tok = self._make_tokenizer(StandardTokenizerBackend()) + result = tok.decode([0, 1, 2]) + assert isinstance(result, str) + + def test_kimi_encode_does_not_raise(self): + """Kimi backend raises TypeError if unexpected kwargs are passed.""" + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok.encode("hello world") + assert isinstance(result, list) + + def test_kimi_decode_does_not_raise(self): + """Kimi backend raises TypeError if unexpected kwargs are passed.""" + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok.decode([0, 1, 2]) + assert isinstance(result, str) + + def test_kimi_call_does_not_raise(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok("hello world") + assert "input_ids" in result + + def test_standard_encode_without_override_would_fail_on_kimi(self): + """Verify that without overrides, Kimi backend rejects add_special_tokens.""" + tok = Tokenizer() + tok._tokenizer = KimiLikeTokenizerBackend() + # Don't call _apply_kwarg_overrides - defaults still have add_special_tokens + with pytest.raises(TypeError, match="Unexpected kwargs"): + tok.encode("hello world") + + def test_standard_decode_without_override_would_fail_on_kimi(self): + """Verify that without overrides, Kimi backend rejects skip_special_tokens.""" + tok = Tokenizer() + tok._tokenizer = KimiLikeTokenizerBackend() + with pytest.raises(TypeError, match="Unexpected kwargs"): + tok.decode([0, 1, 2]) + + def test_user_kwargs_override_defaults(self): + """User-provided kwargs should override the defaults.""" + tok = self._make_tokenizer(StandardTokenizerBackend()) + result = tok.encode("hello", add_special_tokens=True) + assert isinstance(result, list) + + def test_kimi_user_kwargs_override_defaults(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok.encode("hello", allow_special_tokens=True) + assert isinstance(result, list) From b65baa265a0dd8d4bcd7f02cd72bd488b4888016 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Fri, 27 Feb 2026 22:19:20 -0800 Subject: [PATCH 4/7] fix: Address code review feedback for Bailian trace loader Gracefully handle unrecognized 'type' field values during dataset type inference by skipping the explicit type shortcut and falling through to structural detection. This fixes Bailian traces (which use "type" for request type, not dataset type) auto-detecting correctly. Also updates CLI descriptions to reference both trace formats and adds tests for the type field fallback behavior. Signed-off-by: Anthony Casagrande --- docs/cli_options.md | 6 +- src/aiperf/common/config/input_config.py | 6 +- src/aiperf/dataset/composer/custom.py | 32 ++++----- .../dataset/loader/base_trace_loader.py | 1 + src/aiperf/dataset/loader/models.py | 3 +- tests/unit/dataset/loader/test_can_load.py | 66 +++++++++++++++++++ 6 files changed, 89 insertions(+), 25 deletions(-) diff --git a/docs/cli_options.md b/docs/cli_options.md index 767464382..414e9f174 100644 --- a/docs/cli_options.md +++ b/docs/cli_options.md @@ -258,11 +258,11 @@ Custom HTTP headers to include with every request. Specify as `Header:Value` pai #### `--input-file` `` -Path to file or directory containing benchmark dataset. Required when using `--custom-dataset-type`. Supported formats depend on dataset type: JSONL for `single_turn`/`multi_turn`, JSONL trace files for `mooncake_trace`, directories for `random_pool`. File is parsed according to `--custom-dataset-type` specification. +Path to file or directory containing benchmark dataset. Required when using `--custom-dataset-type`. Supported formats depend on dataset type: JSONL for `single_turn`/`multi_turn`, JSONL for `mooncake_trace`/`bailian_trace` (timestamped traces), directories for `random_pool`. File is parsed according to `--custom-dataset-type` specification. #### `--fixed-schedule` -Run requests according to timestamps specified in the input dataset. When enabled, AIPerf replays the exact timing pattern from the dataset. This mode is automatically enabled for `mooncake_trace` datasets. +Run requests according to timestamps specified in the input dataset. When enabled, AIPerf replays the exact timing pattern from the dataset. This mode is automatically enabled for trace datasets.
_Flag (no value required)_ #### `--fixed-schedule-auto-offset` @@ -292,7 +292,7 @@ Pre-configured public dataset to download and use for benchmarking (e.g., `share #### `--custom-dataset-type` `` -Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), `mooncake_trace` (timestamped trace files), `random_pool` (directory of reusable prompts). Requires `--input-file`. Mutually exclusive with `--public-dataset`. +Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), `mooncake_trace`/`bailian_trace` (timestamped trace files), `random_pool` (directory of reusable prompts). Requires `--input-file`. Mutually exclusive with `--public-dataset`.
_Choices: [`bailian_trace`, `mooncake_trace`, `multi_turn`, `random_pool`, `single_turn`]_ #### `--dataset-sampling-strategy` `` diff --git a/src/aiperf/common/config/input_config.py b/src/aiperf/common/config/input_config.py index 162ecf846..4c861aea8 100644 --- a/src/aiperf/common/config/input_config.py +++ b/src/aiperf/common/config/input_config.py @@ -192,7 +192,7 @@ def validate_goodput(self) -> Self: Any, Field( description="Path to file or directory containing benchmark dataset. Required when using `--custom-dataset-type`. " - "Supported formats depend on dataset type: JSONL for `single_turn`/`multi_turn`, JSONL trace files for `mooncake_trace`, " + "Supported formats depend on dataset type: JSONL for `single_turn`/`multi_turn`, JSONL for `mooncake_trace`/`bailian_trace` (timestamped traces), " "directories for `random_pool`. File is parsed according to `--custom-dataset-type` specification.", ), BeforeValidator(parse_file), @@ -208,7 +208,7 @@ def validate_goodput(self) -> Self: bool, Field( description="Run requests according to timestamps specified in the input dataset. When enabled, AIPerf replays " - "the exact timing pattern from the dataset. This mode is automatically enabled for `mooncake_trace` datasets." + "the exact timing pattern from the dataset. This mode is automatically enabled for trace datasets." ), CLIParameter( name=( @@ -278,7 +278,7 @@ def validate_goodput(self) -> Self: Field( description="Format specification for custom dataset provided via `--input-file`. Determines parsing logic and expected file structure. " "Options: `single_turn` (JSONL with single exchanges), `multi_turn` (JSONL with conversation history), " - "`mooncake_trace` (timestamped trace files), `random_pool` (directory of reusable prompts). " + "`mooncake_trace`/`bailian_trace` (timestamped trace files), `random_pool` (directory of reusable prompts). " "Requires `--input-file`. Mutually exclusive with `--public-dataset`.", ), CLIParameter( diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index e916a468d..8306a41ce 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -107,25 +107,21 @@ def _infer_type( Raises: ValueError: If the type field is invalid or no loader can handle the data format """ - # Check for explicit type field first (most efficient) - if data is not None and "type" in data: - try: - # Try to convert the type string to enum - explicit_type = CustomDatasetType(data["type"]) - LoaderClass = plugins.get_class( - PluginType.CUSTOM_DATASET_LOADER, explicit_type - ) - if not LoaderClass.can_load(data, filename): - raise ValueError( - f"Explicit type field {explicit_type} specified, but loader {LoaderClass.__name__} " - "cannot handle the data format. Please specify --custom-dataset-type explicitly." - ) - self.info(f"Using explicit type field: {explicit_type}") - return explicit_type - except (ValueError, KeyError) as e: + # Check for explicit type field first (most efficient). + # Skip values that aren't known dataset types (e.g. Bailian's "type": "text" + # is a request type, not a dataset type) and fall through to structural detection. + if data is not None and data.get("type") in CustomDatasetType: + explicit_type = CustomDatasetType(data["type"]) + LoaderClass = plugins.get_class( + PluginType.CUSTOM_DATASET_LOADER, explicit_type + ) + if not LoaderClass.can_load(data, filename): raise ValueError( - f"Invalid type field value: {data['type']}. Please specify --custom-dataset-type explicitly." - ) from e + f"Explicit type field {explicit_type} specified, but loader {LoaderClass.__name__} " + "cannot handle the data format. Please specify --custom-dataset-type explicitly." + ) + self.info(f"Using explicit type field: {explicit_type}") + return explicit_type detected_type = None for entry, LoaderClass in plugins.iter_all(PluginType.CUSTOM_DATASET_LOADER): diff --git a/src/aiperf/dataset/loader/base_trace_loader.py b/src/aiperf/dataset/loader/base_trace_loader.py index 7f9a602b3..5acb14509 100644 --- a/src/aiperf/dataset/loader/base_trace_loader.py +++ b/src/aiperf/dataset/loader/base_trace_loader.py @@ -91,6 +91,7 @@ def _preprocess_trace(self, trace: TraceT) -> None: Called after parsing but before filtering. Default is a no-op. """ + pass @abstractmethod def _group_traces(self, items: list[TraceT]) -> dict[str, list[TraceT]]: diff --git a/src/aiperf/dataset/loader/models.py b/src/aiperf/dataset/loader/models.py index e3e8d6b4f..3517fffd7 100644 --- a/src/aiperf/dataset/loader/models.py +++ b/src/aiperf/dataset/loader/models.py @@ -263,7 +263,8 @@ class BailianTrace(AIPerfBaseModel): belong to the same session and are ordered by ``turn``. Important: Bailian traces use a block size of 16 tokens per salted SipHash - block. Set ``--isl-block-size 16`` when using this format. + block. Use ``--isl-block-size 16`` when using this format (this is set + automatically in CLI flows). Examples: - Root request: ``{"chat_id": 159, "parent_chat_id": -1, "timestamp": 61.114, "input_length": 521, "output_length": 132, "type": "text", "turn": 1, "hash_ids": [1089, 1090, 1091]}`` diff --git a/tests/unit/dataset/loader/test_can_load.py b/tests/unit/dataset/loader/test_can_load.py index 4c1b5c14b..42a19df54 100644 --- a/tests/unit/dataset/loader/test_can_load.py +++ b/tests/unit/dataset/loader/test_can_load.py @@ -6,6 +6,7 @@ import pytest from pytest import param +from aiperf.dataset.loader.bailian_trace import BailianTraceDatasetLoader from aiperf.dataset.loader.mooncake_trace import MooncakeTraceDatasetLoader from aiperf.dataset.loader.multi_turn import MultiTurnDatasetLoader from aiperf.dataset.loader.random_pool import RandomPoolDatasetLoader @@ -167,6 +168,8 @@ class TestCustomDatasetComposerInferType: param({"input_length": 100, "output_length": 50}, None, CustomDatasetType.MOONCAKE_TRACE, id="mooncake_input_length"), param({"type": "mooncake_trace", "input_length": 100}, None, CustomDatasetType.MOONCAKE_TRACE, id="mooncake_explicit"), param({"text_input": "Hello"}, None, CustomDatasetType.MOONCAKE_TRACE, id="mooncake_text_input"), + param({"type": "bailian_trace", "chat_id": 1, "timestamp": 0.0, "input_length": 100, "output_length": 50}, None, CustomDatasetType.BAILIAN_TRACE, id="bailian_explicit"), + param({"chat_id": 1, "timestamp": 0.0, "input_length": 100, "output_length": 50, "type": "text"}, None, CustomDatasetType.BAILIAN_TRACE, id="bailian_structural_with_request_type"), ], ) # fmt: skip def test_infer_from_data( @@ -201,6 +204,15 @@ def test_infer_from_data_raises(self, create_user_config_and_composer, data): with pytest.raises(ValueError, match="No loader can handle"): composer._infer_type(data) + def test_infer_explicit_type_loader_rejects_raises( + self, create_user_config_and_composer + ): + """Test that a recognized type field with incompatible data raises ValueError.""" + _, composer = create_user_config_and_composer() + data = {"type": "single_turn", "input_length": 100} + with pytest.raises(ValueError, match="cannot handle the data format"): + composer._infer_type(data) + def test_infer_random_pool_with_directory(self, create_user_config_and_composer): """Test inferring RandomPool with directory path.""" _, composer = create_user_config_and_composer() @@ -357,3 +369,57 @@ def test_directory_path_uniquely_identifies_random_pool(self): assert SingleTurnDatasetLoader.can_load(data=None, filename=temp_path) is False # fmt: skip assert MultiTurnDatasetLoader.can_load(data=None, filename=temp_path) is False # fmt: skip assert MooncakeTraceDatasetLoader.can_load(data=None, filename=temp_path) is False # fmt: skip + assert BailianTraceDatasetLoader.can_load(data=None, filename=temp_path) is False # fmt: skip + + +class TestUnrecognizedTypeFieldFallback: + """Tests for graceful handling of unrecognized 'type' field values. + + Some trace formats (e.g. Bailian) include a 'type' field that represents + something other than the dataset type (e.g. request type: text/search/image). + The inference logic should fall back to structural detection instead of raising.""" + + def test_bailian_type_field_falls_through_to_structural_detection( + self, create_user_config_and_composer + ): + """Bailian data with type='text' should infer as bailian_trace, not raise.""" + _, composer = create_user_config_and_composer() + data = { + "chat_id": 159, + "parent_chat_id": -1, + "timestamp": 61.114, + "input_length": 521, + "output_length": 132, + "type": "text", + "turn": 1, + "hash_ids": [1089, 1090, 1091], + } + result = composer._infer_type(data) + assert result == CustomDatasetType.BAILIAN_TRACE + + @pytest.mark.parametrize( + "type_value", + [ + param("text", id="text"), + param("search", id="search"), + param("image", id="image"), + param("file", id="file"), + param("unknown_garbage", id="garbage"), + ], + ) # fmt: skip + def test_unrecognized_type_field_does_not_raise( + self, create_user_config_and_composer, type_value + ): + """Unrecognized type field values should not raise during inference.""" + _, composer = create_user_config_and_composer() + data = { + "chat_id": 1, + "parent_chat_id": -1, + "timestamp": 0.0, + "input_length": 100, + "output_length": 50, + "type": type_value, + "turn": 1, + } + result = composer._infer_type(data) + assert result == CustomDatasetType.BAILIAN_TRACE From 3617fdb9b67b7c2b123bfb34d479e627fa3f4a3a Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Fri, 27 Feb 2026 22:38:41 -0800 Subject: [PATCH 5/7] address code rabbit --- src/aiperf/plugin/schema/plugins.schema.json | 1 + src/aiperf/plugin/schema/schemas.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/aiperf/plugin/schema/plugins.schema.json b/src/aiperf/plugin/schema/plugins.schema.json index 258c3651c..e3c0e2fc8 100644 --- a/src/aiperf/plugin/schema/plugins.schema.json +++ b/src/aiperf/plugin/schema/plugins.schema.json @@ -533,6 +533,7 @@ "default_block_size": { "anyOf": [ { + "minimum": 1, "type": "integer" }, { diff --git a/src/aiperf/plugin/schema/schemas.py b/src/aiperf/plugin/schema/schemas.py index 09729bd2e..1e052f49a 100644 --- a/src/aiperf/plugin/schema/schemas.py +++ b/src/aiperf/plugin/schema/schemas.py @@ -337,6 +337,7 @@ class CustomDatasetLoaderMetadata(BaseModel): ) default_block_size: int | None = Field( default=None, + ge=1, description=( "Default token block size for hash-based prompt caching. " "Used when the user does not explicitly set --isl-block-size. " From a327d90a28219c42dfe621c649c75e2292e981f5 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Fri, 27 Feb 2026 22:48:22 -0800 Subject: [PATCH 6/7] feat: 2-3x faster trace dataset loading via HashIdRandomGenerator Add HashIdRandomGenerator that deterministically seeds per (trace_id, hash_id) pair, enabling parallel token generation across workers without lock contention or cache coordination. Extract parallel_convert module to leverage multiprocessing with shared-memory token corpus. Stream conversations through composers to the backing store instead of materializing the full dataset in memory. Signed-off-by: Anthony Casagrande --- src/aiperf/common/hash_id_random_generator.py | 77 ++ src/aiperf/dataset/composer/base.py | 83 +- src/aiperf/dataset/composer/custom.py | 19 +- src/aiperf/dataset/composer/synthetic.py | 21 +- .../dataset/composer/synthetic_rankings.py | 13 +- src/aiperf/dataset/dataset_manager.py | 56 +- .../dataset/generator/parallel_decode.py | 139 --- src/aiperf/dataset/generator/prompt.py | 125 +-- .../dataset/loader/base_trace_loader.py | 157 +-- src/aiperf/dataset/loader/parallel_convert.py | 286 ++++++ src/aiperf/dataset/protocols.py | 3 +- .../dataset/synthesis/rolling_hasher.py | 4 +- .../common/test_hash_id_random_generator.py | 277 ++++++ .../dataset/composer/test_base_composer.py | 38 +- .../dataset/composer/test_custom_composer.py | 24 +- .../composer/test_synthetic_composer.py | 42 +- .../test_synthetic_rankings_composer.py | 10 +- tests/unit/dataset/conftest.py | 55 +- .../dataset/generator/test_parallel_decode.py | 269 ----- .../generator/test_prompt_generator.py | 121 +-- .../unit/dataset/loader/test_bailian_trace.py | 71 +- .../dataset/loader/test_base_trace_loader.py | 605 +++++++++++ .../dataset/loader/test_parallel_convert.py | 938 ++++++++++++++++++ tests/unit/dataset/loader/test_trace.py | 81 +- tests/unit/dataset/test_dataset_manager.py | 20 +- .../test_dataset_manager_inputs_json.py | 5 +- 26 files changed, 2552 insertions(+), 987 deletions(-) create mode 100644 src/aiperf/common/hash_id_random_generator.py delete mode 100644 src/aiperf/dataset/generator/parallel_decode.py create mode 100644 src/aiperf/dataset/loader/parallel_convert.py create mode 100644 tests/unit/common/test_hash_id_random_generator.py delete mode 100644 tests/unit/dataset/generator/test_parallel_decode.py create mode 100644 tests/unit/dataset/loader/test_base_trace_loader.py create mode 100644 tests/unit/dataset/loader/test_parallel_convert.py diff --git a/src/aiperf/common/hash_id_random_generator.py b/src/aiperf/common/hash_id_random_generator.py new file mode 100644 index 000000000..d8a3d5380 --- /dev/null +++ b/src/aiperf/common/hash_id_random_generator.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Hash-ID-based random generator for parallel processing with reproducibility. + +Enables parallel processing of traces with hash_ids while maintaining +reproducibility. Each (trace_id, hash_id) pair produces a deterministic random +sequence regardless of worker count or processing order. + +Architecture: + Global Seed -> Base RNG -> (trace_id, hash_id) -> Deterministic tokens + +The trace_id (typically a content hash of the trace file) ensures that different +trace files with overlapping hash_id values produce different content, while the +same trace file always produces identical results. +""" + +import hashlib + +from aiperf.common.random_generator import RandomGenerator + +__all__ = ["HashIdRandomGenerator"] + + +class _DisabledNumpyRNG: + """Raises on any attribute access to prevent NumPy RNG usage.""" + + def __getattr__(self, name): + raise RuntimeError( + "HashIdRandomGenerator does not support NumPy RNG operations. " + "Use Python RNG methods (randrange, choice, etc.) instead." + ) + + +class HashIdRandomGenerator(RandomGenerator): + """RandomGenerator that re-seeds deterministically per (trace_id, hash_id). + + Designed for parallel processing where multiple workers need to generate + identical content for the same hash_id within a trace file. + + Thread Safety: + NOT thread-safe. Each worker process must have its own instance. + """ + + @classmethod + def from_base_rng(cls, base_rng: RandomGenerator) -> "HashIdRandomGenerator": + """Create from a base RandomGenerator (typically from rng.derive()).""" + base_seed = base_rng.seed or base_rng.randrange(0, 2**64) + return cls(base_seed, _internal=True) + + def __init__(self, base_seed: int, *, _internal: bool = False): + super().__init__(base_seed, _internal=_internal) + self._numpy_rng = _DisabledNumpyRNG() + self._trace_id: str = "" + + def set_trace_id(self, trace_id: str) -> None: + """Set trace identifier to scope hash_ids to a specific trace file. + + Args: + trace_id: Content hash or unique identifier for the trace file. + Different trace files must use different trace_ids. + """ + self._trace_id = trace_id + + def reseed_for_hash_id(self, hash_id: int) -> None: + """Re-seed RNG deterministically for a specific hash_id. + + After calling, all random operations use the derived seed until + the next reseed_for_hash_id call. + + Args: + hash_id: KV block hash ID from trace data. + """ + seed_bytes = hashlib.sha256( + f"{self.seed}:{self._trace_id}:{hash_id}".encode() + ).digest() + self._python_rng.seed(int.from_bytes(seed_bytes[:8], "big")) diff --git a/src/aiperf/dataset/composer/base.py b/src/aiperf/dataset/composer/base.py index 9b12bd8ef..3aa0b4cb6 100644 --- a/src/aiperf/dataset/composer/base.py +++ b/src/aiperf/dataset/composer/base.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Iterator from aiperf.common import random_generator as rng from aiperf.common.config import UserConfig @@ -41,12 +42,11 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer | None, **kwargs): self._turn_sequence_cache: dict[int, tuple[int, int]] = {} @abstractmethod - def create_dataset(self) -> list[Conversation]: - """ - Create a set of conversation objects from the given configuration. + def create_dataset(self) -> Iterator[Conversation]: + """Create conversation objects from the given configuration. - Returns: - list[Conversation]: A list of conversation objects. + Yields Conversation objects one at a time so callers can stream + them directly to the backing store without materializing the full list. """ ... @@ -151,63 +151,38 @@ def prefix_prompt_enabled(self) -> bool: and self.config.input.prompt.prefix_prompt.length > 0 ) - def _finalize_conversations(self, conversations: list[Conversation]) -> None: - """Finalize conversations by adding conversation-level context prompts. - - Injects shared system prompts and per-conversation user context prompts. - Note: Turn-level finalization (_finalize_turn) is handled by each composer - according to its needs (eager in synthetic, lazy in custom). - - Args: - conversations: List of conversations to finalize - """ - self._inject_context_prompts(conversations) + def _finalize_conversation( + self, conversation: Conversation, session_index: int + ) -> None: + """Inject context prompts into a single conversation. - def _inject_context_prompts(self, conversations: list[Conversation]) -> None: - """Inject shared system and user context prompts into conversations. - - Sets the system_message and context_message fields on Conversation objects, - which endpoint formatters will prepend to the first turn when creating payloads. + Sets the system_message and user_context_message fields, which + endpoint formatters prepend to the first turn when creating payloads. Args: - conversations: List of conversations to inject prompts into + conversation: Conversation to finalize. + session_index: Position of this conversation in the dataset + (used for per-session user context prompt generation). """ if self.prompt_generator is None: return config = self.config.input.prompt.prefix_prompt - has_shared_system = config.shared_system_prompt_length is not None - has_user_context = config.user_context_prompt_length is not None - if not (has_shared_system or has_user_context): - return + if config.shared_system_prompt_length is not None: + prompt = self._get_shared_system_prompt() + if prompt: + conversation.system_message = prompt - self.debug( - lambda: f"Injecting context prompts into {len(conversations)} conversations" - ) - - # Get shared system prompt once (same for all sessions) - shared_system_prompt = None - if has_shared_system: - shared_system_prompt = self.prompt_generator.get_shared_system_prompt() - - # Iterate through conversations and set conversation-level fields - for session_index, conversation in enumerate(conversations): - # Set shared system prompt - if shared_system_prompt: - conversation.system_message = shared_system_prompt - self.trace( - lambda conv=conversation: f"Set system_message on conversation {conv.session_id}" - ) + if config.user_context_prompt_length is not None: + conversation.user_context_message = ( + self.prompt_generator.generate_user_context_prompt(session_index) + ) - # Set user context prompt (unique per session) - if has_user_context: - user_context = self.prompt_generator.generate_user_context_prompt( - session_index - ) - conversation.user_context_message = user_context - self.trace( - lambda idx=session_index, - conv=conversation: f"Set user_context_message for session {idx} " - f"(conversation {conv.session_id})" - ) + def _get_shared_system_prompt(self) -> str | None: + """Return the shared system prompt, computing and caching on first call.""" + if not hasattr(self, "_shared_system_prompt_cache"): + self._shared_system_prompt_cache: str | None = ( + self.prompt_generator.get_shared_system_prompt() + ) + return self._shared_system_prompt_cache diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index 8306a41ce..c452e4a31 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from collections.abc import Iterator from pathlib import Path from typing import Any @@ -19,11 +20,12 @@ class CustomDatasetComposer(BaseDatasetComposer): def __init__(self, config: UserConfig, tokenizer: Tokenizer | None): super().__init__(config, tokenizer) - def create_dataset(self) -> list[Conversation]: + def create_dataset(self) -> Iterator[Conversation]: """Create conversations from a file or directory. - Returns: - list[Conversation]: A list of conversation objects. + Yields conversations one at a time, finalizing each inline so + the caller can stream them directly to the backing store without + materializing the full list. """ # TODO: (future) for K8s, we need to transfer file data from SC (across node) check_file_exists(self.config.input.file) @@ -44,14 +46,13 @@ def create_dataset(self) -> list[Conversation]: dataset = self.loader.load_dataset() conversations = self.loader.convert_to_conversations(dataset) - # Finalize all turns with metadata (custom datasets need this) - for conversation in conversations: + for session_index, conversation in enumerate(conversations): + # Finalize all turns with metadata (custom datasets need this) for turn in conversation.turns: self._finalize_turn(turn) - - # Finalize conversation-level context prompts - self._finalize_conversations(conversations) - return conversations + # Finalize conversation-level context prompts + self._finalize_conversation(conversation, session_index) + yield conversation def _infer_dataset_type(self, file_path: str) -> CustomDatasetType: """Infer the custom dataset type from the input file. diff --git a/src/aiperf/dataset/composer/synthetic.py b/src/aiperf/dataset/composer/synthetic.py index b8991aea2..6ce5895f2 100644 --- a/src/aiperf/dataset/composer/synthetic.py +++ b/src/aiperf/dataset/composer/synthetic.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from collections.abc import Iterator + from aiperf.common import random_generator as rng from aiperf.common.config import UserConfig from aiperf.common.config.config_defaults import InputDefaults @@ -39,17 +41,13 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer | None): "setting the mean to a positive value." ) - def create_dataset(self) -> list[Conversation]: + def create_dataset(self) -> Iterator[Conversation]: """Create a synthetic conversation dataset from the given configuration. - It generates a set of conversations with a varying number of turns, - where each turn contains synthetic text, image, and audio payloads. - - Returns: - list[Conversation]: A list of conversation objects. + Yields conversations one at a time, where each conversation contains + a varying number of turns with synthetic text, image, and audio payloads. """ - conversations = [] - for _ in range(self.config.input.conversation.num_dataset_entries): + for session_index in range(self.config.input.conversation.num_dataset_entries): conversation = Conversation(session_id=self.session_id_generator.next()) num_turns = self._turn_sampler_rng.sample_positive_normal_integer( @@ -61,11 +59,10 @@ def create_dataset(self) -> list[Conversation]: for turn_idx in range(num_turns): turn = self._create_turn(is_first=(turn_idx == 0)) conversation.turns.append(turn) - conversations.append(conversation) - # Finalize all conversations (turn metadata + context prompts) - self._finalize_conversations(conversations) - return conversations + # Finalize conversation-level context prompts + self._finalize_conversation(conversation, session_index) + yield conversation def _create_turn(self, is_first: bool) -> Turn: """Create a turn object that contains synthetic payloads to send. diff --git a/src/aiperf/dataset/composer/synthetic_rankings.py b/src/aiperf/dataset/composer/synthetic_rankings.py index ab90c9f07..c7acffb63 100644 --- a/src/aiperf/dataset/composer/synthetic_rankings.py +++ b/src/aiperf/dataset/composer/synthetic_rankings.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from collections.abc import Iterator + from aiperf.common import random_generator as rng from aiperf.common.config import InputDefaults, UserConfig from aiperf.common.models import Conversation, Text, Turn @@ -33,26 +35,25 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer | None): f"Using default sampling strategy for synthetic rankings dataset: {InputDefaults.DATASET_SAMPLING_STRATEGY}" ) - def create_dataset(self) -> list[Conversation]: + def create_dataset(self) -> Iterator[Conversation]: """Generate synthetic dataset for the rankings endpoint. Each conversation contains one turn with one query and multiple passages. """ - conversations: list[Conversation] = [] num_entries = self.config.input.conversation.num_dataset_entries num_passages_mean = self.config.input.rankings.passages.mean num_passages_std = self.config.input.rankings.passages.stddev - for _ in range(num_entries): + for session_index in range(num_entries): num_passages = self._passages_rng.sample_positive_normal_integer( num_passages_mean, num_passages_std ) conversation = Conversation(session_id=self.session_id_generator.next()) turn = self._create_turn(num_passages=num_passages) conversation.turns.append(turn) - conversations.append(conversation) - - return conversations + # Finalize conversation-level context prompts + self._finalize_conversation(conversation, session_index) + yield conversation def _create_turn(self, num_passages: int) -> Turn: """Create a single ranking turn with one synthetic query and multiple synthetic passages. diff --git a/src/aiperf/dataset/dataset_manager.py b/src/aiperf/dataset/dataset_manager.py index e16339d0b..e9ac8ae26 100644 --- a/src/aiperf/dataset/dataset_manager.py +++ b/src/aiperf/dataset/dataset_manager.py @@ -5,6 +5,7 @@ import asyncio import gc import time +from collections.abc import Iterable from typing import TYPE_CHECKING import orjson @@ -31,6 +32,7 @@ from aiperf.common.mixins import ReplyClientMixin from aiperf.common.models import ( Conversation, + ConversationMetadata, DatasetClientMetadata, DatasetMetadata, InputsFile, @@ -86,11 +88,9 @@ def __init__( ) self.user_config = user_config self.tokenizer: Tokenizer | None = None - self.dataset: dict[ - str, Conversation - ] = {} # conversation ID -> Conversation mapping self.dataset_metadata: DatasetMetadata | None = None self._conversation_ids_cache: list[str] = [] + self._conversation_count: int = 0 self.dataset_configured = asyncio.Event() # In Kubernetes mode, use compress_only to stream directly to compressed files. @@ -132,7 +132,6 @@ async def _profile_configure_command( self.info(lambda: f"Configuring dataset for {self.service_id}") begin = time.perf_counter() await self._configure_dataset() - await self._generate_inputs_json_file() await self._configure_dataset_client_and_free_memory() duration = time.perf_counter() - begin @@ -140,8 +139,6 @@ async def _profile_configure_command( async def _configure_dataset_client_and_free_memory(self) -> None: """Configure the dataset client for serving fallback requests, then free memory.""" - conversation_count = len(self.dataset) - if not self._compress_only: client_metadata = self._backing_store.get_client_metadata() ClientStoreClass = plugins.get_class( @@ -149,25 +146,22 @@ async def _configure_dataset_client_and_free_memory(self) -> None: ) self._dataset_client = ClientStoreClass(client_metadata=client_metadata) await self._dataset_client.initialize() + await self._generate_inputs_json_file() self.dataset_configured.set() - # Reassign to new empty containers (not .clear()) to release object references, - # then run gc.collect() twice to ensure circular references are cleaned up. - self.dataset = {} - self._conversation_ids_cache = [] + # Run gc.collect() twice to ensure circular references are cleaned up. gc.collect() gc.collect() if self._compress_only: self.info( - f"Kubernetes mode: skipped local client, freed {conversation_count} " - "conversations from memory (workers handle all requests)" + f"Kubernetes mode: skipped local client, compressed {self._conversation_count} " + "conversations into backing store)" ) else: self.info( - f"Dataset client initialized and freed {conversation_count} " - "conversations from memory" + f"Dataset client initialized, {self._conversation_count} conversations in backing store" ) async def _configure_tokenizer(self) -> None: @@ -185,7 +179,7 @@ async def _configure_tokenizer(self) -> None: resolve_alias=tokenizer_config.should_resolve_alias, ) - def _generate_input_payloads( + async def _generate_input_payloads( self, model_endpoint: ModelEndpointInfo, ) -> InputsFile: @@ -201,7 +195,10 @@ def _generate_input_payloads( f"class: {endpoint.__class__.__name__}", ) session_payloads_map: dict[str, list] = {} - for conversation in self.dataset.values(): + for conv_metadata in self.dataset_metadata.conversations: + conversation = await self._dataset_client.get_conversation( + conv_metadata.conversation_id + ) session_id = conversation.session_id if session_id not in session_payloads_map: session_payloads_map[session_id] = [] @@ -245,7 +242,7 @@ async def _generate_inputs_json_file(self) -> None: file_path.parent.mkdir(parents=True, exist_ok=True) model_endpoint = ModelEndpointInfo.from_user_config(self.user_config) - inputs = self._generate_input_payloads(model_endpoint) + inputs = await self._generate_input_payloads(model_endpoint) temp_file_path.write_bytes( orjson.dumps( @@ -293,7 +290,7 @@ async def _load_public_dataset(self) -> list[Conversation]: ) return await loader.convert_to_conversations(dataset) - def _load_custom_dataset(self) -> list[Conversation]: + def _load_custom_dataset(self) -> Iterable[Conversation]: ComposerClass = plugins.get_class( PluginType.DATASET_COMPOSER, ComposerType.CUSTOM ) @@ -303,7 +300,7 @@ def _load_custom_dataset(self) -> list[Conversation]: def _is_rankings_endpoint(self, endpoint_type: str) -> bool: return "rankings" in endpoint_type.lower() - def _load_synthetic_dataset(self) -> list[Conversation]: + def _load_synthetic_dataset(self) -> Iterable[Conversation]: endpoint_type = self.user_config.endpoint.type if self._is_rankings_endpoint(endpoint_type): @@ -334,16 +331,17 @@ async def _configure_dataset(self) -> None: else: conversations = self._load_synthetic_dataset() - self.dataset = {conv.session_id: conv for conv in conversations} - self._conversation_ids_cache = [ - conversation.session_id for conversation in conversations - ] - - # Initialize backing store and stream conversations to mmap files - # Workers read directly from these files + # Stream conversations to backing store and collect metadata on the fly. + # Each conversation is written to mmap and then can be GC'd — we never + # hold the full dataset in memory. await self._backing_store.initialize() - conversations_dict = {conv.session_id: conv for conv in conversations} - await self._backing_store.add_conversations(conversations_dict) + metadata_list: list[ConversationMetadata] = [] + for conversation in conversations: + await self._backing_store.add_conversation( + conversation.session_id, conversation + ) + metadata_list.append(conversation.metadata()) + self._conversation_count = len(metadata_list) await self._backing_store.finalize() # In Kubernetes mode (compress_only=True), files are already compressed # during finalize(). In local mode, uncompressed files are used directly. @@ -362,7 +360,7 @@ async def _configure_dataset(self) -> None: ) self.dataset_metadata = DatasetMetadata( - conversations=[conversation.metadata() for conversation in conversations], + conversations=metadata_list, sampling_strategy=self.user_config.input.dataset_sampling_strategy, ) self.info( diff --git a/src/aiperf/dataset/generator/parallel_decode.py b/src/aiperf/dataset/generator/parallel_decode.py deleted file mode 100644 index f8558261c..000000000 --- a/src/aiperf/dataset/generator/parallel_decode.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Parallel decode utilities for batch tokenizer operations. - -This module provides functions to decode multiple token sequences in parallel -using ProcessPoolExecutor, bypassing Python's GIL for CPU-bound tokenizer -operations. - -The daemon flag on the current process is temporarily cleared because Python's -multiprocessing refuses to spawn children from daemon processes, and AIPerf -services run as daemons. -""" - -import multiprocessing as mp -import os -from concurrent.futures import ProcessPoolExecutor -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from aiperf.common.tokenizer import Tokenizer - -# Module-level tokenizer for worker processes (initialized once per worker) -_worker_tokenizer: "Tokenizer | None" = None -_worker_tokenizer_name: str | None = None - - -def _init_worker(tokenizer_name: str) -> None: - """Initialize tokenizer in worker process. - - This function is called once per worker process when the ProcessPoolExecutor - starts. It loads the tokenizer so subsequent decode calls don't need to reload it. - - Args: - tokenizer_name: Name or path of the pretrained tokenizer to load. - """ - global _worker_tokenizer, _worker_tokenizer_name - if _worker_tokenizer is None or _worker_tokenizer_name != tokenizer_name: - # The main process already downloaded and cached the tokenizer, so force - # offline mode to skip network requests and alias resolution. - os.environ["HF_HUB_OFFLINE"] = "1" - os.environ["TRANSFORMERS_OFFLINE"] = "1" - - from aiperf.common.tokenizer import Tokenizer - - _worker_tokenizer = Tokenizer.from_pretrained( - tokenizer_name, resolve_alias=False - ) - _worker_tokenizer_name = tokenizer_name - - -def _decode_tokens(token_ids: list[int]) -> str: - """Decode tokens using worker's tokenizer. - - Args: - token_ids: List of token IDs to decode. - - Returns: - Decoded string. - - Raises: - RuntimeError: If worker tokenizer is not initialized. - """ - if _worker_tokenizer is None: - raise RuntimeError("Worker tokenizer not initialized") - return _worker_tokenizer.decode(token_ids, skip_special_tokens=False) - - -def parallel_decode( - token_sequences: list[list[int]], - tokenizer_name: str, - max_workers: int | None = None, - chunksize: int = 50, -) -> list[str]: - """Decode multiple token sequences in parallel using ProcessPoolExecutor. - - This function is optimized for batch decoding of many token sequences. - For small batches (< 10 sequences), it falls back to sequential decoding - to avoid process spawn overhead. - - Args: - token_sequences: List of token ID lists to decode. - tokenizer_name: Name or path of the pretrained tokenizer to use in workers. - max_workers: Number of worker processes. Defaults to min(cpu_count, 8). - chunksize: Number of items per worker batch for map(). - - Returns: - List of decoded strings in the same order as input. - """ - if not token_sequences: - return [] - - # For small batches, sequential is faster (avoid process overhead) - if len(token_sequences) < 10: - from aiperf.common.tokenizer import Tokenizer - - tokenizer = Tokenizer.from_pretrained(tokenizer_name) - return [ - tokenizer.decode(tokens, skip_special_tokens=False) - for tokens in token_sequences - ] - - num_workers = max_workers or min(mp.cpu_count() or 4, 8) - - # Temporarily clear the daemon flag so ProcessPoolExecutor can spawn workers. - # Python's multiprocessing refuses to spawn children from daemon processes, - # and AIPerf services run as daemons. - # - # Alternatives considered: - # - billiard: bypasses the daemon restriction natively, but crashes with - # BrokenProcessPool on macOS due to terminal FD inheritance issues. - # - loky: robust reusable executor, but still requires the same daemon flag - # hack, so no advantage over stdlib. - was_daemon = mp.current_process().daemon - try: - if was_daemon: - _set_daemon(False) - with ProcessPoolExecutor( - max_workers=num_workers, - initializer=_init_worker, - initargs=(tokenizer_name,), - ) as executor: - results = list( - executor.map(_decode_tokens, token_sequences, chunksize=chunksize) - ) - finally: - if was_daemon: - _set_daemon(True) - - return results - - -def _set_daemon(daemon: bool) -> None: - """Set the daemon flag on the current process.""" - try: - mp.current_process().daemon = daemon - except AssertionError: - # Fallback to using the internal _config dictionary if assertions are enabled - mp.current_process()._config["daemon"] = daemon diff --git a/src/aiperf/dataset/generator/prompt.py b/src/aiperf/dataset/generator/prompt.py index 9c76d4a32..f3ada6133 100644 --- a/src/aiperf/dataset/generator/prompt.py +++ b/src/aiperf/dataset/generator/prompt.py @@ -13,12 +13,49 @@ InvalidStateError, NotInitializedError, ) +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator from aiperf.common.tokenizer import Tokenizer from aiperf.dataset.generator.base import BaseGenerator DEFAULT_CORPUS_FILE = "assets/shakespeare.txt" +def sample_tokens_from_corpus( + corpus: list[int], + num_tokens: int, + rng_to_use: rng.RandomGenerator, + sep_token: int | None = None, +) -> list[int]: + """Sample tokens from a corpus with optional separator token. + + Args: + corpus: Token corpus as a list of token IDs. + num_tokens: Number of tokens to sample. + rng_to_use: RandomGenerator for sampling start position. + sep_token: Optional separator token to prepend (BOS/EOS). + + Returns: + List of sampled token IDs. + """ + corpus_len = len(corpus) + tokens: list[int] = [] + + if sep_token is not None: + tokens.append(sep_token) + num_tokens -= 1 + + start = rng_to_use.randrange(corpus_len) + end = start + num_tokens + + if end <= corpus_len: + tokens.extend(corpus[start:end]) + else: + tokens.extend(corpus[start:]) + tokens.extend(corpus[: end - corpus_len]) + + return tokens + + class PromptGenerator(BaseGenerator): """A class for generating synthetic prompts from a text corpus. @@ -47,15 +84,16 @@ def __init__(self, config: PromptConfig, tokenizer: Tokenizer, **kwargs): self._corpus_rng = rng.derive("dataset.prompt.corpus") self._prefix_rng = rng.derive("dataset.prompt.prefix") + # Hash-ID-based RNG for deterministic per-hash_id generation. + # Re-seeds itself for each hash_id, enabling identical random + # sequences per hash block regardless of processing order or workers. + self._hash_id_corpus_rng = HashIdRandomGenerator.from_base_rng(self._corpus_rng) + super().__init__(config=config, tokenizer=tokenizer, **kwargs) # Cached prompts: block ID -> list of tokens self._cache: dict[int, list[int]] = {} - # Decoded string cache: (hash_ids tuple, num_tokens, block_size) -> decoded string - # This avoids redundant tokenizer.decode() calls for repeated hash_id combinations - self._decoded_cache: dict[tuple[tuple[int, ...], int, int], str] = {} - # TODO: move this under initialize() method # Initialize corpus if not already done if self._tokenized_corpus is None: @@ -154,6 +192,7 @@ def generate( mean: int | None = None, stddev: int | None = None, hash_ids: list[int] | None = None, + block_size: int | None = None, ) -> str: """Generate a synthetic prompt with the configuration parameters. Serves as a wrapper around other internal methods to provide a unified interface. @@ -162,6 +201,8 @@ def generate( mean: The mean of the normal distribution. stddev: The standard deviation of the normal distribution. hash_ids: A list of hash indices used for token reuse. + block_size: Override block size for hash-ID generation. + Defaults to config value or :data:`InputTokensDefaults.BLOCK_SIZE`. Returns: A synthetic prompt as a string. @@ -169,10 +210,12 @@ def generate( if hash_ids: if mean is None: raise ValueError("mean must be provided when hash_ids is set.") - block_size = ( - self.config.input_tokens.block_size or InputTokensDefaults.BLOCK_SIZE + effective_block_size = ( + block_size + or self.config.input_tokens.block_size + or InputTokensDefaults.BLOCK_SIZE ) - return self._generate_cached_prompt(mean, hash_ids, block_size) + return self._generate_cached_prompt(mean, hash_ids, effective_block_size) num_tokens = self.calculate_num_tokens(mean, stddev) return self.generate_prompt(num_tokens) @@ -208,48 +251,12 @@ def _generate_cached_prompt( hash_ids: list[int], block_size: int, ) -> str: - """ - Generate a prompt containing exactly `num_tokens` by reusing previously generated prompts - stored in `_cache`. Each hash index in `hash_ids` corresponds to a block of - `block_size` tokens. If a hash index is found in `_cache`, its stored prompt is reused. - Otherwise, a new prompt is generated using `generate_prompt()` and stored in `_cache`. + """Generate a prompt by reusing previously generated token blocks. - Args: - num_tokens: The number of tokens required in the prompt. - hash_ids: A list of hash IDs to use for token reuse. - block_size: The number of tokens allocated per hash block. - - Returns: - str: A synthetic prompt as a string. - - Raises: - ConfigurationError: If the input parameters are not compatible. - """ - # Check decoded string cache first to avoid redundant decode calls - cache_key = (tuple(hash_ids), num_tokens, block_size) - if cache_key in self._decoded_cache: - return self._decoded_cache[cache_key] - - # Build token sequence using _build_token_sequence (shared logic) - final_prompt = self._build_token_sequence(num_tokens, hash_ids, block_size) - - # Decode and cache the result - decoded = self.tokenizer.decode(final_prompt, skip_special_tokens=False) - self._decoded_cache[cache_key] = decoded - return decoded - - def _build_token_sequence( - self, - num_tokens: int, - hash_ids: list[int], - block_size: int, - ) -> list[int]: - """ - Build a token sequence without decoding. Used for batch parallel decode. - - Each hash index in `hash_ids` corresponds to a block of `block_size` tokens. - If a hash index is found in `_cache`, its stored tokens are reused. - Otherwise, new tokens are sampled and stored in `_cache`. + Each hash_id in `hash_ids` corresponds to a block of `block_size` tokens. + If a hash_id is found in `_cache`, its stored tokens are reused. Otherwise, + tokens are generated deterministically using HashIdRandomGenerator re-seeding + and stored in `_cache`. Args: num_tokens: The number of tokens required in the prompt. @@ -257,7 +264,7 @@ def _build_token_sequence( block_size: The number of tokens allocated per hash block. Returns: - list[int]: A list of token IDs. + str: A synthetic prompt as a string. Raises: ConfigurationError: If the input parameters are not compatible. @@ -280,21 +287,17 @@ def _build_token_sequence( current_block_size = final_block_size if hash_id not in self._cache: - # To ensure that the prompt doesn't merge chunks, we insert a BOS or EOS token - # at the beginning. Length is maintained and the prompt generates the expected - # number of tokens. If no BOS or EOS token is available, we don't insert one. - prompt_tokens: list[int] = [] - if self.tokenizer.block_separation_token_id is not None: - prompt_tokens += [self.tokenizer.block_separation_token_id] - prompt_tokens += self._sample_tokens(current_block_size - 1) - else: - prompt_tokens += self._sample_tokens(current_block_size) - - self._cache[hash_id] = prompt_tokens # store to cache + self._hash_id_corpus_rng.reseed_for_hash_id(hash_id) + self._cache[hash_id] = sample_tokens_from_corpus( + self._tokenized_corpus, + current_block_size, + self._hash_id_corpus_rng, + self.tokenizer.block_separation_token_id, + ) final_prompt.extend(self._cache[hash_id]) - return final_prompt + return self.tokenizer.decode(final_prompt, skip_special_tokens=False) def _sample_tokens(self, num_tokens: int) -> list[int]: """Generate a list of token IDs containing exactly `num_tokens` number of tokens diff --git a/src/aiperf/dataset/loader/base_trace_loader.py b/src/aiperf/dataset/loader/base_trace_loader.py index 5acb14509..0ad629fe1 100644 --- a/src/aiperf/dataset/loader/base_trace_loader.py +++ b/src/aiperf/dataset/loader/base_trace_loader.py @@ -1,29 +1,50 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import hashlib from abc import abstractmethod +from collections.abc import Iterator from typing import Any, Generic, TypeVar from aiperf.common.config.config_defaults import InputTokensDefaults from aiperf.common.config.user_config import UserConfig from aiperf.common.models import Conversation, Text, Turn -from aiperf.dataset.generator.parallel_decode import parallel_decode from aiperf.dataset.generator.prompt import PromptGenerator from aiperf.dataset.loader.base_loader import BaseFileLoader +from aiperf.dataset.loader.parallel_convert import parallel_convert from aiperf.dataset.synthesis.models import SynthesisParams from aiperf.dataset.synthesis.synthesizer import Synthesizer from aiperf.plugin.enums import DatasetSamplingStrategy TraceT = TypeVar("TraceT") +_MIN_TRACES_FOR_PARALLEL = 10 + + +def _compute_file_hash(filepath: str) -> str: + """Compute SHA256 hash of file content (first 16 hex chars). + + Falls back to hashing the filepath string if the file cannot be read. + """ + try: + hasher = hashlib.sha256() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest()[:16] + except (OSError, TypeError): + return hashlib.sha256(filepath.encode()).hexdigest()[:16] + class BaseTraceDatasetLoader(BaseFileLoader, Generic[TraceT]): """Base class for trace dataset loaders with hash_ids-based prompt generation. Provides common infrastructure for loading trace-format datasets (Mooncake, Bailian, etc.) including shared initialization, timestamp - filtering, 3-phase prompt generation with parallel decode, and - synthesis integration. + filtering, parallel prompt generation with deterministic per-hash_id + re-seeding, and synthesis integration. Subclasses must implement: - `can_load`: data format detection @@ -50,14 +71,7 @@ def __init__( self._end_offset = user_config.input.fixed_schedule_end_offset self._max_isl = user_config.input.synthesis.max_isl self._max_osl = user_config.input.synthesis.max_osl - - # Use the resolved tokenizer name so worker processes can load from cache - # without needing alias resolution or network access. - self._tokenizer_name = ( - prompt_generator.tokenizer.resolved_name - or user_config.tokenizer.name - or user_config.endpoint.model_names[0] - ) + self._trace_id: str = "" # Precedence: user CLI --isl-block-size > plugin metadata default > hardcoded fallback user_block_size = user_config.input.prompt.input_tokens.block_size @@ -171,6 +185,10 @@ def load_dataset(self) -> dict[str, list[TraceT]]: self._skipped_traces = 0 self._skipped_max_isl = 0 self._capped_max_osl = 0 + + self._trace_id = _compute_file_hash(self.filename) + self.prompt_generator._hash_id_corpus_rng.set_trace_id(self._trace_id) + self.debug(lambda: f"Trace ID: {self._trace_id} for {self.filename}") items: list[TraceT] = [] with open(self.filename) as f: @@ -202,7 +220,7 @@ def load_dataset(self) -> dict[str, list[TraceT]]: return data # ------------------------------------------------------------------ - # convert_to_conversations — 3-phase prompt generation + # convert_to_conversations # ------------------------------------------------------------------ def _get_text_input(self, trace: TraceT) -> str | None: @@ -227,77 +245,70 @@ def _build_turn(self, trace: TraceT, prompt: str) -> Turn: ) def convert_to_conversations( - self, data: dict[str, list[TraceT]] - ) -> list[Conversation]: - """Convert trace sessions to :class:`Conversation` objects. + self, + data: dict[str, list[TraceT]], + num_workers: int | None = None, + batch_size: int = 100, + ) -> Iterator[Conversation]: + """Convert trace sessions to conversations using parallel workers. - Uses a three-phase approach for optimal performance: + Uses multiprocessing Pool with shared memory for the token corpus. + Each worker gets its own HashIdRandomGenerator to produce deterministic + token sequences per hash_id regardless of worker count or order. - 1. Build token sequences, checking the string cache first. - 2. Batch parallel decode for all cache misses. - 3. Assemble final :class:`Conversation` objects. + Falls back to single-threaded conversion for small datasets. + + Yields: + Conversation objects in session order. """ - # Phase 1: Build token sequences and identify cache misses - pending_decodes: list[tuple[str, int, list[int], tuple]] = [] - conversations_data: dict[str, list[tuple[TraceT, str | None]]] = {} + sessions = list(data.items()) + if not sessions: + return + + total_traces = sum(len(traces) for _, traces in sessions) + if total_traces < _MIN_TRACES_FOR_PARALLEL: + yield from self._convert_single_threaded(sessions) + return + + pg = self.prompt_generator + serialized = [ + (sid, [t.model_dump() for t in traces]) # type: ignore[union-attr] + for sid, traces in sessions + ] + + yield from parallel_convert( + sessions=serialized, + tokenizer_name=pg.tokenizer.resolved_name, + corpus=pg._tokenized_corpus, + base_seed=pg._hash_id_corpus_rng.seed, + block_size=self._block_size, + sep_token=pg.tokenizer.block_separation_token_id, + trace_id=self._trace_id, + num_workers=num_workers, + batch_size=batch_size, + ) - for session_id, traces in data.items(): - conversations_data[session_id] = [] - for idx, trace in enumerate(traces): + def _convert_single_threaded( + self, sessions: list[tuple[str, list[TraceT]]] + ) -> Iterator[Conversation]: + """Fallback single-threaded conversion for small datasets.""" + for session_id, traces in sessions: + conversation = Conversation(session_id=session_id) + for trace in traces: text_input = self._get_text_input(trace) if text_input is not None: - conversations_data[session_id].append((trace, text_input)) - continue - - hash_ids: list[int] = getattr(trace, "hash_ids", None) or [] - input_length: int = getattr(trace, "input_length", 0) - - if hash_ids: - cache_key = ( - tuple(hash_ids), - input_length, - self._block_size, - ) - if cache_key in self.prompt_generator._decoded_cache: - prompt = self.prompt_generator._decoded_cache[cache_key] - conversations_data[session_id].append((trace, prompt)) - else: - tokens = self.prompt_generator._build_token_sequence( - input_length, hash_ids, self._block_size - ) - pending_decodes.append((session_id, idx, tokens, cache_key)) - conversations_data[session_id].append((trace, None)) + prompt = text_input else: + hash_ids: list[int] = getattr(trace, "hash_ids", None) or [] + input_length: int = getattr(trace, "input_length", 0) prompt = self.prompt_generator.generate( - mean=input_length, stddev=0, hash_ids=[] + mean=input_length, + stddev=0, + hash_ids=hash_ids, + block_size=self._block_size, ) - conversations_data[session_id].append((trace, prompt)) - - # Phase 2: Batch parallel decode for all cache misses - if pending_decodes: - self.debug( - lambda: f"Parallel decoding {len(pending_decodes)} prompts " - f"({len(data)} conversations)" - ) - token_sequences = [p[2] for p in pending_decodes] - decoded_prompts = parallel_decode(token_sequences, self._tokenizer_name) - - for (session_id, idx, _, cache_key), prompt in zip( - pending_decodes, decoded_prompts, strict=True - ): - self.prompt_generator._decoded_cache[cache_key] = prompt - trace, _ = conversations_data[session_id][idx] - conversations_data[session_id][idx] = (trace, prompt) - - # Phase 3: Build final conversation objects - conversations: list[Conversation] = [] - for session_id, trace_prompt_pairs in conversations_data.items(): - conversation = Conversation(session_id=session_id) - for trace, prompt in trace_prompt_pairs: conversation.turns.append(self._build_turn(trace, prompt)) - conversations.append(conversation) - - return conversations + yield conversation # ------------------------------------------------------------------ # Synthesis — shared orchestration with subclass hooks diff --git a/src/aiperf/dataset/loader/parallel_convert.py b/src/aiperf/dataset/loader/parallel_convert.py new file mode 100644 index 000000000..d8ac27fa4 --- /dev/null +++ b/src/aiperf/dataset/loader/parallel_convert.py @@ -0,0 +1,286 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Parallel conversion of trace sessions to conversations. + +Uses multiprocessing Pool with shared memory for the token corpus. Each worker +gets its own HashIdRandomGenerator to produce deterministic token sequences per +hash_id regardless of worker count or processing order. + +The daemon flag on the current process is temporarily cleared because Python's +multiprocessing refuses to spawn children from daemon processes, and AIPerf +services run as daemons. +""" + +from __future__ import annotations + +import multiprocessing as mp +import os +import sys +from collections.abc import Callable, Iterator +from dataclasses import dataclass, field +from multiprocessing import Pool, shared_memory + +import numpy as np + +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator +from aiperf.common.models import Conversation, Text, Turn +from aiperf.common.tokenizer import Tokenizer + + +@dataclass(slots=True) +class _WorkerInitArgs: + """Arguments passed to each worker process via Pool initargs.""" + + shm_name: str + corpus_len: int + tokenizer_name: str + base_seed: int + block_size: int + sep_token: int | None + trace_id: str + + +@dataclass(slots=True) +class _WorkerState: + """Per-worker process state, initialized once via _init_worker.""" + + tokenizer: Tokenizer + corpus: np.ndarray + shm: shared_memory.SharedMemory # prevent GC from unmapping corpus buffer + hash_rng: HashIdRandomGenerator + block_size: int + sep_token: int | None + sample_tokens: Callable[..., list[int]] + block_cache: dict[int, list[int]] = field(default_factory=dict) + + +# Set once per worker process by _init_worker; read by _process_batch. +_worker_state: _WorkerState | None = None + + +def _init_worker(args: _WorkerInitArgs) -> None: + """Initialize worker process with shared corpus and tokenizer. + + Called once per worker when the Pool is created. Attaches to the + shared-memory corpus, creates a per-worker HashIdRandomGenerator + (seeded by trace_id for file-level determinism), and loads the + tokenizer from local cache (offline mode). + """ + global _worker_state + + from aiperf.dataset.generator.prompt import sample_tokens_from_corpus + + # The main process already downloaded and cached the tokenizer, so force + # offline mode to skip network requests and alias resolution. + os.environ["HF_HUB_OFFLINE"] = "1" + os.environ["TRANSFORMERS_OFFLINE"] = "1" + + shm = shared_memory.SharedMemory(name=args.shm_name) + + # Each worker gets its own RNG so reseed_for_hash_id calls are independent. + hash_rng = HashIdRandomGenerator(args.base_seed, _internal=True) + hash_rng.set_trace_id(args.trace_id) + + _worker_state = _WorkerState( + tokenizer=Tokenizer.from_pretrained(args.tokenizer_name, resolve_alias=False), + corpus=np.ndarray((args.corpus_len,), dtype=np.int32, buffer=shm.buf), + shm=shm, + hash_rng=hash_rng, + block_size=args.block_size, + sep_token=args.sep_token, + sample_tokens=sample_tokens_from_corpus, + ) + + +def _process_batch( + batch: list[tuple[str, list[dict]]], +) -> list[tuple[str, list[tuple]]]: + """Process a batch of sessions, converting hash_ids to prompts. + + Each trace dict must have 'input_length', 'output_length', 'timestamp', + 'delay', and optionally 'hash_ids' and 'text_input'. + """ + assert _worker_state is not None + hash_rng = _worker_state.hash_rng + corpus = _worker_state.corpus + block_size = _worker_state.block_size + sep_token = _worker_state.sep_token + decode = _worker_state.tokenizer.decode + sample_tokens = _worker_state.sample_tokens + block_cache = _worker_state.block_cache + + def get_block_tokens(hash_id: int, size: int) -> list[int]: + if hash_id in block_cache: + return block_cache[hash_id] + hash_rng.reseed_for_hash_id(hash_id) + tokens = sample_tokens(corpus, size, hash_rng, sep_token) + block_cache[hash_id] = tokens + return tokens + + results = [] + for session_id, traces in batch: + turns = [] + for trace in traces: + if trace.get("text_input"): + # Literal prompt provided by the trace (no generation needed). + prompt = trace["text_input"] + elif trace.get("hash_ids"): + # Generate prompt from hash_id blocks. All blocks are full-sized + # except the last, which gets the remainder tokens. + hash_ids = trace["hash_ids"] + input_length = trace["input_length"] + final_block_size = input_length - (len(hash_ids) - 1) * block_size + + tokens: list[int] = [] + for i, hid in enumerate(hash_ids): + size = final_block_size if i == len(hash_ids) - 1 else block_size + tokens.extend(get_block_tokens(hid, size)) + prompt = decode(tokens, skip_special_tokens=False) + else: + prompt = "" + + turns.append( + ( + trace.get("timestamp"), + trace.get("delay"), + prompt, + trace.get("output_length"), + ) + ) + results.append((session_id, turns)) + + return results + + +def _has_broken_stdio() -> bool: + """Check if any stdio stream has an invalid file descriptor.""" + for stream in (sys.stdin, sys.stdout, sys.stderr): + try: + fd = stream.fileno() + if fd < 0: + return True + os.fstat(fd) + except (OSError, ValueError, AttributeError): + return True + return False + + +def _ensure_valid_stdio_fds() -> None: + """Redirect broken stdio to /dev/null before spawning Pool workers. + + Under the Textual terminal UI, child service processes inherit + Textual-managed sys.stdin/stdout/stderr objects whose fileno() may + return -1. When Pool workers fork and call util._close_stdin(), the + invalid FD propagates to _posixsubprocess.fork_exec causing + "bad value(s) in fds_to_keep". Only redirects when a problem is + detected so non-dashboard modes keep normal stdio. + """ + if not _has_broken_stdio(): + return + + devnull = os.open(os.devnull, os.O_RDWR) + for fd in (0, 1, 2): + os.dup2(devnull, fd) + if devnull > 2: + os.close(devnull) + sys.stdin = os.fdopen(0, "r", closefd=False) + sys.stdout = os.fdopen(1, "w", closefd=False) + sys.stderr = os.fdopen(2, "w", closefd=False) + + +def _set_daemon(daemon: bool) -> None: + """Set the daemon flag on the current process. + + Python's multiprocessing refuses to spawn children from daemon processes, + and AIPerf services run as daemons. This temporarily clears the flag. + """ + try: + mp.current_process().daemon = daemon + except AssertionError: + mp.current_process()._config["daemon"] = daemon + + +def parallel_convert( + sessions: list[tuple[str, list[dict]]], + *, + tokenizer_name: str, + corpus: list[int], + base_seed: int, + block_size: int, + sep_token: int | None, + trace_id: str, + num_workers: int | None = None, + batch_size: int = 100, +) -> Iterator[Conversation]: + """Convert trace sessions to conversations using parallel workers. + + Yields Conversation objects one at a time as batches complete, using + ``pool.imap`` to preserve insertion order while avoiding materializing + all results in memory at once. + + Args: + sessions: List of (session_id, [trace_dict, ...]) tuples. + tokenizer_name: HuggingFace tokenizer name (already cached locally). + corpus: Tokenized corpus as a list of token IDs. + base_seed: Base seed for HashIdRandomGenerator. + block_size: Number of tokens per hash block. + sep_token: Optional separator token prepended to each block. + trace_id: File-derived trace ID for deterministic per-file seeding. + num_workers: Number of worker processes. Defaults to min(cpu_count, 16). + batch_size: Number of sessions per worker batch. + + Yields: + Conversation objects in the same order as the input sessions. + """ + _ensure_valid_stdio_fds() + + corpus_len = len(corpus) + shm = shared_memory.SharedMemory( + create=True, size=corpus_len * np.dtype(np.int32).itemsize + ) + + try: + np.ndarray((corpus_len,), dtype=np.int32, buffer=shm.buf)[:] = corpus + + batches = [ + sessions[i : i + batch_size] for i in range(0, len(sessions), batch_size) + ] + + workers = num_workers or min(os.cpu_count() or 4, 16) + + was_daemon = mp.current_process().daemon + try: + if was_daemon: + _set_daemon(False) + init_args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=corpus_len, + tokenizer_name=tokenizer_name, + base_seed=base_seed, + block_size=block_size, + sep_token=sep_token, + trace_id=trace_id, + ) + with Pool(workers, _init_worker, (init_args,)) as pool: + # imap preserves submission order (unlike imap_unordered) + for batch_result in pool.imap(_process_batch, batches): + for sid, turns in batch_result: + yield Conversation( + session_id=sid, + turns=[ + Turn( + timestamp=ts, + delay=delay, + texts=[Text(name="text", contents=[prompt])], + max_tokens=max_tokens, + ) + for ts, delay, prompt, max_tokens in turns + ], + ) + finally: + if was_daemon: + _set_daemon(True) + finally: + shm.close() + shm.unlink() diff --git a/src/aiperf/dataset/protocols.py b/src/aiperf/dataset/protocols.py index abf738328..a48c9f3fd 100644 --- a/src/aiperf/dataset/protocols.py +++ b/src/aiperf/dataset/protocols.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from aiperf.common.models import Conversation @@ -47,7 +48,7 @@ def load_dataset(self) -> dict[str, list["CustomDatasetT"]]: ... def convert_to_conversations( self, custom_data: dict[str, list["CustomDatasetT"]] - ) -> list[Conversation]: ... + ) -> Iterable[Conversation]: ... @runtime_checkable diff --git a/src/aiperf/dataset/synthesis/rolling_hasher.py b/src/aiperf/dataset/synthesis/rolling_hasher.py index c1fc1ae28..5beb55008 100644 --- a/src/aiperf/dataset/synthesis/rolling_hasher.py +++ b/src/aiperf/dataset/synthesis/rolling_hasher.py @@ -225,7 +225,9 @@ def hashes_to_texts( if hash_ids: # Use PromptGenerator to generate text from hash_ids # This uses the Shakespeare corpus and caches blocks by hash_id - text = prompt_generator.generate(mean=input_len, hash_ids=hash_ids) + text = prompt_generator.generate( + mean=input_len, hash_ids=hash_ids, block_size=block_size + ) else: # No hash_ids - generate plain text of target length text = prompt_generator.generate(mean=input_len) diff --git a/tests/unit/common/test_hash_id_random_generator.py b/tests/unit/common/test_hash_id_random_generator.py new file mode 100644 index 000000000..02afaee6d --- /dev/null +++ b/tests/unit/common/test_hash_id_random_generator.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for HashIdRandomGenerator parallel processing with reproducibility.""" + +import pytest + +from aiperf.common import random_generator as rng +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator + + +class TestHashIdRandomGenerator: + """Test HashIdRandomGenerator for parallel processing reproducibility.""" + + @pytest.fixture(autouse=True) + def setup_rng(self): + """Initialize global RNG before each test.""" + rng.reset() + rng.init(42) + yield + rng.reset() + + def test_deterministic_seeding_per_hash_id(self): + """Test that the same hash_id always produces the same random sequence.""" + base_rng = rng.derive("test.base") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Generate values for hash_id=123 + hash_rng.reseed_for_hash_id(123) + values_1 = [hash_rng.randrange(1000) for _ in range(10)] + + # Generate values for hash_id=456 + hash_rng.reseed_for_hash_id(456) + values_2 = [hash_rng.randrange(1000) for _ in range(10)] + + # Generate values for hash_id=123 again (should match values_1) + hash_rng.reseed_for_hash_id(123) + values_3 = [hash_rng.randrange(1000) for _ in range(10)] + + # Same hash_id produces same sequence + assert values_1 == values_3 + # Different hash_ids produce different sequences + assert values_1 != values_2 + + def test_independence_across_workers(self): + """Test that different worker instances produce identical results for same hash_id.""" + # Simulate Worker 1 + base_rng_1 = rng.derive("worker.corpus") + hash_rng_1 = HashIdRandomGenerator.from_base_rng(base_rng_1) + + # Reset and re-init to simulate Worker 2 with same global seed + rng.reset() + rng.init(42) + + # Simulate Worker 2 + base_rng_2 = rng.derive("worker.corpus") + hash_rng_2 = HashIdRandomGenerator.from_base_rng(base_rng_2) + + # Both workers process hash_id=789 + hash_rng_1.reseed_for_hash_id(789) + worker1_values = [hash_rng_1.randrange(1000) for _ in range(10)] + + hash_rng_2.reseed_for_hash_id(789) + worker2_values = [hash_rng_2.randrange(1000) for _ in range(10)] + + # Workers produce identical results + assert worker1_values == worker2_values + + def test_order_independence(self): + """Test that processing order doesn't affect reproducibility.""" + base_rng = rng.derive("test.order") + + # Process in order: 100, 200, 300 + hash_rng_1 = HashIdRandomGenerator.from_base_rng(base_rng) + results_order1 = {} + for hash_id in [100, 200, 300]: + hash_rng_1.reseed_for_hash_id(hash_id) + results_order1[hash_id] = [hash_rng_1.randrange(1000) for _ in range(5)] + + # Process in different order: 300, 100, 200 + hash_rng_2 = HashIdRandomGenerator.from_base_rng(base_rng) + results_order2 = {} + for hash_id in [300, 100, 200]: + hash_rng_2.reseed_for_hash_id(hash_id) + results_order2[hash_id] = [hash_rng_2.randrange(1000) for _ in range(5)] + + # Same hash_id produces same results regardless of order + assert results_order1[100] == results_order2[100] + assert results_order1[200] == results_order2[200] + assert results_order1[300] == results_order2[300] + + def test_parallel_cache_simulation(self): + """Simulate parallel workers with individual caches.""" + # Worker 1 processes hash_ids: 1, 2, 3 + base_rng_w1 = rng.derive("worker.corpus") + hash_rng_w1 = HashIdRandomGenerator.from_base_rng(base_rng_w1) + cache_w1 = {} + + for hash_id in [1, 2, 3]: + hash_rng_w1.reseed_for_hash_id(hash_id) + cache_w1[hash_id] = [hash_rng_w1.randrange(1000) for _ in range(5)] + + # Reset for Worker 2 simulation + rng.reset() + rng.init(42) + + # Worker 2 processes hash_ids: 3, 4, 5 (overlapping hash_id=3) + base_rng_w2 = rng.derive("worker.corpus") + hash_rng_w2 = HashIdRandomGenerator.from_base_rng(base_rng_w2) + cache_w2 = {} + + for hash_id in [3, 4, 5]: + hash_rng_w2.reseed_for_hash_id(hash_id) + cache_w2[hash_id] = [hash_rng_w2.randrange(1000) for _ in range(5)] + + # Both workers produce identical results for hash_id=3 + assert cache_w1[3] == cache_w2[3] + + # Different hash_ids produce different results + assert cache_w1[1] != cache_w1[2] + assert cache_w2[4] != cache_w2[5] + + def test_non_deterministic_mode(self): + """Test that non-deterministic mode works (seed=None).""" + rng.reset() + rng.init(None) + + base_rng = rng.derive("test.nondeterministic") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Should not raise errors + hash_rng.reseed_for_hash_id(123) + values = [hash_rng.randrange(1000) for _ in range(10)] + + assert len(values) == 10 + # Ensure that an actual seed is created for the HashIdRandomGenerator + assert hash_rng.seed is not None + + def test_multiple_random_operations(self): + """Test various random operations after reseeding.""" + base_rng = rng.derive("test.operations") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Test for hash_id=555 + hash_rng.reseed_for_hash_id(555) + int_val = hash_rng.randrange(100, 200) + float_val = hash_rng.uniform(0.0, 1.0) + choice_val = hash_rng.choice([10, 20, 30, 40]) + + # Re-seed with same hash_id and verify reproducibility + hash_rng.reseed_for_hash_id(555) + int_val_2 = hash_rng.randrange(100, 200) + float_val_2 = hash_rng.uniform(0.0, 1.0) + choice_val_2 = hash_rng.choice([10, 20, 30, 40]) + + assert int_val == int_val_2 + assert float_val == float_val_2 + assert choice_val == choice_val_2 + + def test_hash_collision_independence(self): + """Test that different hash_ids produce independent sequences.""" + base_rng = rng.derive("test.collision") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_ids = [1, 12345, 999999, 7777777, 123456789] + sequences = {} + + for hash_id in hash_ids: + hash_rng.reseed_for_hash_id(hash_id) + sequences[hash_id] = [hash_rng.randrange(1000) for _ in range(20)] + + # All sequences should be different + unique_sequences = set(tuple(seq) for seq in sequences.values()) + assert len(unique_sequences) == len(hash_ids) + + def test_reseed_state_isolation(self): + """Test that reseeding properly isolates state between hash_ids.""" + base_rng = rng.derive("test.isolation") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Generate partial sequence for hash_id=111 + hash_rng.reseed_for_hash_id(111) + partial_seq_111 = [hash_rng.randrange(1000) for _ in range(3)] + + # Switch to hash_id=222 and generate values + hash_rng.reseed_for_hash_id(222) + _ = [hash_rng.randrange(1000) for _ in range(5)] + + # Return to hash_id=111 and continue - should continue from fresh state + hash_rng.reseed_for_hash_id(111) + full_seq_111 = [hash_rng.randrange(1000) for _ in range(10)] + + # The first 3 values should match the partial sequence + assert full_seq_111[:3] == partial_seq_111 + + def test_trace_id_isolation(self): + """Test that different trace_ids produce different sequences for same hash_id.""" + base_rng = rng.derive("test.trace_id") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_rng.set_trace_id("trace_file_a") + hash_rng.reseed_for_hash_id(100) + values_a = [hash_rng.randrange(1000) for _ in range(10)] + + hash_rng.set_trace_id("trace_file_b") + hash_rng.reseed_for_hash_id(100) + values_b = [hash_rng.randrange(1000) for _ in range(10)] + + assert values_a != values_b + + +class TestHashIdRandomGeneratorEdgeCases: + """Test edge cases and error conditions.""" + + @pytest.fixture(autouse=True) + def setup_rng(self): + """Initialize global RNG before each test.""" + rng.reset() + rng.init(42) + yield + rng.reset() + + def test_zero_hash_id(self): + """Test with hash_id=0.""" + base_rng = rng.derive("test.zero") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_rng.reseed_for_hash_id(0) + values = [hash_rng.randrange(1000) for _ in range(5)] + + assert len(values) == 5 + + def test_negative_hash_id(self): + """Test with negative hash_id.""" + base_rng = rng.derive("test.negative") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_rng.reseed_for_hash_id(-123) + values = [hash_rng.randrange(1000) for _ in range(5)] + + # Should not produce the same as positive 123 + hash_rng.reseed_for_hash_id(123) + values_positive = [hash_rng.randrange(1000) for _ in range(5)] + + assert values != values_positive + + def test_large_hash_id(self): + """Test with very large hash_id.""" + base_rng = rng.derive("test.large") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + large_hash_id = 999999999999999999 + hash_rng.reseed_for_hash_id(large_hash_id) + values = [hash_rng.randrange(1000) for _ in range(5)] + + assert len(values) == 5 + + @pytest.mark.parametrize( + "operation", + [ + lambda rng: rng.integers(0, 100, size=2), + lambda rng: rng.random_batch(2), + lambda rng: rng.shuffle([1, 2, 3, 4, 5]), + lambda rng: rng.numpy_choice([1, 2, 3, 4, 5], size=2), + lambda rng: rng.normal(0, 1, size=2), + ], + ) + def test_numpy_rng_raises_exception(self, operation): + """Test that using NumPy RNG operations raises an exception.""" + base_rng = rng.derive("test.numpy") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + hash_rng.reseed_for_hash_id(123) + + with pytest.raises( + RuntimeError, match="HashIdRandomGenerator does not support NumPy RNG" + ): + operation(hash_rng) diff --git a/tests/unit/dataset/composer/test_base_composer.py b/tests/unit/dataset/composer/test_base_composer.py index a889be039..da8422733 100644 --- a/tests/unit/dataset/composer/test_base_composer.py +++ b/tests/unit/dataset/composer/test_base_composer.py @@ -245,10 +245,10 @@ def test_prefix_prompt_enabled_property(self, base_config, mock_tokenizer): composer2 = ConcreteBaseComposer(base_config, mock_tokenizer) assert composer2.prefix_prompt_enabled is False - def test_inject_context_prompts_with_shared_system_prompt( + def test_finalize_conversation_with_shared_system_prompt( self, base_config, mock_tokenizer ): - """Test _inject_context_prompts with shared system prompt.""" + """Test _finalize_conversation with shared system prompt.""" base_config.input.prompt.prefix_prompt.shared_system_prompt_length = 50 base_config.input.prompt.prefix_prompt.length = 0 base_config.input.conversation.num = 3 @@ -259,7 +259,6 @@ def test_inject_context_prompts_with_shared_system_prompt( ): composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -268,13 +267,13 @@ def test_inject_context_prompts_with_shared_system_prompt( Conversation(session_id="conv_2"), ] - # Mock the prompt generator method with patch.object( composer.prompt_generator, "get_shared_system_prompt", return_value="shared system prompt text", ): - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # All conversations should have the same system message assert conversations[0].system_message == "shared system prompt text" @@ -285,17 +284,16 @@ def test_inject_context_prompts_with_shared_system_prompt( assert conversations[1].user_context_message is None assert conversations[2].user_context_message is None - def test_inject_context_prompts_with_user_context_prompt( + def test_finalize_conversation_with_user_context_prompt( self, base_config, mock_tokenizer ): - """Test _inject_context_prompts with user context prompts.""" + """Test _finalize_conversation with user context prompts.""" base_config.input.prompt.prefix_prompt.user_context_prompt_length = 30 base_config.input.prompt.prefix_prompt.length = 0 base_config.input.conversation.num = 3 composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -304,7 +302,6 @@ def test_inject_context_prompts_with_user_context_prompt( Conversation(session_id="conv_2"), ] - # Mock the prompt generator method def mock_generate_user_context(index): return f"user context {index}" @@ -313,7 +310,8 @@ def mock_generate_user_context(index): "generate_user_context_prompt", side_effect=mock_generate_user_context, ): - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # Each conversation should have unique user context assert conversations[0].user_context_message == "user context 0" @@ -324,10 +322,8 @@ def mock_generate_user_context(index): assert conversations[1].system_message is None assert conversations[2].system_message is None - def test_inject_context_prompts_with_both_prompts( - self, base_config, mock_tokenizer - ): - """Test _inject_context_prompts with both shared system and user context prompts.""" + def test_finalize_conversation_with_both_prompts(self, base_config, mock_tokenizer): + """Test _finalize_conversation with both shared system and user context prompts.""" base_config.input.prompt.prefix_prompt.shared_system_prompt_length = 50 base_config.input.prompt.prefix_prompt.user_context_prompt_length = 30 base_config.input.prompt.prefix_prompt.length = 0 @@ -339,7 +335,6 @@ def test_inject_context_prompts_with_both_prompts( ): composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -347,7 +342,6 @@ def test_inject_context_prompts_with_both_prompts( Conversation(session_id="conv_1"), ] - # Mock both prompt generator methods def mock_generate_user_context(index): return f"user context {index}" @@ -363,7 +357,8 @@ def mock_generate_user_context(index): side_effect=mock_generate_user_context, ), ): - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # Both conversations should have system message assert conversations[0].system_message == "shared system prompt" @@ -372,8 +367,8 @@ def mock_generate_user_context(index): assert conversations[0].user_context_message == "user context 0" assert conversations[1].user_context_message == "user context 1" - def test_inject_context_prompts_with_no_prompts(self, base_config, mock_tokenizer): - """Test _inject_context_prompts when no context prompts are configured.""" + def test_finalize_conversation_with_no_prompts(self, base_config, mock_tokenizer): + """Test _finalize_conversation when no context prompts are configured.""" base_config.input.prompt.prefix_prompt.length = 0 base_config.input.prompt.prefix_prompt.shared_system_prompt_length = None base_config.input.prompt.prefix_prompt.user_context_prompt_length = None @@ -381,7 +376,6 @@ def test_inject_context_prompts_with_no_prompts(self, base_config, mock_tokenize composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -389,8 +383,8 @@ def test_inject_context_prompts_with_no_prompts(self, base_config, mock_tokenize Conversation(session_id="conv_1"), ] - # Should not call any prompt generator methods - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # No messages should be set assert conversations[0].system_message is None diff --git a/tests/unit/dataset/composer/test_custom_composer.py b/tests/unit/dataset/composer/test_custom_composer.py index 75ec844b5..f861df296 100644 --- a/tests/unit/dataset/composer/test_custom_composer.py +++ b/tests/unit/dataset/composer/test_custom_composer.py @@ -64,34 +64,26 @@ def test_create_loader_instance_dataset_types( composer._create_loader_instance(dataset_type) assert isinstance(composer.loader, expected_instance) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) - def test_create_dataset_trace( - self, mock_check_file, mock_parallel_decode, trace_config, mock_tokenizer - ): + def test_create_dataset_trace(self, mock_check_file, trace_config, mock_tokenizer): """Test that create_dataset returns correct type.""" - mock_parallel_decode.return_value = ["decoded 1", "decoded 2", "decoded 3"] composer = CustomDatasetComposer(trace_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) == 3 assert all(isinstance(c, Conversation) for c in conversations) assert all(isinstance(turn, Turn) for c in conversations for turn in c.turns) assert all(len(turn.texts) == 1 for c in conversations for turn in c.turns) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) - def test_max_tokens_config( - self, mock_check_file, mock_parallel_decode, trace_config, mock_tokenizer - ): - mock_parallel_decode.return_value = ["decoded 1", "decoded 2", "decoded 3"] + def test_max_tokens_config(self, mock_check_file, trace_config, mock_tokenizer): trace_config.input.prompt.output_tokens.mean = 120 trace_config.input.prompt.output_tokens.stddev = 8.0 composer = CustomDatasetComposer(trace_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) > 0 # With global RNG, verify max_tokens is set to a positive integer @@ -104,7 +96,6 @@ def test_max_tokens_config( # Should be roughly around the mean of 120 (within 3 stddev) assert 96 < turn.max_tokens < 144 - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) @patch("pathlib.Path.iterdir", return_value=[]) @@ -112,17 +103,15 @@ def test_max_tokens_mooncake( self, mock_iterdir, mock_check_file, - mock_parallel_decode, custom_config, mock_tokenizer, ): """Test that max_tokens can be set from the custom file""" - mock_parallel_decode.return_value = ["decoded 1", "decoded 2", "decoded 3"] mock_check_file.return_value = None custom_config.input.custom_dataset_type = CustomDatasetType.MOONCAKE_TRACE composer = CustomDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: for turn in conversation.turns: @@ -151,9 +140,8 @@ def test_create_dataset_empty_result( mock_get_class.return_value = mock_loader_class composer = CustomDatasetComposer(custom_config, mock_tokenizer) - result = composer.create_dataset() + result = list(composer.create_dataset()) - assert isinstance(result, list) assert len(result) == 0 diff --git a/tests/unit/dataset/composer/test_synthetic_composer.py b/tests/unit/dataset/composer/test_synthetic_composer.py index 0c9e1ea85..c333bfe9c 100644 --- a/tests/unit/dataset/composer/test_synthetic_composer.py +++ b/tests/unit/dataset/composer/test_synthetic_composer.py @@ -97,7 +97,7 @@ def test_initialization_with_all_zero_mean(self, mock_tokenizer): def test_create_dataset_basic(self, synthetic_config, mock_tokenizer): """Test basic dataset creation with text-only conversations.""" composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test create_dataset returns correct number of conversations assert len(conversations) == 5 # num_conversations @@ -119,7 +119,7 @@ def test_create_dataset_basic(self, synthetic_config, mock_tokenizer): def test_create_dataset_with_images(self, image_config, mock_tokenizer): """Test dataset creation with image generation enabled.""" composer = SyntheticDatasetComposer(image_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations include image payloads assert len(conversations) == 3 @@ -140,7 +140,7 @@ def test_create_dataset_with_images(self, image_config, mock_tokenizer): def test_create_dataset_with_audio(self, audio_config, mock_tokenizer): """Test dataset creation with audio generation enabled.""" composer = SyntheticDatasetComposer(audio_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations include audio payloads assert len(conversations) == 3 @@ -160,7 +160,7 @@ def test_create_dataset_with_audio(self, audio_config, mock_tokenizer): def test_create_dataset_multimodal(self, multimodal_config, mock_tokenizer): """Test dataset creation with both image and audio enabled.""" composer = SyntheticDatasetComposer(multimodal_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations include both image and audio payloads assert ( @@ -183,7 +183,7 @@ def test_create_dataset_with_prefix_prompts( ): """Test dataset creation with prefix prompts enabled.""" composer = SyntheticDatasetComposer(prefix_prompt_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) == 5 for conversation in conversations: @@ -198,7 +198,7 @@ def test_create_dataset_with_prefix_prompts( def test_create_dataset_multiple_turns(self, multiturn_config, mock_tokenizer): """Test dataset creation with multiple turns and delays.""" composer = SyntheticDatasetComposer(multiturn_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations have multiple turns assert len(conversations) == 4 @@ -303,7 +303,7 @@ def test_turn_delays_from_config_options(self, mock_tokenizer): rng.reset() rng.init(42) # Set seed for reproducibility composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Verify conversations were created assert len(conversations) == 5 @@ -329,7 +329,7 @@ def test_turn_delays_from_config_options(self, mock_tokenizer): rng.reset() rng.init(42) # Reset seed composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: # First turn should still have no delay @@ -365,7 +365,7 @@ def test_turn_delays_with_zero_mean(self, mock_tokenizer): rng.reset() rng.init(42) composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: # All turns should have None delay when mean=0 @@ -511,7 +511,7 @@ def test_zero_conversations(self, synthetic_config, mock_tokenizer): synthetic_config.input.conversation.num_dataset_entries = 0 composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) == 0 @@ -536,7 +536,7 @@ def test_edge_case_statistical_parameters(self, mock_tokenizer): ) composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test with very small/large mean and stddev values assert len(conversations) == 2 @@ -555,7 +555,7 @@ def test_multi_turn_does_not_control_dataset_entries(self, mock_tokenizer): ) composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Verify that num_dataset_entries controls the number of conversations generated assert len(conversations) == 10 @@ -568,7 +568,7 @@ def test_different_conversation_counts( synthetic_config.input.conversation.num_dataset_entries = num_conversations composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Parametrized test for different num_conversations values assert len(conversations) == num_conversations @@ -579,7 +579,7 @@ def test_different_batch_sizes(self, synthetic_config, batch_size, mock_tokenize synthetic_config.input.prompt.batch_size = batch_size composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Parametrized test for different batch_size values assert len(conversations) > 0 @@ -605,7 +605,7 @@ def test_missing_required_generators(self, synthetic_config, mock_tokenizer): with pytest.raises( ValueError, match="Text prompt generation requires a tokenizer" ): - composer.create_dataset() + list(composer.create_dataset()) def test_reproducibility_with_fixed_seed(self, multimodal_config, mock_tokenizer): """Test that dataset generation is reproducible with fixed random seed.""" @@ -620,12 +620,12 @@ def test_reproducibility_with_fixed_seed(self, multimodal_config, mock_tokenizer rng.reset() rng.init(42) composer1 = SyntheticDatasetComposer(multimodal_config, mock_tokenizer) - conversations1 = composer1.create_dataset() + conversations1 = list(composer1.create_dataset()) rng.reset() rng.init(42) composer2 = SyntheticDatasetComposer(multimodal_config, mock_tokenizer) - conversations2 = composer2.create_dataset() + conversations2 = list(composer2.create_dataset()) # Basic structure should be the same assert len(conversations1) == len(conversations2) @@ -653,7 +653,7 @@ def test_model_selection_random(self, custom_config, mock_tokenizer): custom_config.endpoint.model_names = ["test-model-1", "test-model-2"] composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # With random selection, verify models are from the valid set for conversation in conversations: @@ -665,7 +665,7 @@ def test_model_selection_round_robin(self, custom_config, mock_tokenizer): custom_config.endpoint.model_names = ["test-model-1", "test-model-2"] composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Check that models are selected in round-robin fashion for i, conversation in enumerate(conversations): @@ -682,7 +682,7 @@ def test_max_tokens_integration_with_mean(self, custom_config, mock_tokenizer): custom_config.input.prompt.output_tokens.stddev = 5.0 composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # With global RNG, verify max_tokens is set to a positive integer # around the mean of 100 @@ -699,7 +699,7 @@ def test_max_tokens_not_set_when_mean_none(self, custom_config, mock_tokenizer): custom_config.input.prompt.output_tokens.stddev = None composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: for turn in conversation.turns: diff --git a/tests/unit/dataset/composer/test_synthetic_rankings_composer.py b/tests/unit/dataset/composer/test_synthetic_rankings_composer.py index f86b64d86..422adf7c0 100644 --- a/tests/unit/dataset/composer/test_synthetic_rankings_composer.py +++ b/tests/unit/dataset/composer/test_synthetic_rankings_composer.py @@ -18,7 +18,7 @@ def test_create_dataset_structure(synthetic_config, mock_tokenizer): synthetic_config.input.rankings.passages.stddev = 1 composer = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - dataset = composer.create_dataset() + dataset = list(composer.create_dataset()) assert len(dataset) == synthetic_config.input.conversation.num_dataset_entries for conv in dataset: @@ -42,7 +42,7 @@ def test_passage_count_distribution(synthetic_config, mock_tokenizer): synthetic_config.input.rankings.passages.stddev = 2 composer = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - dataset = composer.create_dataset() + dataset = list(composer.create_dataset()) passage_counts = [len(conv.turns[0].texts[1].contents) for conv in dataset] assert all(1 <= c <= 10 for c in passage_counts) @@ -56,10 +56,10 @@ def test_reproducibility_fixed_seed(synthetic_config, mock_tokenizer): synthetic_config.input.random_seed = 42 composer1 = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - data1 = composer1.create_dataset() + data1 = list(composer1.create_dataset()) composer2 = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - data2 = composer2.create_dataset() + data2 = list(composer2.create_dataset()) # Session IDs differ (fresh), but text contents should match for c1, c2 in zip(data1, data2, strict=True): @@ -78,7 +78,7 @@ def test_rankings_specific_token_options(synthetic_config, mock_tokenizer): synthetic_config.input.random_seed = 42 composer = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - dataset = composer.create_dataset() + dataset = list(composer.create_dataset()) # Verify that data was generated assert len(dataset) > 0 diff --git a/tests/unit/dataset/conftest.py b/tests/unit/dataset/conftest.py index d5fabd459..704eb8994 100644 --- a/tests/unit/dataset/conftest.py +++ b/tests/unit/dataset/conftest.py @@ -5,16 +5,16 @@ """ from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest import aiperf.endpoints # noqa: F401 # Import to register endpoints import aiperf.transports # noqa: F401 # Import to register transports from aiperf.common.config import EndpointConfig, OutputConfig, ServiceConfig, UserConfig -from aiperf.common.models import Conversation +from aiperf.common.models import Conversation, DatasetMetadata from aiperf.dataset.dataset_manager import DatasetManager -from aiperf.plugin.enums import EndpointType +from aiperf.plugin.enums import DatasetSamplingStrategy, EndpointType @pytest.fixture @@ -32,37 +32,64 @@ def user_config(tmp_path: Path) -> UserConfig: @pytest.fixture -def empty_dataset_manager( +def dataset_manager( user_config: UserConfig, ) -> DatasetManager: - """Create a DatasetManager instance with empty dataset.""" - manager = DatasetManager( + """Create a DatasetManager instance.""" + return DatasetManager( service_config=ServiceConfig(), user_config=user_config, service_id="test_dataset_manager", ) - manager.dataset = {} - return manager -@pytest.fixture -def populated_dataset_manager( +def _create_dataset_manager_with_client( user_config: UserConfig, - sample_conversations: dict[str, Conversation], + conversations: dict[str, Conversation], ) -> DatasetManager: - """Create a DatasetManager instance with sample data.""" + """Create a DatasetManager with a mock dataset client backed by conversations.""" manager = DatasetManager( service_config=ServiceConfig(), user_config=user_config, service_id="test_dataset_manager", ) - manager.dataset = sample_conversations + + async def mock_get_conversation(conversation_id: str) -> Conversation: + if conversation_id not in conversations: + raise KeyError(conversation_id) + return conversations[conversation_id] + + mock_client = AsyncMock() + mock_client.get_conversation = AsyncMock(side_effect=mock_get_conversation) + manager._dataset_client = mock_client + + manager.dataset_metadata = DatasetMetadata( + conversations=[conv.metadata() for conv in conversations.values()], + sampling_strategy=DatasetSamplingStrategy.RANDOM, + ) return manager +@pytest.fixture +def populated_dataset_manager( + user_config: UserConfig, + sample_conversations: dict[str, Conversation], +) -> DatasetManager: + """Create a DatasetManager with a mock dataset client for payload tests.""" + return _create_dataset_manager_with_client(user_config, sample_conversations) + + +@pytest.fixture +def empty_dataset_manager( + user_config: UserConfig, +) -> DatasetManager: + """Create a DatasetManager with an empty dataset client.""" + return _create_dataset_manager_with_client(user_config, {}) + + @pytest.fixture def capture_file_writes(): - """Provide a fixture to capture file write operations for testing purposes.""" + """Capture file write operations for testing.""" class FileWriteCapture: def __init__(self): diff --git a/tests/unit/dataset/generator/test_parallel_decode.py b/tests/unit/dataset/generator/test_parallel_decode.py deleted file mode 100644 index a6b0cca79..000000000 --- a/tests/unit/dataset/generator/test_parallel_decode.py +++ /dev/null @@ -1,269 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Unit tests for parallel_decode module.""" - -import importlib -import os -from unittest.mock import MagicMock, call, patch - -import pytest - -from aiperf.dataset.generator.parallel_decode import _set_daemon, parallel_decode - -# Import the module directly (not through __init__.py which exports the function) -pd_module = importlib.import_module("aiperf.dataset.generator.parallel_decode") - - -class TestParallelDecode: - """Test suite for parallel_decode module.""" - - def test_parallel_decode_empty_list(self): - """Test parallel_decode with empty input returns empty list.""" - result = parallel_decode([], "gpt2") - assert result == [] - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_parallel_decode_small_batch_sequential(self, mock_tokenizer_class): - """Test that small batches (< 10) use sequential decoding.""" - mock_tokenizer = MagicMock() - mock_tokenizer.decode.return_value = "decoded" - mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer - - token_sequences = [[1, 2, 3], [4, 5, 6]] # Less than 10 - result = parallel_decode(token_sequences, "gpt2") - - # Should use sequential decoding (Tokenizer.from_pretrained called once) - mock_tokenizer_class.from_pretrained.assert_called_once_with("gpt2") - assert mock_tokenizer.decode.call_count == 2 - assert result == ["decoded", "decoded"] - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_parallel_decode_large_batch_uses_executor(self, mock_executor_class): - """Test that large batches (>= 10) use ProcessPoolExecutor.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - token_sequences = [[i] for i in range(15)] # 15 sequences - result = parallel_decode(token_sequences, "gpt2") - - # Should use ProcessPoolExecutor - mock_executor_class.assert_called_once() - mock_executor.map.assert_called_once() - assert len(result) == 15 - - @patch.object(pd_module, "mp") - @patch.object(pd_module, "ProcessPoolExecutor") - def test_parallel_decode_respects_max_workers(self, mock_executor_class, mock_mp): - """Test that max_workers parameter is respected.""" - mock_mp.cpu_count.return_value = 16 - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - token_sequences = [[i] for i in range(15)] - parallel_decode(token_sequences, "gpt2", max_workers=4) - - # Should be called with max_workers=4 - call_kwargs = mock_executor_class.call_args.kwargs - assert call_kwargs["max_workers"] == 4 - - @patch.object(pd_module, "mp") - @patch.object(pd_module, "ProcessPoolExecutor") - def test_parallel_decode_default_max_workers_capped_at_8( - self, mock_executor_class, mock_mp - ): - """Test that default max_workers is capped at 8.""" - mock_mp.cpu_count.return_value = 64 # Lots of CPUs - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - token_sequences = [[i] for i in range(15)] - parallel_decode(token_sequences, "gpt2") - - # Should be capped at 8 - call_kwargs = mock_executor_class.call_args.kwargs - assert call_kwargs["max_workers"] == 8 - - -class TestSetDaemon: - """Test suite for _set_daemon helper.""" - - def test_set_daemon_uses_property(self): - """_set_daemon sets daemon via the public property when possible.""" - mock_proc = MagicMock() - mock_proc.daemon = True - with patch.object(pd_module.mp, "current_process", return_value=mock_proc): - _set_daemon(False) - assert mock_proc.daemon is False - - def test_set_daemon_falls_back_to_config_on_assertion_error(self): - """_set_daemon falls back to _config when property raises AssertionError.""" - mock_proc = MagicMock() - type(mock_proc).daemon = property( - fget=lambda self: self._config.get("daemon"), - fset=MagicMock(side_effect=AssertionError), - ) - mock_proc._config = {"daemon": True} - with patch.object(pd_module.mp, "current_process", return_value=mock_proc): - _set_daemon(False) - assert mock_proc._config["daemon"] is False - - -class TestParallelDecodeDaemonFlag: - """Test that parallel_decode properly manages the daemon flag.""" - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_daemon_flag_cleared_before_executor_and_restored_after( - self, mock_executor_class - ): - """Daemon flag is cleared before spawning and restored after.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - mock_process = MagicMock() - mock_process.daemon = True - with ( - patch.object(pd_module.mp, "current_process", return_value=mock_process), - patch.object(pd_module, "_set_daemon") as mock_set, - ): - parallel_decode([[i] for i in range(15)], "gpt2") - - assert mock_set.call_args_list == [call(False), call(True)] - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_daemon_flag_not_toggled_for_non_daemon_process(self, mock_executor_class): - """Daemon flag is not toggled when the process is not a daemon.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - mock_process = MagicMock() - mock_process.daemon = False - with ( - patch.object(pd_module.mp, "current_process", return_value=mock_process), - patch.object(pd_module, "_set_daemon") as mock_set, - ): - parallel_decode([[i] for i in range(15)], "gpt2") - - mock_set.assert_not_called() - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_daemon_flag_restored_on_executor_error(self, mock_executor_class): - """Daemon flag is restored even when the executor raises.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.side_effect = RuntimeError("boom") - mock_executor_class.return_value = mock_executor - - mock_process = MagicMock() - mock_process.daemon = True - with ( - patch.object(pd_module.mp, "current_process", return_value=mock_process), - patch.object(pd_module, "_set_daemon") as mock_set, - pytest.raises(RuntimeError, match="boom"), - ): - parallel_decode([[i] for i in range(15)], "gpt2") - - assert mock_set.call_args_list == [call(False), call(True)] - - -class TestWorkerFunctions: - """Test suite for worker functions.""" - - def test_decode_tokens_raises_without_init(self): - """Test that _decode_tokens raises if worker not initialized.""" - pd_module._worker_tokenizer = None - - with pytest.raises(RuntimeError, match="not initialized"): - pd_module._decode_tokens([1, 2, 3]) - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_loads_tokenizer(self, mock_tokenizer_class): - """Test that _init_worker loads the tokenizer.""" - pd_module._worker_tokenizer = None - pd_module._worker_tokenizer_name = None - - mock_tokenizer = MagicMock() - mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer - - pd_module._init_worker("gpt2") - - mock_tokenizer_class.from_pretrained.assert_called_once_with( - "gpt2", resolve_alias=False - ) - assert pd_module._worker_tokenizer is mock_tokenizer - assert pd_module._worker_tokenizer_name == "gpt2" - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_sets_offline_mode(self, mock_tokenizer_class, monkeypatch): - """Test that _init_worker enables HuggingFace offline mode.""" - monkeypatch.delenv("HF_HUB_OFFLINE", raising=False) - monkeypatch.delenv("TRANSFORMERS_OFFLINE", raising=False) - pd_module._worker_tokenizer = None - pd_module._worker_tokenizer_name = None - - mock_tokenizer_class.from_pretrained.return_value = MagicMock() - - pd_module._init_worker("gpt2") - - assert os.environ["HF_HUB_OFFLINE"] == "1" - assert os.environ["TRANSFORMERS_OFFLINE"] == "1" - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_reuses_tokenizer_same_name(self, mock_tokenizer_class): - """Test that _init_worker reuses tokenizer if same name.""" - mock_tokenizer = MagicMock() - pd_module._worker_tokenizer = mock_tokenizer - pd_module._worker_tokenizer_name = "gpt2" - - pd_module._init_worker("gpt2") - - # Should NOT call from_pretrained again - mock_tokenizer_class.from_pretrained.assert_not_called() - assert pd_module._worker_tokenizer is mock_tokenizer - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_reloads_tokenizer_different_name(self, mock_tokenizer_class): - """Test that _init_worker reloads tokenizer if different name.""" - old_tokenizer = MagicMock() - pd_module._worker_tokenizer = old_tokenizer - pd_module._worker_tokenizer_name = "gpt2" - - new_tokenizer = MagicMock() - mock_tokenizer_class.from_pretrained.return_value = new_tokenizer - - pd_module._init_worker("llama") - - mock_tokenizer_class.from_pretrained.assert_called_once_with( - "llama", resolve_alias=False - ) - assert pd_module._worker_tokenizer is new_tokenizer - assert pd_module._worker_tokenizer_name == "llama" - - def test_decode_tokens_uses_worker_tokenizer(self): - """Test that _decode_tokens uses the worker tokenizer.""" - mock_tokenizer = MagicMock() - mock_tokenizer.decode.return_value = "decoded text" - pd_module._worker_tokenizer = mock_tokenizer - - result = pd_module._decode_tokens([1, 2, 3]) - - mock_tokenizer.decode.assert_called_once_with( - [1, 2, 3], skip_special_tokens=False - ) - assert result == "decoded text" diff --git a/tests/unit/dataset/generator/test_prompt_generator.py b/tests/unit/dataset/generator/test_prompt_generator.py index 0bd151bad..bdb2e33d7 100644 --- a/tests/unit/dataset/generator/test_prompt_generator.py +++ b/tests/unit/dataset/generator/test_prompt_generator.py @@ -661,125 +661,30 @@ def test_generate_user_context_prompt_corpus_not_initialized(self, mock_tokenize assert "corpus" in str(exc_info.value).lower() # ============================================================================ - # Decoded String Cache Tests + # HashIdRandomGenerator Integration Tests # ============================================================================ - def test_decoded_cache_initialized_empty(self, basic_config): - """Test that decoded cache is initialized as empty dict.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - assert hasattr(generator, "_decoded_cache") - assert isinstance(generator._decoded_cache, dict) - assert len(generator._decoded_cache) == 0 + def test_hash_id_rng_initialized(self, basic_config): + """Test that HashIdRandomGenerator is initialized in PromptGenerator.""" + from aiperf.common.hash_id_random_generator import HashIdRandomGenerator - def test_decoded_cache_populated_on_first_call(self, basic_config): - """Test that decoded cache is populated after first call.""" tokenizer, config = basic_config generator = PromptGenerator(config, tokenizer) - _ = generator._generate_cached_prompt(10, [1, 2], 5) + assert hasattr(generator, "_hash_id_corpus_rng") + assert isinstance(generator._hash_id_corpus_rng, HashIdRandomGenerator) - # Should have one entry in decoded cache - expected_key = ((1, 2), 10, 5) - assert expected_key in generator._decoded_cache - assert isinstance(generator._decoded_cache[expected_key], str) - - def test_decoded_cache_hit_on_repeated_call(self, basic_config): - """Test that decoded cache is hit on repeated calls with same params.""" + def test_generate_cached_prompt_deterministic_per_hash_id(self, basic_config): + """Test that same hash_ids produce identical prompts across calls.""" tokenizer, config = basic_config generator = PromptGenerator(config, tokenizer) - # First call - should populate cache result1 = generator._generate_cached_prompt(10, [1, 2], 5) - # Second call with same params - should hit cache - with patch.object(generator.tokenizer, "decode") as mock_decode: - result2 = generator._generate_cached_prompt(10, [1, 2], 5) - mock_decode.assert_not_called() # Decode should NOT be called - - assert result1 == result2 - - def test_decoded_cache_miss_different_hash_ids(self, basic_config): - """Test that different hash_ids create different cache entries.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._generate_cached_prompt(10, [1, 2], 5) - _ = generator._generate_cached_prompt(10, [3, 4], 5) - - # Both should be cached separately - assert ((1, 2), 10, 5) in generator._decoded_cache - assert ((3, 4), 10, 5) in generator._decoded_cache - assert len(generator._decoded_cache) == 2 + # Clear block cache to force regeneration + generator._cache.clear() - def test_decoded_cache_miss_different_num_tokens(self, basic_config): - """Test that different num_tokens creates different cache entry.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._generate_cached_prompt(10, [1, 2], 5) - _ = generator._generate_cached_prompt(8, [1, 2], 5) # Different final block - - # Should have two separate entries - assert ((1, 2), 10, 5) in generator._decoded_cache - assert ((1, 2), 8, 5) in generator._decoded_cache - assert len(generator._decoded_cache) == 2 - - def test_decoded_cache_key_structure(self, basic_config): - """Test that cache key is (tuple(hash_ids), num_tokens, block_size).""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - # 12 tokens = 5 + 5 + 2 (valid final block size) - generator._generate_cached_prompt(12, [1, 2, 3], 5) - - expected_key = ((1, 2, 3), 12, 5) - assert expected_key in generator._decoded_cache - - # ============================================================================ - # _build_token_sequence Method Tests - # ============================================================================ + result2 = generator._generate_cached_prompt(10, [1, 2], 5) - def test_build_token_sequence_returns_tokens(self, basic_config): - """Test that _build_token_sequence returns a list of token IDs.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - tokens = generator._build_token_sequence(10, [1, 2], 5) - - assert isinstance(tokens, list) - assert all(isinstance(t, int) for t in tokens) - assert len(tokens) == 10 - - def test_build_token_sequence_populates_cache(self, basic_config): - """Test that _build_token_sequence populates the token block cache.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._build_token_sequence(10, [1, 2], 5) - - # Token block cache should be populated - assert 1 in generator._cache - assert 2 in generator._cache - - def test_build_token_sequence_does_not_populate_decoded_cache(self, basic_config): - """Test that _build_token_sequence does NOT populate decoded cache.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._build_token_sequence(10, [1, 2], 5) - - # Decoded cache should remain empty - assert len(generator._decoded_cache) == 0 - - def test_build_token_sequence_same_validation_as_generate_cached( - self, basic_config - ): - """Test that _build_token_sequence has same validation as _generate_cached_prompt.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - # This should raise same error as _generate_cached_prompt - with pytest.raises(ConfigurationError): - generator._build_token_sequence(10, [1, 2, 3], 5) # final_block_size = 0 + # Same hash_ids should produce identical results due to HashIdRandomGenerator + assert result1 == result2 diff --git a/tests/unit/dataset/loader/test_bailian_trace.py b/tests/unit/dataset/loader/test_bailian_trace.py index 97885a993..7fb1e021d 100644 --- a/tests/unit/dataset/loader/test_bailian_trace.py +++ b/tests/unit/dataset/loader/test_bailian_trace.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import logging -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from pydantic import ValidationError @@ -112,8 +112,7 @@ class TestBailianTraceDatasetLoader: def mock_prompt_generator(self): generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture @@ -415,12 +414,7 @@ def test_can_load(self, data, expected): # ---- convert_to_conversations ---- - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_convert_to_conversations( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): - mock_parallel_decode.return_value = ["decoded prompt 1", "decoded prompt 2"] - + def test_convert_to_conversations(self, mock_prompt_generator, default_user_config): trace_data = { "100": [ BailianTrace( @@ -447,7 +441,7 @@ def test_convert_to_conversations( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 2 assert conversations[0].session_id == "100" @@ -460,7 +454,7 @@ def test_convert_empty_data(self, mock_prompt_generator, default_user_config): user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - assert loader.convert_to_conversations({}) == [] + assert list(loader.convert_to_conversations({})) == [] def test_convert_without_hash_ids(self, mock_prompt_generator, default_user_config): """When hash_ids is empty, falls back to normal prompt generation.""" @@ -480,62 +474,18 @@ def test_convert_without_hash_ids(self, mock_prompt_generator, default_user_conf user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 mock_prompt_generator.generate.assert_called_once_with( - mean=100, stddev=0, hash_ids=[] + mean=100, stddev=0, hash_ids=[], block_size=512 ) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_parallel_decode_length_mismatch_raises( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): - """strict=True in zip guards against silent data loss.""" - mock_parallel_decode.return_value = ["only one"] # expecting 2 - - trace_data = { - "1": [ - BailianTrace( - chat_id=1, - timestamp=1.0, - input_length=10, - output_length=5, - hash_ids=[1], - ) - ], - "2": [ - BailianTrace( - chat_id=2, - timestamp=2.0, - input_length=20, - output_length=10, - hash_ids=[2], - ) - ], - } - - loader = BailianTraceDatasetLoader( - filename="dummy.jsonl", - user_config=default_user_config, - prompt_generator=mock_prompt_generator, - ) - - with pytest.raises(ValueError, match="zip"): - loader.convert_to_conversations(trace_data) - # ---- multi-turn conversation conversion ---- - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") def test_multi_turn_conversation_ordering( - self, mock_parallel_decode, mock_prompt_generator, default_user_config + self, mock_prompt_generator, default_user_config ): - mock_parallel_decode.return_value = [ - "prompt turn 1", - "prompt turn 2", - "prompt turn 3", - ] - trace_data = { "100": [ BailianTrace( @@ -572,7 +522,7 @@ def test_multi_turn_conversation_ordering( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 conv = conversations[0] @@ -613,8 +563,7 @@ class TestBailianTraceSynthesisIntegration: def mock_prompt_generator(self): generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator def test_speedup_ratio_scales_timestamps(self, mock_prompt_generator): diff --git a/tests/unit/dataset/loader/test_base_trace_loader.py b/tests/unit/dataset/loader/test_base_trace_loader.py new file mode 100644 index 000000000..bbc36e19a --- /dev/null +++ b/tests/unit/dataset/loader/test_base_trace_loader.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for BaseTraceDatasetLoader. + +Covers file hashing, trace_id lifecycle, parallel vs single-threaded threshold, +convert_to_conversations dispatching, and integration with parallel_convert. +""" + +import hashlib +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from aiperf.common.config import ( + EndpointConfig, + InputConfig, + InputTokensConfig, + PromptConfig, + UserConfig, +) +from aiperf.common.config.config_defaults import InputTokensDefaults +from aiperf.common.models import Conversation +from aiperf.dataset.loader.base_trace_loader import ( + _MIN_TRACES_FOR_PARALLEL, + _compute_file_hash, +) +from aiperf.dataset.loader.mooncake_trace import MooncakeTraceDatasetLoader + +# ----------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------- + + +@pytest.fixture +def create_jsonl_file(): + """Create a temporary JSONL file with custom content.""" + filenames = [] + + def _create_file(content_lines): + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for line in content_lines: + f.write(line + "\n") + filenames.append(f.name) + return f.name + + yield _create_file + + for fn in filenames: + Path(fn).unlink(missing_ok=True) + + +@pytest.fixture +def default_user_config() -> UserConfig: + return UserConfig(endpoint=EndpointConfig(model_names=["test-model"])) + + +@pytest.fixture +def mock_prompt_generator(): + """Mock PromptGenerator with required attributes for BaseTraceDatasetLoader.""" + generator = Mock() + generator.generate.return_value = "Generated prompt text" + generator._cache = {} + generator._tokenized_corpus = list(range(100, 200)) + generator._hash_id_corpus_rng = Mock() + generator._hash_id_corpus_rng.seed = 42 + generator.tokenizer = Mock() + generator.tokenizer.resolved_name = "test-model" + generator.tokenizer.block_separation_token_id = None + return generator + + +# ----------------------------------------------------------------------- +# _compute_file_hash +# ----------------------------------------------------------------------- + + +class TestComputeFileHash: + """Tests for _compute_file_hash function.""" + + def test_hash_is_16_hex_chars(self, create_jsonl_file): + """Hash should be 16 hex characters.""" + filename = create_jsonl_file(["test content"]) + result = _compute_file_hash(filename) + assert len(result) == 16 + assert all(c in "0123456789abcdef" for c in result) + + def test_deterministic_for_same_content(self, create_jsonl_file): + """Same file content should produce same hash.""" + f1 = create_jsonl_file(["line 1", "line 2"]) + f2 = create_jsonl_file(["line 1", "line 2"]) + assert _compute_file_hash(f1) == _compute_file_hash(f2) + + def test_different_content_different_hash(self, create_jsonl_file): + """Different file content should produce different hash.""" + f1 = create_jsonl_file(["line 1"]) + f2 = create_jsonl_file(["line 2"]) + assert _compute_file_hash(f1) != _compute_file_hash(f2) + + def test_fallback_on_file_error(self): + """Non-existent file should fall back to hashing the filepath string.""" + result = _compute_file_hash("/nonexistent/path/file.jsonl") + expected = hashlib.sha256(b"/nonexistent/path/file.jsonl").hexdigest()[:16] + assert result == expected + + def test_fallback_on_type_error(self): + """TypeError (e.g. from mock_open) should fall back to filepath hash.""" + with patch("builtins.open") as mock_open: + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.read.return_value = "string not bytes" + mock_open.return_value = mock_file + + result = _compute_file_hash("test.jsonl") + expected = hashlib.sha256(b"test.jsonl").hexdigest()[:16] + assert result == expected + + def test_hash_matches_sha256(self, create_jsonl_file): + """Hash should match first 16 chars of SHA-256.""" + content = ["test line one", "test line two"] + filename = create_jsonl_file(content) + + with open(filename, "rb") as f: + expected = hashlib.sha256(f.read()).hexdigest()[:16] + + assert _compute_file_hash(filename) == expected + + +# ----------------------------------------------------------------------- +# load_dataset — trace_id lifecycle +# ----------------------------------------------------------------------- + + +class TestLoadDatasetTraceId: + """Tests that load_dataset computes trace_id and sets it on the RNG.""" + + def test_load_dataset_sets_trace_id( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """load_dataset should compute file hash and set trace_id.""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader.load_dataset() + + assert loader._trace_id != "" + assert len(loader._trace_id) == 16 + mock_prompt_generator._hash_id_corpus_rng.set_trace_id.assert_called_once_with( + loader._trace_id + ) + + def test_trace_id_matches_file_hash( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """trace_id should match the SHA-256 hash of the file.""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + expected_hash = _compute_file_hash(filename) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader.load_dataset() + + assert loader._trace_id == expected_hash + + def test_different_files_different_trace_ids( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Different files should produce different trace_ids.""" + f1 = create_jsonl_file(['{"input_length": 100, "hash_ids": [1]}']) + f2 = create_jsonl_file(['{"input_length": 200, "hash_ids": [2]}']) + + loader1 = MooncakeTraceDatasetLoader( + filename=f1, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader1.load_dataset() + + loader2 = MooncakeTraceDatasetLoader( + filename=f2, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader2.load_dataset() + + assert loader1._trace_id != loader2._trace_id + + +# ----------------------------------------------------------------------- +# convert_to_conversations — threshold dispatching +# ----------------------------------------------------------------------- + + +class TestConvertToConversationsDispatching: + """Tests for parallel vs single-threaded dispatching.""" + + def test_empty_data_returns_empty(self, mock_prompt_generator, default_user_config): + """Empty data dict should return empty list.""" + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + result = list(loader.convert_to_conversations({})) + assert result == [] + + def test_small_dataset_uses_single_threaded( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Datasets with fewer than _MIN_TRACES_FOR_PARALLEL traces use single-threaded.""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert" + ) as mock_parallel: + result = list(loader.convert_to_conversations(data)) + + mock_parallel.assert_not_called() + assert len(result) == 1 + assert isinstance(result[0], Conversation) + + def test_large_dataset_uses_parallel( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Datasets with >= _MIN_TRACES_FOR_PARALLEL traces use parallel conversion.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + mock_conversations = [Conversation(session_id=f"s{i}") for i in range(3)] + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=mock_conversations, + ) as mock_parallel: + result = list(loader.convert_to_conversations(data)) + + mock_parallel.assert_called_once() + assert result == mock_conversations + + def test_parallel_convert_receives_correct_args( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """parallel_convert should receive the right parameters from the loader.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data, num_workers=8, batch_size=50)) + + call_kwargs = mock_parallel.call_args[1] + assert call_kwargs["tokenizer_name"] == "test-model" + assert call_kwargs["base_seed"] == 42 + assert call_kwargs["block_size"] == InputTokensDefaults.BLOCK_SIZE + assert call_kwargs["sep_token"] is None + assert call_kwargs["trace_id"] == loader._trace_id + assert call_kwargs["num_workers"] == 8 + assert call_kwargs["batch_size"] == 50 + + def test_exactly_threshold_uses_parallel( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Exactly _MIN_TRACES_FOR_PARALLEL traces should use parallel.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + mock_parallel.assert_called_once() + + def test_below_threshold_uses_single_threaded( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """One below _MIN_TRACES_FOR_PARALLEL uses single-threaded.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL - 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert" + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + mock_parallel.assert_not_called() + + +# ----------------------------------------------------------------------- +# _convert_single_threaded +# ----------------------------------------------------------------------- + + +class TestConvertSingleThreaded: + """Tests for the single-threaded conversion fallback.""" + + def test_text_input_used_directly( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces with text_input should use the literal text.""" + content = ['{"text_input": "Hello world", "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + assert len(conversations) == 1 + assert conversations[0].turns[0].texts[0].contents == ["Hello world"] + mock_prompt_generator.generate.assert_not_called() + + def test_hash_ids_calls_generate( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces with hash_ids should call prompt_generator.generate().""" + content = ['{"input_length": 100, "hash_ids": [1, 2], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + assert len(conversations) == 1 + mock_prompt_generator.generate.assert_called_once_with( + mean=100, + stddev=0, + hash_ids=[1, 2], + block_size=InputTokensDefaults.BLOCK_SIZE, + ) + + def test_no_input_calls_generate_with_empty_hash_ids( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces with input_length but no hash_ids still call generate.""" + content = ['{"input_length": 50, "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + list(loader.convert_to_conversations(data)) + + mock_prompt_generator.generate.assert_called_once_with( + mean=50, + stddev=0, + hash_ids=[], + block_size=InputTokensDefaults.BLOCK_SIZE, + ) + + def test_turn_fields_populated( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Turn objects should have correct timestamp, delay, max_tokens.""" + content = [ + '{"input_length": 100, "hash_ids": [1], "timestamp": 5000, "output_length": 42}' + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + turn = conversations[0].turns[0] + assert turn.timestamp == 5000 + assert turn.max_tokens == 42 + + def test_multi_session_conversion( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Multiple sessions each produce a separate Conversation.""" + content = [ + '{"session_id": "s1", "input_length": 100, "hash_ids": [1], "timestamp": 1000}', + '{"session_id": "s2", "input_length": 200, "hash_ids": [2], "timestamp": 2000}', + '{"session_id": "s1", "input_length": 150, "hash_ids": [3], "delay": 50}', + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + session_ids = {c.session_id for c in conversations} + assert "s1" in session_ids + assert "s2" in session_ids + + s1_conv = next(c for c in conversations if c.session_id == "s1") + assert len(s1_conv.turns) == 2 + + +# ----------------------------------------------------------------------- +# Block size precedence +# ----------------------------------------------------------------------- + + +class TestBlockSizePrecedence: + """Tests that block size follows the correct precedence chain.""" + + def test_default_block_size(self, mock_prompt_generator, default_user_config): + """Without any overrides, block_size should be InputTokensDefaults.BLOCK_SIZE.""" + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + assert loader._block_size == InputTokensDefaults.BLOCK_SIZE + + def test_plugin_default_block_size( + self, mock_prompt_generator, default_user_config + ): + """Plugin metadata default should override the hardcoded fallback.""" + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + assert loader._block_size == 16 + + def test_user_cli_overrides_plugin_default(self, mock_prompt_generator): + """User CLI --isl-block-size should override plugin metadata default.""" + user_config = UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig( + prompt=PromptConfig( + input_tokens=InputTokensConfig(block_size=64), + ), + ), + ) + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + assert loader._block_size == 64 + + def test_block_size_passed_to_single_threaded( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Single-threaded path should pass block_size to generate().""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + data = loader.load_dataset() + list(loader.convert_to_conversations(data)) + + mock_prompt_generator.generate.assert_called_once_with( + mean=100, + stddev=0, + hash_ids=[1], + block_size=16, + ) + + def test_block_size_passed_to_parallel( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Parallel path should pass block_size to parallel_convert.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + assert mock_parallel.call_args[1]["block_size"] == 16 + + +# ----------------------------------------------------------------------- +# Serialization for parallel path +# ----------------------------------------------------------------------- + + +class TestSessionSerialization: + """Tests that traces are correctly serialized to dicts for parallel workers.""" + + def test_traces_serialized_via_model_dump( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces should be serialized via model_dump() before passing to parallel_convert.""" + content = [ + f'{{"input_length": {100 + i}, "hash_ids": [{i}], "timestamp": {i * 1000}, "output_length": {20 + i}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + + sessions = mock_parallel.call_args[1]["sessions"] + assert len(sessions) > 0 + + # Each session should be (str, list[dict]) + for sid, traces in sessions: + assert isinstance(sid, str) + for trace_dict in traces: + assert isinstance(trace_dict, dict) + assert "input_length" in trace_dict diff --git a/tests/unit/dataset/loader/test_parallel_convert.py b/tests/unit/dataset/loader/test_parallel_convert.py new file mode 100644 index 000000000..f557b6ff6 --- /dev/null +++ b/tests/unit/dataset/loader/test_parallel_convert.py @@ -0,0 +1,938 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for parallel_convert module. + +Covers worker initialization, batch processing, daemon flag handling, +shared memory management, and end-to-end parallel conversion. +""" + +import multiprocessing as mp +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import pytest + +import aiperf.dataset.loader.parallel_convert as parallel_convert_mod +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator +from aiperf.common.models import Conversation +from aiperf.dataset.generator.prompt import sample_tokens_from_corpus +from aiperf.dataset.loader.parallel_convert import ( + _init_worker, + _process_batch, + _set_daemon, + _WorkerInitArgs, + _WorkerState, + parallel_convert, +) + +# ----------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------- + + +@pytest.fixture +def sample_corpus(): + """A small corpus of token IDs for testing.""" + return list(range(100, 200)) + + +@pytest.fixture +def sample_corpus_array(sample_corpus): + """Corpus as numpy int32 array.""" + return np.array(sample_corpus, dtype=np.int32) + + +@pytest.fixture +def mock_tokenizer(): + """Mock tokenizer that decodes tokens into readable strings.""" + tok = MagicMock() + tok.decode.side_effect = lambda ids, **kw: " ".join(f"t{i}" for i in ids) + return tok + + +@pytest.fixture +def setup_worker(sample_corpus_array, mock_tokenizer): + """Set up module-level _worker_state for _process_batch tests.""" + seed = 42 + hash_rng = HashIdRandomGenerator(seed, _internal=True) + hash_rng.set_trace_id("test_trace") + + parallel_convert_mod._worker_state = _WorkerState( + tokenizer=mock_tokenizer, + corpus=sample_corpus_array, + shm=MagicMock(), + hash_rng=hash_rng, + block_size=10, + sep_token=None, + sample_tokens=sample_tokens_from_corpus, + ) + yield + parallel_convert_mod._worker_state = None + + +@pytest.fixture +def setup_worker_with_sep(sample_corpus_array, mock_tokenizer): + """Set up worker with a separator token (BOS/EOS).""" + seed = 42 + hash_rng = HashIdRandomGenerator(seed, _internal=True) + hash_rng.set_trace_id("test_trace") + + parallel_convert_mod._worker_state = _WorkerState( + tokenizer=mock_tokenizer, + corpus=sample_corpus_array, + shm=MagicMock(), + hash_rng=hash_rng, + block_size=10, + sep_token=1, + sample_tokens=sample_tokens_from_corpus, + ) + yield + parallel_convert_mod._worker_state = None + + +# ----------------------------------------------------------------------- +# _set_daemon +# ----------------------------------------------------------------------- + + +class TestSetDaemon: + """Tests for daemon flag manipulation.""" + + def test_set_daemon_true(self): + """Setting daemon to True should work on a non-daemon process.""" + original = mp.current_process().daemon + try: + _set_daemon(True) + assert mp.current_process().daemon is True + finally: + _set_daemon(original) + + def test_set_daemon_false(self): + """Setting daemon to False should work.""" + original = mp.current_process().daemon + try: + _set_daemon(False) + assert mp.current_process().daemon is False + finally: + _set_daemon(original) + + def test_set_daemon_fallback_on_assertion_error(self): + """If daemon= setter raises AssertionError, fallback to _config.""" + proc = mp.current_process() + original = proc.daemon + + with patch.object( + type(proc), "daemon", property(fset=Mock(side_effect=AssertionError)) + ): + _set_daemon(True) + assert proc._config["daemon"] is True + + # Restore + proc._config["daemon"] = original + + +# ----------------------------------------------------------------------- +# _process_batch +# ----------------------------------------------------------------------- + + +class TestProcessBatch: + """Tests for _process_batch worker function.""" + + def test_text_input_traces(self, setup_worker): + """Traces with text_input should use the literal text.""" + batch = [ + ( + "session-1", + [ + { + "text_input": "Hello world", + "timestamp": 100, + "delay": None, + "output_length": 10, + }, + ], + ), + ] + results = _process_batch(batch) + + assert len(results) == 1 + sid, turns = results[0] + assert sid == "session-1" + assert len(turns) == 1 + ts, delay, prompt, max_tokens = turns[0] + assert prompt == "Hello world" + assert ts == 100 + assert max_tokens == 10 + + def test_hash_ids_traces(self, setup_worker, mock_tokenizer): + """Traces with hash_ids should generate tokens and decode.""" + batch = [ + ( + "session-1", + [ + { + "hash_ids": [1, 2], + "input_length": 15, + "timestamp": 200, + "delay": 5, + "output_length": 20, + }, + ], + ), + ] + results = _process_batch(batch) + + assert len(results) == 1 + sid, turns = results[0] + assert sid == "session-1" + assert len(turns) == 1 + ts, delay, prompt, max_tokens = turns[0] + assert ts == 200 + assert delay == 5 + assert max_tokens == 20 + # decode was called with generated tokens + mock_tokenizer.decode.assert_called_once() + assert isinstance(prompt, str) + assert len(prompt) > 0 + + def test_empty_trace_no_hash_ids_no_text(self, setup_worker): + """Traces without hash_ids or text_input produce empty prompt.""" + batch = [ + ( + "session-1", + [ + { + "timestamp": 300, + "delay": None, + "output_length": 5, + "input_length": 0, + }, + ], + ), + ] + results = _process_batch(batch) + + _, turns = results[0] + _, _, prompt, _ = turns[0] + assert prompt == "" + + def test_multiple_sessions_in_batch(self, setup_worker, mock_tokenizer): + """Multiple sessions in one batch are all processed.""" + batch = [ + ( + "s1", + [ + { + "text_input": "prompt A", + "timestamp": 1, + "delay": None, + "output_length": 10, + } + ], + ), + ( + "s2", + [ + { + "text_input": "prompt B", + "timestamp": 2, + "delay": None, + "output_length": 20, + } + ], + ), + ( + "s3", + [ + { + "text_input": "prompt C", + "timestamp": 3, + "delay": None, + "output_length": 30, + } + ], + ), + ] + results = _process_batch(batch) + + assert len(results) == 3 + assert results[0][0] == "s1" + assert results[1][0] == "s2" + assert results[2][0] == "s3" + assert results[0][1][0][2] == "prompt A" + assert results[1][1][0][2] == "prompt B" + assert results[2][1][0][2] == "prompt C" + + def test_multi_turn_session(self, setup_worker, mock_tokenizer): + """A session with multiple turns (traces) processes all turns.""" + batch = [ + ( + "session-1", + [ + { + "text_input": "turn 1", + "timestamp": 100, + "delay": None, + "output_length": 10, + }, + { + "text_input": "turn 2", + "timestamp": 200, + "delay": 50, + "output_length": 20, + }, + { + "text_input": "turn 3", + "timestamp": 300, + "delay": 100, + "output_length": 30, + }, + ], + ), + ] + results = _process_batch(batch) + + _, turns = results[0] + assert len(turns) == 3 + assert turns[0][2] == "turn 1" + assert turns[1][2] == "turn 2" + assert turns[2][2] == "turn 3" + + def test_hash_id_block_cache_reuse(self, setup_worker, mock_tokenizer): + """Same hash_id within a batch should reuse cached tokens.""" + batch = [ + ( + "s1", + [ + { + "hash_ids": [42], + "input_length": 10, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ( + "s2", + [ + { + "hash_ids": [42], + "input_length": 10, + "timestamp": 2, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + results = _process_batch(batch) + + # Both sessions with same hash_id should get same decoded output + prompt_1 = results[0][1][0][2] + prompt_2 = results[1][1][0][2] + assert prompt_1 == prompt_2 + + def test_different_hash_ids_produce_different_prompts( + self, setup_worker, mock_tokenizer + ): + """Different hash_ids should produce different token sequences.""" + batch = [ + ( + "s1", + [ + { + "hash_ids": [100], + "input_length": 10, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ( + "s2", + [ + { + "hash_ids": [200], + "input_length": 10, + "timestamp": 2, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + results = _process_batch(batch) + + prompt_1 = results[0][1][0][2] + prompt_2 = results[1][1][0][2] + assert prompt_1 != prompt_2 + + def test_final_block_size_calculation(self, setup_worker, mock_tokenizer): + """Last hash block should get the remainder tokens.""" + # block_size=10, input_length=25, 3 hash_ids + # first two blocks: 10 tokens each, last block: 25 - 2*10 = 5 tokens + batch = [ + ( + "s1", + [ + { + "hash_ids": [1, 2, 3], + "input_length": 25, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + _process_batch(batch) + + # Verify decode was called (tokens were generated) + assert mock_tokenizer.decode.call_count == 1 + decoded_tokens = mock_tokenizer.decode.call_args[0][0] + # 10 + 10 + 5 = 25 total tokens + assert len(decoded_tokens) == 25 + + def test_separator_token_prepended(self, setup_worker_with_sep, mock_tokenizer): + """When sep_token is set, each block should have it prepended.""" + batch = [ + ( + "s1", + [ + { + "hash_ids": [1], + "input_length": 10, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + _process_batch(batch) + + decoded_tokens = mock_tokenizer.decode.call_args[0][0] + # With sep_token=1, first token in block should be 1 + assert decoded_tokens[0] == 1 + + def test_mixed_text_input_and_hash_ids(self, setup_worker, mock_tokenizer): + """Session with both text_input and hash_id traces.""" + batch = [ + ( + "s1", + [ + { + "text_input": "literal text", + "timestamp": 1, + "delay": None, + "output_length": 10, + }, + { + "hash_ids": [5], + "input_length": 10, + "timestamp": 2, + "delay": None, + "output_length": 20, + }, + ], + ), + ] + results = _process_batch(batch) + + _, turns = results[0] + assert turns[0][2] == "literal text" + assert turns[1][2] != "literal text" # Generated prompt + + def test_none_fields_preserved(self, setup_worker): + """None values for timestamp, delay, output_length are preserved.""" + batch = [ + ( + "s1", + [ + { + "text_input": "test", + "timestamp": None, + "delay": None, + "output_length": None, + }, + ], + ), + ] + results = _process_batch(batch) + + ts, delay, _, max_tokens = results[0][1][0] + assert ts is None + assert delay is None + assert max_tokens is None + + +# ----------------------------------------------------------------------- +# _init_worker +# ----------------------------------------------------------------------- + + +class TestInitWorker: + """Tests for _init_worker function.""" + + def test_init_worker_sets_up_state(self, sample_corpus_array, tmp_path): + """_init_worker should populate _worker dict with all required fields.""" + from multiprocessing import shared_memory + + shm = shared_memory.SharedMemory(create=True, size=sample_corpus_array.nbytes) + try: + np.copyto( + np.ndarray( + sample_corpus_array.shape, + dtype=sample_corpus_array.dtype, + buffer=shm.buf, + ), + sample_corpus_array, + ) + + mock_tok = MagicMock() + args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=len(sample_corpus_array), + tokenizer_name="test-model", + base_seed=42, + block_size=10, + sep_token=1, + trace_id="abc123", + ) + with patch( + "aiperf.common.tokenizer.Tokenizer.from_pretrained", + return_value=mock_tok, + ): + _init_worker(args) + + state = parallel_convert_mod._worker_state + assert state is not None + assert state.tokenizer is mock_tok + assert state.block_size == 10 + assert state.sep_token == 1 + assert isinstance(state.hash_rng, HashIdRandomGenerator) + assert np.array_equal(state.corpus, sample_corpus_array) + finally: + parallel_convert_mod._worker_state = None + shm.close() + shm.unlink() + + def test_init_worker_sets_offline_env(self, sample_corpus_array): + """Worker should set HF offline environment variables.""" + import os + from multiprocessing import shared_memory + + original_hf = os.environ.get("HF_HUB_OFFLINE") + original_tf = os.environ.get("TRANSFORMERS_OFFLINE") + + shm = shared_memory.SharedMemory(create=True, size=sample_corpus_array.nbytes) + try: + np.copyto( + np.ndarray( + sample_corpus_array.shape, + dtype=sample_corpus_array.dtype, + buffer=shm.buf, + ), + sample_corpus_array, + ) + + args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=len(sample_corpus_array), + tokenizer_name="test-model", + base_seed=42, + block_size=10, + sep_token=None, + trace_id="abc", + ) + with patch( + "aiperf.common.tokenizer.Tokenizer.from_pretrained", + return_value=MagicMock(), + ): + _init_worker(args) + + assert os.environ.get("HF_HUB_OFFLINE") == "1" + assert os.environ.get("TRANSFORMERS_OFFLINE") == "1" + finally: + parallel_convert_mod._worker_state = None + shm.close() + shm.unlink() + # Restore env + if original_hf is None: + os.environ.pop("HF_HUB_OFFLINE", None) + else: + os.environ["HF_HUB_OFFLINE"] = original_hf + if original_tf is None: + os.environ.pop("TRANSFORMERS_OFFLINE", None) + else: + os.environ["TRANSFORMERS_OFFLINE"] = original_tf + + +# ----------------------------------------------------------------------- +# parallel_convert — end-to-end +# ----------------------------------------------------------------------- + + +class TestParallelConvert: + """Tests for the parallel_convert orchestration function.""" + + def test_empty_sessions_returns_empty(self, sample_corpus): + """Empty input returns empty output.""" + result = list( + parallel_convert( + sessions=[], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + assert result == [] + + def test_returns_conversation_objects(self, sample_corpus): + """Output should be a list of Conversation objects.""" + sessions = [ + ( + "s1", + [ + { + "text_input": "hello", + "timestamp": 1, + "delay": None, + "output_length": 5, + } + ], + ), + ] + + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + + mock_pool_instance.imap.return_value = [ + [("s1", [(1, None, "hello", 5)])], + ] + + result = list( + parallel_convert( + sessions=sessions, + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + assert len(result) == 1 + assert isinstance(result[0], Conversation) + assert result[0].session_id == "s1" + assert len(result[0].turns) == 1 + assert result[0].turns[0].timestamp == 1 + assert result[0].turns[0].max_tokens == 5 + + def test_batching_splits_sessions(self, sample_corpus): + """Sessions should be split into batches of batch_size.""" + sessions = [ + ( + f"s{i}", + [ + { + "text_input": f"p{i}", + "timestamp": i, + "delay": None, + "output_length": 5, + } + ], + ) + for i in range(5) + ] + + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + + mock_pool_instance.imap.return_value = [ + [(f"s{i}", [(i, None, f"p{i}", 5)]) for i in range(2)], + [(f"s{i}", [(i, None, f"p{i}", 5)]) for i in range(2, 5)], + ] + + list( + parallel_convert( + sessions=sessions, + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + batch_size=2, + ) + ) + + # map was called with batches of size 2 + batches = mock_pool_instance.imap.call_args[0][1] + assert len(batches) == 3 # 5 sessions / 2 batch_size = 3 batches + assert len(batches[0]) == 2 + assert len(batches[1]) == 2 + assert len(batches[2]) == 1 + + def test_daemon_flag_restored(self, sample_corpus): + """Daemon flag should be restored after Pool finishes.""" + original_daemon = mp.current_process().daemon + + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + mock_pool_instance.imap.return_value = [] + + list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "t", + "timestamp": 1, + "delay": None, + "output_length": 1, + } + ], + ) + ], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + assert mp.current_process().daemon == original_daemon + + def test_shared_memory_cleanup(self, sample_corpus): + """Shared memory should be cleaned up even on errors.""" + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + mock_pool_instance.imap.side_effect = RuntimeError("Pool error") + + with pytest.raises(RuntimeError, match="Pool error"): + list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "t", + "timestamp": 1, + "delay": None, + "output_length": 1, + } + ], + ) + ], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + # No leaked shared memory (if it leaked, subsequent tests would detect it) + + def test_multi_turn_conversations(self, sample_corpus): + """Sessions with multiple turns should produce multi-turn Conversations.""" + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + + mock_pool_instance.imap.return_value = [ + [("s1", [(100, None, "turn 1", 10), (200, 50, "turn 2", 20)])], + ] + + result = list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "turn 1", + "timestamp": 100, + "delay": None, + "output_length": 10, + }, + { + "text_input": "turn 2", + "timestamp": 200, + "delay": 50, + "output_length": 20, + }, + ], + ) + ], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + assert len(result) == 1 + conv = result[0] + assert len(conv.turns) == 2 + assert conv.turns[0].timestamp == 100 + assert conv.turns[0].delay is None + assert conv.turns[0].max_tokens == 10 + assert conv.turns[1].timestamp == 200 + assert conv.turns[1].delay == 50 + assert conv.turns[1].max_tokens == 20 + + def test_pool_receives_correct_init_args(self, sample_corpus): + """Pool should be initialized with correct arguments.""" + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + mock_pool_instance.imap.return_value = [] + + list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "t", + "timestamp": 1, + "delay": None, + "output_length": 1, + } + ], + ) + ], + tokenizer_name="my-tokenizer", + corpus=sample_corpus, + base_seed=12345, + block_size=64, + sep_token=7, + trace_id="trace_abc", + num_workers=4, + ) + ) + + call_args = MockPool.call_args + assert call_args[0][0] == 4 # num_workers + assert call_args[0][1] is _init_worker + initargs = call_args[0][2] + assert len(initargs) == 1 + args = initargs[0] + assert isinstance(args, _WorkerInitArgs) + assert args.tokenizer_name == "my-tokenizer" + assert args.base_seed == 12345 + assert args.block_size == 64 + assert args.sep_token == 7 + assert args.trace_id == "trace_abc" + + +# ----------------------------------------------------------------------- +# Determinism: _process_batch produces identical results with same seed +# ----------------------------------------------------------------------- + + +class TestProcessBatchDeterminism: + """Tests that _process_batch is deterministic across invocations.""" + + def _setup_and_process( + self, corpus_array, hash_ids, input_length, block_size, trace_id, seed=42 + ): + """Helper: set up worker state and process a single batch.""" + hash_rng = HashIdRandomGenerator(seed, _internal=True) + hash_rng.set_trace_id(trace_id) + + mock_tok = MagicMock() + mock_tok.decode.side_effect = lambda ids, **kw: ",".join(str(i) for i in ids) + + parallel_convert_mod._worker_state = _WorkerState( + tokenizer=mock_tok, + corpus=corpus_array, + shm=MagicMock(), + hash_rng=hash_rng, + block_size=block_size, + sep_token=None, + sample_tokens=sample_tokens_from_corpus, + ) + + batch = [ + ( + "s1", + [ + { + "hash_ids": hash_ids, + "input_length": input_length, + "timestamp": 1, + "delay": None, + "output_length": 5, + } + ], + ), + ] + result = _process_batch(batch) + parallel_convert_mod._worker_state = None + return result[0][1][0][2] # prompt string + + def test_same_seed_same_trace_id_same_result(self, sample_corpus_array): + """Identical seed + trace_id + hash_ids = identical prompt.""" + prompt_1 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_a" + ) + prompt_2 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_a" + ) + assert prompt_1 == prompt_2 + + def test_different_trace_id_different_result(self, sample_corpus_array): + """Different trace_ids produce different prompts.""" + prompt_1 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_a" + ) + prompt_2 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_b" + ) + assert prompt_1 != prompt_2 + + def test_different_seed_different_result(self, sample_corpus_array): + """Different seeds produce different prompts.""" + prompt_1 = self._setup_and_process( + sample_corpus_array, [1], 10, 10, "trace_a", seed=42 + ) + prompt_2 = self._setup_and_process( + sample_corpus_array, [1], 10, 10, "trace_a", seed=99 + ) + assert prompt_1 != prompt_2 + + def test_different_hash_ids_different_result(self, sample_corpus_array): + """Different hash_ids produce different prompts.""" + prompt_1 = self._setup_and_process(sample_corpus_array, [10], 10, 10, "trace_a") + prompt_2 = self._setup_and_process(sample_corpus_array, [20], 10, 10, "trace_a") + assert prompt_1 != prompt_2 diff --git a/tests/unit/dataset/loader/test_trace.py b/tests/unit/dataset/loader/test_trace.py index d03922974..9c7854f6a 100644 --- a/tests/unit/dataset/loader/test_trace.py +++ b/tests/unit/dataset/loader/test_trace.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import logging -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from pydantic import ValidationError @@ -111,9 +111,7 @@ def mock_prompt_generator(self): generator = Mock() generator.generate.return_value = "Generated prompt text" # Required for convert_to_conversations() to check string cache - generator._decoded_cache = {} - # Mock _build_token_sequence to return a simple token list - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture @@ -356,18 +354,8 @@ def test_load_dataset_logs_skipped_traces( # Check that the skipped traces message is logged assert f"Skipped {expected_skipped:,} traces" in caplog.text - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_convert_to_conversations( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): + def test_convert_to_conversations(self, mock_prompt_generator, default_user_config): """Test conversion of trace data to conversations.""" - # Mock parallel_decode to return decoded prompts - mock_parallel_decode.return_value = [ - "decoded prompt 1", - "decoded prompt 2", - "decoded prompt 3", - ] - # Setup trace data trace_data = { "session-1": [ @@ -401,7 +389,7 @@ def test_convert_to_conversations( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 3 @@ -432,7 +420,7 @@ def test_convert_to_conversations_empty_data( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations({}) + conversations = list(loader.convert_to_conversations({})) assert len(conversations) == 0 @@ -453,7 +441,7 @@ def test_convert_to_conversations_with_text_input( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 # One conversation with multiple turns conversation = conversations[0] @@ -801,8 +789,8 @@ def test_load_dataset_max_isl_and_max_osl_combined( class TestMooncakeTraceReproducibility: """Tests for reproducibility of Mooncake trace prompt generation. - These tests verify that the two-phase Mooncake flow with parallel_decode - yields identical prompts across runs when the RNG is seeded consistently. + These tests verify that HashIdRandomGenerator-based generation yields + identical prompts across runs when the RNG is seeded consistently. """ @pytest.fixture @@ -810,8 +798,7 @@ def mock_prompt_generator(self): """Create a mock prompt generator for testing.""" generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture @@ -826,9 +813,8 @@ def user_config_for_reproducibility(self): ), ) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") def test_mooncake_flow_reproducibility_with_same_seed( - self, mock_parallel_decode, mock_tokenizer_cls, user_config_for_reproducibility + self, mock_tokenizer_cls, user_config_for_reproducibility ): """Verify Mooncake flow produces identical prompts across runs with same seed. @@ -838,16 +824,7 @@ def test_mooncake_flow_reproducibility_with_same_seed( from aiperf.common import random_generator as rng from aiperf.dataset.generator import PromptGenerator - # Mock parallel_decode to return deterministic results based on input - def deterministic_decode(token_sequences, tokenizer_name): - return [ - f"decoded_prompt_{i}_{len(seq)}" - for i, seq in enumerate(token_sequences) - ] - - mock_parallel_decode.side_effect = deterministic_decode - - # Create trace data with hash_ids to exercise the two-phase flow + # Create trace data with hash_ids trace_data = { "session-1": [ MooncakeTrace( @@ -913,39 +890,6 @@ def deterministic_decode(token_sequences, tokenizer_name): f"First run: {prompts1}, Second run: {prompts2}" ) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_parallel_decode_length_mismatch_raises( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): - """Verify that length mismatch between pending_decodes and decoded_prompts raises. - - This tests the strict=True behavior in zip() that guards against silent data loss. - """ - # Mock parallel_decode to return FEWER results than expected - mock_parallel_decode.return_value = ["decoded prompt 1"] # Only 1, expecting 3 - - trace_data = { - "session-1": [ - MooncakeTrace(input_length=100, hash_ids=[1, 2], timestamp=1000), - ], - "session-2": [ - MooncakeTrace(input_length=200, hash_ids=[3, 4, 5], timestamp=2000), - ], - "session-3": [ - MooncakeTrace(input_length=150, hash_ids=[6], timestamp=3000), - ], - } - - loader = MooncakeTraceDatasetLoader( - filename="dummy.jsonl", - user_config=default_user_config, - prompt_generator=mock_prompt_generator, - ) - - # Should raise ValueError due to strict=True in zip - with pytest.raises(ValueError, match="zip"): - loader.convert_to_conversations(trace_data) - # ============================================================================ # Synthesis Integration Tests @@ -986,8 +930,7 @@ def mock_prompt_generator(self): """Create a mock prompt generator for testing.""" generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture diff --git a/tests/unit/dataset/test_dataset_manager.py b/tests/unit/dataset/test_dataset_manager.py index 81f97d577..e54861276 100644 --- a/tests/unit/dataset/test_dataset_manager.py +++ b/tests/unit/dataset/test_dataset_manager.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from pydantic import ValidationError @@ -404,9 +404,8 @@ async def test_in_memory_dataset_freed_after_client_initialization( ProfileConfigureCommand(config=user_config, service_id="test_service") ) - # After configuration, in-memory dataset should be empty - assert dataset_manager.dataset == {} - assert dataset_manager._conversation_ids_cache == [] + # After configuration, conversation count should reflect dataset size + assert dataset_manager._conversation_count > 0 @pytest.mark.asyncio async def test_dataset_configured_event_set_after_client_initialization( @@ -463,8 +462,7 @@ async def test_handle_conversation_request_uses_dataset_client( 0 ].conversation_id - # Verify in-memory dataset is empty (freed) - assert dataset_manager.dataset == {} + # Verify in-memory dataset was not materialized # Request should still work via dataset client request = ConversationRequestMessage( @@ -489,8 +487,7 @@ async def test_handle_conversation_turn_request_uses_dataset_client( 0 ].conversation_id - # Verify in-memory dataset is empty (freed) - assert dataset_manager.dataset == {} + # Verify in-memory dataset was not materialized # Request should still work via dataset client request = ConversationTurnRequestMessage( @@ -572,17 +569,12 @@ async def test_configure_client_compress_only_skips_client_creation( """In compress_only mode, _configure_dataset_client_and_free_memory skips client creation.""" service_config = ServiceConfig(service_run_type=ServiceRunType.KUBERNETES) manager = DatasetManager(service_config, base_user_config) - # Simulate some dataset entries - manager.dataset = {"conv1": MagicMock(), "conv2": MagicMock()} - manager._conversation_ids_cache = ["conv1", "conv2"] + manager._conversation_count = 2 await manager._configure_dataset_client_and_free_memory() # Should have set dataset_configured event assert manager.dataset_configured.is_set() - # Should have freed memory (cleared dataset) - assert manager.dataset == {} - assert manager._conversation_ids_cache == [] # Should NOT have created a dataset client assert manager._dataset_client is None diff --git a/tests/unit/dataset/test_dataset_manager_inputs_json.py b/tests/unit/dataset/test_dataset_manager_inputs_json.py index 6ceb2beb1..2b92f2015 100644 --- a/tests/unit/dataset/test_dataset_manager_inputs_json.py +++ b/tests/unit/dataset/test_dataset_manager_inputs_json.py @@ -112,7 +112,10 @@ async def test_generate_inputs_json_session_order_preservation( written_json = json.loads(capture_file_writes.written_content) session_ids = [session["session_id"] for session in written_json["data"]] - expected_order = list(populated_dataset_manager.dataset.keys()) + expected_order = [ + conv.conversation_id + for conv in populated_dataset_manager.dataset_metadata.conversations + ] assert session_ids == expected_order @pytest.mark.asyncio From cc5bd8dd8b2f1f2e4e748cedc1054e8525d714d4 Mon Sep 17 00:00:00 2001 From: Anthony Casagrande Date: Tue, 3 Mar 2026 20:59:25 -0800 Subject: [PATCH 7/7] fix: pass trust_remote_code and revision through parallel_decode parallel_decode and _init_worker were not forwarding tokenizer config args, causing failures with tokenizers like Kimi that require trust_remote_code=True. Signed-off-by: Anthony Casagrande --- .../dataset/generator/parallel_decode.py | 25 +++++- src/aiperf/dataset/loader/mooncake_trace.py | 9 ++- .../dataset/generator/test_parallel_decode.py | 77 ++++++++++++++++++- 3 files changed, 103 insertions(+), 8 deletions(-) diff --git a/src/aiperf/dataset/generator/parallel_decode.py b/src/aiperf/dataset/generator/parallel_decode.py index f8558261c..03f41f3f4 100644 --- a/src/aiperf/dataset/generator/parallel_decode.py +++ b/src/aiperf/dataset/generator/parallel_decode.py @@ -25,7 +25,11 @@ _worker_tokenizer_name: str | None = None -def _init_worker(tokenizer_name: str) -> None: +def _init_worker( + tokenizer_name: str, + trust_remote_code: bool = False, + revision: str = "main", +) -> None: """Initialize tokenizer in worker process. This function is called once per worker process when the ProcessPoolExecutor @@ -33,6 +37,8 @@ def _init_worker(tokenizer_name: str) -> None: Args: tokenizer_name: Name or path of the pretrained tokenizer to load. + trust_remote_code: Whether to trust remote code when loading. + revision: The specific model version to use. """ global _worker_tokenizer, _worker_tokenizer_name if _worker_tokenizer is None or _worker_tokenizer_name != tokenizer_name: @@ -44,7 +50,10 @@ def _init_worker(tokenizer_name: str) -> None: from aiperf.common.tokenizer import Tokenizer _worker_tokenizer = Tokenizer.from_pretrained( - tokenizer_name, resolve_alias=False + tokenizer_name, + trust_remote_code=trust_remote_code, + revision=revision, + resolve_alias=False, ) _worker_tokenizer_name = tokenizer_name @@ -71,6 +80,8 @@ def parallel_decode( tokenizer_name: str, max_workers: int | None = None, chunksize: int = 50, + trust_remote_code: bool = False, + revision: str = "main", ) -> list[str]: """Decode multiple token sequences in parallel using ProcessPoolExecutor. @@ -83,6 +94,8 @@ def parallel_decode( tokenizer_name: Name or path of the pretrained tokenizer to use in workers. max_workers: Number of worker processes. Defaults to min(cpu_count, 8). chunksize: Number of items per worker batch for map(). + trust_remote_code: Whether to trust remote code when loading. + revision: The specific model version to use. Returns: List of decoded strings in the same order as input. @@ -94,7 +107,11 @@ def parallel_decode( if len(token_sequences) < 10: from aiperf.common.tokenizer import Tokenizer - tokenizer = Tokenizer.from_pretrained(tokenizer_name) + tokenizer = Tokenizer.from_pretrained( + tokenizer_name, + trust_remote_code=trust_remote_code, + revision=revision, + ) return [ tokenizer.decode(tokens, skip_special_tokens=False) for tokens in token_sequences @@ -118,7 +135,7 @@ def parallel_decode( with ProcessPoolExecutor( max_workers=num_workers, initializer=_init_worker, - initargs=(tokenizer_name,), + initargs=(tokenizer_name, trust_remote_code, revision), ) as executor: results = list( executor.map(_decode_tokens, token_sequences, chunksize=chunksize) diff --git a/src/aiperf/dataset/loader/mooncake_trace.py b/src/aiperf/dataset/loader/mooncake_trace.py index f46703f4d..351798423 100644 --- a/src/aiperf/dataset/loader/mooncake_trace.py +++ b/src/aiperf/dataset/loader/mooncake_trace.py @@ -65,6 +65,8 @@ def __init__( or user_config.tokenizer.name or user_config.endpoint.model_names[0] ) + self._trust_remote_code = user_config.tokenizer.trust_remote_code + self._revision = user_config.tokenizer.revision self._block_size = user_config.input.prompt.input_tokens.block_size @classmethod @@ -228,7 +230,12 @@ def convert_to_conversations( f"({len(data)} conversations)" ) token_sequences = [p[2] for p in pending_decodes] - decoded_prompts = parallel_decode(token_sequences, self._tokenizer_name) + decoded_prompts = parallel_decode( + token_sequences, + self._tokenizer_name, + trust_remote_code=self._trust_remote_code, + revision=self._revision, + ) # Fill in placeholders and update cache for (session_id, idx, _, cache_key), prompt in zip( diff --git a/tests/unit/dataset/generator/test_parallel_decode.py b/tests/unit/dataset/generator/test_parallel_decode.py index a6b0cca79..6096ab1be 100644 --- a/tests/unit/dataset/generator/test_parallel_decode.py +++ b/tests/unit/dataset/generator/test_parallel_decode.py @@ -34,7 +34,9 @@ def test_parallel_decode_small_batch_sequential(self, mock_tokenizer_class): result = parallel_decode(token_sequences, "gpt2") # Should use sequential decoding (Tokenizer.from_pretrained called once) - mock_tokenizer_class.from_pretrained.assert_called_once_with("gpt2") + mock_tokenizer_class.from_pretrained.assert_called_once_with( + "gpt2", trust_remote_code=False, revision="main" + ) assert mock_tokenizer.decode.call_count == 2 assert result == ["decoded", "decoded"] @@ -204,7 +206,10 @@ def test_init_worker_loads_tokenizer(self, mock_tokenizer_class): pd_module._init_worker("gpt2") mock_tokenizer_class.from_pretrained.assert_called_once_with( - "gpt2", resolve_alias=False + "gpt2", + trust_remote_code=False, + revision="main", + resolve_alias=False, ) assert pd_module._worker_tokenizer is mock_tokenizer assert pd_module._worker_tokenizer_name == "gpt2" @@ -250,7 +255,10 @@ def test_init_worker_reloads_tokenizer_different_name(self, mock_tokenizer_class pd_module._init_worker("llama") mock_tokenizer_class.from_pretrained.assert_called_once_with( - "llama", resolve_alias=False + "llama", + trust_remote_code=False, + revision="main", + resolve_alias=False, ) assert pd_module._worker_tokenizer is new_tokenizer assert pd_module._worker_tokenizer_name == "llama" @@ -267,3 +275,66 @@ def test_decode_tokens_uses_worker_tokenizer(self): [1, 2, 3], skip_special_tokens=False ) assert result == "decoded text" + + @patch("aiperf.common.tokenizer.Tokenizer") + def test_init_worker_passes_trust_remote_code_and_revision( + self, mock_tokenizer_class + ): + """Test that _init_worker forwards trust_remote_code and revision.""" + pd_module._worker_tokenizer = None + pd_module._worker_tokenizer_name = None + mock_tokenizer_class.from_pretrained.return_value = MagicMock() + + pd_module._init_worker("kimi-vl", trust_remote_code=True, revision="v1.2") + + mock_tokenizer_class.from_pretrained.assert_called_once_with( + "kimi-vl", + trust_remote_code=True, + revision="v1.2", + resolve_alias=False, + ) + + +class TestParallelDecodeTokenizerArgs: + """Test that parallel_decode passes tokenizer args through.""" + + @patch("aiperf.common.tokenizer.Tokenizer") + def test_small_batch_passes_trust_remote_code_and_revision( + self, mock_tokenizer_class + ): + """Test sequential path forwards trust_remote_code and revision.""" + mock_tokenizer = MagicMock() + mock_tokenizer.decode.return_value = "decoded" + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + parallel_decode( + [[1, 2]], + "kimi-vl", + trust_remote_code=True, + revision="v1.2", + ) + + mock_tokenizer_class.from_pretrained.assert_called_once_with( + "kimi-vl", trust_remote_code=True, revision="v1.2" + ) + + @patch.object(pd_module, "ProcessPoolExecutor") + def test_large_batch_passes_trust_remote_code_and_revision_to_workers( + self, mock_executor_class + ): + """Test executor path forwards trust_remote_code and revision via initargs.""" + mock_executor = MagicMock() + mock_executor.__enter__ = MagicMock(return_value=mock_executor) + mock_executor.__exit__ = MagicMock(return_value=False) + mock_executor.map.return_value = ["decoded"] * 15 + mock_executor_class.return_value = mock_executor + + parallel_decode( + [[i] for i in range(15)], + "kimi-vl", + trust_remote_code=True, + revision="v1.2", + ) + + call_kwargs = mock_executor_class.call_args.kwargs + assert call_kwargs["initargs"] == ("kimi-vl", True, "v1.2")