diff --git a/src/aiperf/common/hash_id_random_generator.py b/src/aiperf/common/hash_id_random_generator.py new file mode 100644 index 000000000..d8a3d5380 --- /dev/null +++ b/src/aiperf/common/hash_id_random_generator.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Hash-ID-based random generator for parallel processing with reproducibility. + +Enables parallel processing of traces with hash_ids while maintaining +reproducibility. Each (trace_id, hash_id) pair produces a deterministic random +sequence regardless of worker count or processing order. + +Architecture: + Global Seed -> Base RNG -> (trace_id, hash_id) -> Deterministic tokens + +The trace_id (typically a content hash of the trace file) ensures that different +trace files with overlapping hash_id values produce different content, while the +same trace file always produces identical results. +""" + +import hashlib + +from aiperf.common.random_generator import RandomGenerator + +__all__ = ["HashIdRandomGenerator"] + + +class _DisabledNumpyRNG: + """Raises on any attribute access to prevent NumPy RNG usage.""" + + def __getattr__(self, name): + raise RuntimeError( + "HashIdRandomGenerator does not support NumPy RNG operations. " + "Use Python RNG methods (randrange, choice, etc.) instead." + ) + + +class HashIdRandomGenerator(RandomGenerator): + """RandomGenerator that re-seeds deterministically per (trace_id, hash_id). + + Designed for parallel processing where multiple workers need to generate + identical content for the same hash_id within a trace file. + + Thread Safety: + NOT thread-safe. Each worker process must have its own instance. + """ + + @classmethod + def from_base_rng(cls, base_rng: RandomGenerator) -> "HashIdRandomGenerator": + """Create from a base RandomGenerator (typically from rng.derive()).""" + base_seed = base_rng.seed or base_rng.randrange(0, 2**64) + return cls(base_seed, _internal=True) + + def __init__(self, base_seed: int, *, _internal: bool = False): + super().__init__(base_seed, _internal=_internal) + self._numpy_rng = _DisabledNumpyRNG() + self._trace_id: str = "" + + def set_trace_id(self, trace_id: str) -> None: + """Set trace identifier to scope hash_ids to a specific trace file. + + Args: + trace_id: Content hash or unique identifier for the trace file. + Different trace files must use different trace_ids. + """ + self._trace_id = trace_id + + def reseed_for_hash_id(self, hash_id: int) -> None: + """Re-seed RNG deterministically for a specific hash_id. + + After calling, all random operations use the derived seed until + the next reseed_for_hash_id call. + + Args: + hash_id: KV block hash ID from trace data. + """ + seed_bytes = hashlib.sha256( + f"{self.seed}:{self._trace_id}:{hash_id}".encode() + ).digest() + self._python_rng.seed(int.from_bytes(seed_bytes[:8], "big")) diff --git a/src/aiperf/common/tokenizer.py b/src/aiperf/common/tokenizer.py index 51a07975a..fb553a560 100644 --- a/src/aiperf/common/tokenizer.py +++ b/src/aiperf/common/tokenizer.py @@ -4,6 +4,7 @@ """HuggingFace tokenizer wrapper with sensible defaults.""" import contextlib +import inspect import io import logging import os @@ -47,6 +48,14 @@ def __init__(self, name: str, suggestions: list[tuple[str, int]]) -> None: ) +def _supports_kwarg(obj: object, method_name: str, kwarg: str) -> bool: + """Check if a method on an object accepts a specific keyword argument.""" + method = getattr(obj, method_name, None) + if method is None: + return False + return kwarg in inspect.signature(method).parameters + + def _is_offline_mode() -> bool: """Check if HuggingFace offline mode is enabled via environment variables.""" return bool(os.environ.get("HF_HUB_OFFLINE", "")) or bool( @@ -147,6 +156,16 @@ def _require_init(self) -> None: if self._tokenizer is None: raise NotInitializedError("Tokenizer is not initialized.") + def _apply_kwarg_overrides(self) -> None: + """Override default args for tokenizers that use non-standard kwargs (e.g. Kimi).""" + if self._tokenizer is None: + return + if _supports_kwarg(self._tokenizer, "encode", "allow_special_tokens"): + self._call_args = {"allow_special_tokens": False} + self._encode_args = {"allow_special_tokens": False} + if not _supports_kwarg(self._tokenizer, "decode", "skip_special_tokens"): + self._decode_args = {} + @staticmethod def resolve_alias(name: str) -> AliasResolutionResult: """Resolve a tokenizer name alias to its canonical repository ID.""" @@ -208,6 +227,7 @@ def from_pretrained( revision=revision, ) tokenizer_cls._resolved_name = resolved_name + tokenizer_cls._apply_kwarg_overrides() except AmbiguousTokenizerNameError: raise except Exception as e: @@ -285,6 +305,7 @@ class _OfflineModelInfo: revision=revision, local_files_only=True, ) + tokenizer_cls._apply_kwarg_overrides() return tokenizer_cls finally: huggingface_hub.model_info = _original_model_info diff --git a/src/aiperf/dataset/composer/base.py b/src/aiperf/dataset/composer/base.py index 9b12bd8ef..3aa0b4cb6 100644 --- a/src/aiperf/dataset/composer/base.py +++ b/src/aiperf/dataset/composer/base.py @@ -3,6 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Iterator from aiperf.common import random_generator as rng from aiperf.common.config import UserConfig @@ -41,12 +42,11 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer | None, **kwargs): self._turn_sequence_cache: dict[int, tuple[int, int]] = {} @abstractmethod - def create_dataset(self) -> list[Conversation]: - """ - Create a set of conversation objects from the given configuration. + def create_dataset(self) -> Iterator[Conversation]: + """Create conversation objects from the given configuration. - Returns: - list[Conversation]: A list of conversation objects. + Yields Conversation objects one at a time so callers can stream + them directly to the backing store without materializing the full list. """ ... @@ -151,63 +151,38 @@ def prefix_prompt_enabled(self) -> bool: and self.config.input.prompt.prefix_prompt.length > 0 ) - def _finalize_conversations(self, conversations: list[Conversation]) -> None: - """Finalize conversations by adding conversation-level context prompts. - - Injects shared system prompts and per-conversation user context prompts. - Note: Turn-level finalization (_finalize_turn) is handled by each composer - according to its needs (eager in synthetic, lazy in custom). - - Args: - conversations: List of conversations to finalize - """ - self._inject_context_prompts(conversations) + def _finalize_conversation( + self, conversation: Conversation, session_index: int + ) -> None: + """Inject context prompts into a single conversation. - def _inject_context_prompts(self, conversations: list[Conversation]) -> None: - """Inject shared system and user context prompts into conversations. - - Sets the system_message and context_message fields on Conversation objects, - which endpoint formatters will prepend to the first turn when creating payloads. + Sets the system_message and user_context_message fields, which + endpoint formatters prepend to the first turn when creating payloads. Args: - conversations: List of conversations to inject prompts into + conversation: Conversation to finalize. + session_index: Position of this conversation in the dataset + (used for per-session user context prompt generation). """ if self.prompt_generator is None: return config = self.config.input.prompt.prefix_prompt - has_shared_system = config.shared_system_prompt_length is not None - has_user_context = config.user_context_prompt_length is not None - if not (has_shared_system or has_user_context): - return + if config.shared_system_prompt_length is not None: + prompt = self._get_shared_system_prompt() + if prompt: + conversation.system_message = prompt - self.debug( - lambda: f"Injecting context prompts into {len(conversations)} conversations" - ) - - # Get shared system prompt once (same for all sessions) - shared_system_prompt = None - if has_shared_system: - shared_system_prompt = self.prompt_generator.get_shared_system_prompt() - - # Iterate through conversations and set conversation-level fields - for session_index, conversation in enumerate(conversations): - # Set shared system prompt - if shared_system_prompt: - conversation.system_message = shared_system_prompt - self.trace( - lambda conv=conversation: f"Set system_message on conversation {conv.session_id}" - ) + if config.user_context_prompt_length is not None: + conversation.user_context_message = ( + self.prompt_generator.generate_user_context_prompt(session_index) + ) - # Set user context prompt (unique per session) - if has_user_context: - user_context = self.prompt_generator.generate_user_context_prompt( - session_index - ) - conversation.user_context_message = user_context - self.trace( - lambda idx=session_index, - conv=conversation: f"Set user_context_message for session {idx} " - f"(conversation {conv.session_id})" - ) + def _get_shared_system_prompt(self) -> str | None: + """Return the shared system prompt, computing and caching on first call.""" + if not hasattr(self, "_shared_system_prompt_cache"): + self._shared_system_prompt_cache: str | None = ( + self.prompt_generator.get_shared_system_prompt() + ) + return self._shared_system_prompt_cache diff --git a/src/aiperf/dataset/composer/custom.py b/src/aiperf/dataset/composer/custom.py index 8306a41ce..c452e4a31 100644 --- a/src/aiperf/dataset/composer/custom.py +++ b/src/aiperf/dataset/composer/custom.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from collections.abc import Iterator from pathlib import Path from typing import Any @@ -19,11 +20,12 @@ class CustomDatasetComposer(BaseDatasetComposer): def __init__(self, config: UserConfig, tokenizer: Tokenizer | None): super().__init__(config, tokenizer) - def create_dataset(self) -> list[Conversation]: + def create_dataset(self) -> Iterator[Conversation]: """Create conversations from a file or directory. - Returns: - list[Conversation]: A list of conversation objects. + Yields conversations one at a time, finalizing each inline so + the caller can stream them directly to the backing store without + materializing the full list. """ # TODO: (future) for K8s, we need to transfer file data from SC (across node) check_file_exists(self.config.input.file) @@ -44,14 +46,13 @@ def create_dataset(self) -> list[Conversation]: dataset = self.loader.load_dataset() conversations = self.loader.convert_to_conversations(dataset) - # Finalize all turns with metadata (custom datasets need this) - for conversation in conversations: + for session_index, conversation in enumerate(conversations): + # Finalize all turns with metadata (custom datasets need this) for turn in conversation.turns: self._finalize_turn(turn) - - # Finalize conversation-level context prompts - self._finalize_conversations(conversations) - return conversations + # Finalize conversation-level context prompts + self._finalize_conversation(conversation, session_index) + yield conversation def _infer_dataset_type(self, file_path: str) -> CustomDatasetType: """Infer the custom dataset type from the input file. diff --git a/src/aiperf/dataset/composer/synthetic.py b/src/aiperf/dataset/composer/synthetic.py index b8991aea2..6ce5895f2 100644 --- a/src/aiperf/dataset/composer/synthetic.py +++ b/src/aiperf/dataset/composer/synthetic.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from collections.abc import Iterator + from aiperf.common import random_generator as rng from aiperf.common.config import UserConfig from aiperf.common.config.config_defaults import InputDefaults @@ -39,17 +41,13 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer | None): "setting the mean to a positive value." ) - def create_dataset(self) -> list[Conversation]: + def create_dataset(self) -> Iterator[Conversation]: """Create a synthetic conversation dataset from the given configuration. - It generates a set of conversations with a varying number of turns, - where each turn contains synthetic text, image, and audio payloads. - - Returns: - list[Conversation]: A list of conversation objects. + Yields conversations one at a time, where each conversation contains + a varying number of turns with synthetic text, image, and audio payloads. """ - conversations = [] - for _ in range(self.config.input.conversation.num_dataset_entries): + for session_index in range(self.config.input.conversation.num_dataset_entries): conversation = Conversation(session_id=self.session_id_generator.next()) num_turns = self._turn_sampler_rng.sample_positive_normal_integer( @@ -61,11 +59,10 @@ def create_dataset(self) -> list[Conversation]: for turn_idx in range(num_turns): turn = self._create_turn(is_first=(turn_idx == 0)) conversation.turns.append(turn) - conversations.append(conversation) - # Finalize all conversations (turn metadata + context prompts) - self._finalize_conversations(conversations) - return conversations + # Finalize conversation-level context prompts + self._finalize_conversation(conversation, session_index) + yield conversation def _create_turn(self, is_first: bool) -> Turn: """Create a turn object that contains synthetic payloads to send. diff --git a/src/aiperf/dataset/composer/synthetic_rankings.py b/src/aiperf/dataset/composer/synthetic_rankings.py index ab90c9f07..c7acffb63 100644 --- a/src/aiperf/dataset/composer/synthetic_rankings.py +++ b/src/aiperf/dataset/composer/synthetic_rankings.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from collections.abc import Iterator + from aiperf.common import random_generator as rng from aiperf.common.config import InputDefaults, UserConfig from aiperf.common.models import Conversation, Text, Turn @@ -33,26 +35,25 @@ def __init__(self, config: UserConfig, tokenizer: Tokenizer | None): f"Using default sampling strategy for synthetic rankings dataset: {InputDefaults.DATASET_SAMPLING_STRATEGY}" ) - def create_dataset(self) -> list[Conversation]: + def create_dataset(self) -> Iterator[Conversation]: """Generate synthetic dataset for the rankings endpoint. Each conversation contains one turn with one query and multiple passages. """ - conversations: list[Conversation] = [] num_entries = self.config.input.conversation.num_dataset_entries num_passages_mean = self.config.input.rankings.passages.mean num_passages_std = self.config.input.rankings.passages.stddev - for _ in range(num_entries): + for session_index in range(num_entries): num_passages = self._passages_rng.sample_positive_normal_integer( num_passages_mean, num_passages_std ) conversation = Conversation(session_id=self.session_id_generator.next()) turn = self._create_turn(num_passages=num_passages) conversation.turns.append(turn) - conversations.append(conversation) - - return conversations + # Finalize conversation-level context prompts + self._finalize_conversation(conversation, session_index) + yield conversation def _create_turn(self, num_passages: int) -> Turn: """Create a single ranking turn with one synthetic query and multiple synthetic passages. diff --git a/src/aiperf/dataset/dataset_manager.py b/src/aiperf/dataset/dataset_manager.py index e16339d0b..e9ac8ae26 100644 --- a/src/aiperf/dataset/dataset_manager.py +++ b/src/aiperf/dataset/dataset_manager.py @@ -5,6 +5,7 @@ import asyncio import gc import time +from collections.abc import Iterable from typing import TYPE_CHECKING import orjson @@ -31,6 +32,7 @@ from aiperf.common.mixins import ReplyClientMixin from aiperf.common.models import ( Conversation, + ConversationMetadata, DatasetClientMetadata, DatasetMetadata, InputsFile, @@ -86,11 +88,9 @@ def __init__( ) self.user_config = user_config self.tokenizer: Tokenizer | None = None - self.dataset: dict[ - str, Conversation - ] = {} # conversation ID -> Conversation mapping self.dataset_metadata: DatasetMetadata | None = None self._conversation_ids_cache: list[str] = [] + self._conversation_count: int = 0 self.dataset_configured = asyncio.Event() # In Kubernetes mode, use compress_only to stream directly to compressed files. @@ -132,7 +132,6 @@ async def _profile_configure_command( self.info(lambda: f"Configuring dataset for {self.service_id}") begin = time.perf_counter() await self._configure_dataset() - await self._generate_inputs_json_file() await self._configure_dataset_client_and_free_memory() duration = time.perf_counter() - begin @@ -140,8 +139,6 @@ async def _profile_configure_command( async def _configure_dataset_client_and_free_memory(self) -> None: """Configure the dataset client for serving fallback requests, then free memory.""" - conversation_count = len(self.dataset) - if not self._compress_only: client_metadata = self._backing_store.get_client_metadata() ClientStoreClass = plugins.get_class( @@ -149,25 +146,22 @@ async def _configure_dataset_client_and_free_memory(self) -> None: ) self._dataset_client = ClientStoreClass(client_metadata=client_metadata) await self._dataset_client.initialize() + await self._generate_inputs_json_file() self.dataset_configured.set() - # Reassign to new empty containers (not .clear()) to release object references, - # then run gc.collect() twice to ensure circular references are cleaned up. - self.dataset = {} - self._conversation_ids_cache = [] + # Run gc.collect() twice to ensure circular references are cleaned up. gc.collect() gc.collect() if self._compress_only: self.info( - f"Kubernetes mode: skipped local client, freed {conversation_count} " - "conversations from memory (workers handle all requests)" + f"Kubernetes mode: skipped local client, compressed {self._conversation_count} " + "conversations into backing store)" ) else: self.info( - f"Dataset client initialized and freed {conversation_count} " - "conversations from memory" + f"Dataset client initialized, {self._conversation_count} conversations in backing store" ) async def _configure_tokenizer(self) -> None: @@ -185,7 +179,7 @@ async def _configure_tokenizer(self) -> None: resolve_alias=tokenizer_config.should_resolve_alias, ) - def _generate_input_payloads( + async def _generate_input_payloads( self, model_endpoint: ModelEndpointInfo, ) -> InputsFile: @@ -201,7 +195,10 @@ def _generate_input_payloads( f"class: {endpoint.__class__.__name__}", ) session_payloads_map: dict[str, list] = {} - for conversation in self.dataset.values(): + for conv_metadata in self.dataset_metadata.conversations: + conversation = await self._dataset_client.get_conversation( + conv_metadata.conversation_id + ) session_id = conversation.session_id if session_id not in session_payloads_map: session_payloads_map[session_id] = [] @@ -245,7 +242,7 @@ async def _generate_inputs_json_file(self) -> None: file_path.parent.mkdir(parents=True, exist_ok=True) model_endpoint = ModelEndpointInfo.from_user_config(self.user_config) - inputs = self._generate_input_payloads(model_endpoint) + inputs = await self._generate_input_payloads(model_endpoint) temp_file_path.write_bytes( orjson.dumps( @@ -293,7 +290,7 @@ async def _load_public_dataset(self) -> list[Conversation]: ) return await loader.convert_to_conversations(dataset) - def _load_custom_dataset(self) -> list[Conversation]: + def _load_custom_dataset(self) -> Iterable[Conversation]: ComposerClass = plugins.get_class( PluginType.DATASET_COMPOSER, ComposerType.CUSTOM ) @@ -303,7 +300,7 @@ def _load_custom_dataset(self) -> list[Conversation]: def _is_rankings_endpoint(self, endpoint_type: str) -> bool: return "rankings" in endpoint_type.lower() - def _load_synthetic_dataset(self) -> list[Conversation]: + def _load_synthetic_dataset(self) -> Iterable[Conversation]: endpoint_type = self.user_config.endpoint.type if self._is_rankings_endpoint(endpoint_type): @@ -334,16 +331,17 @@ async def _configure_dataset(self) -> None: else: conversations = self._load_synthetic_dataset() - self.dataset = {conv.session_id: conv for conv in conversations} - self._conversation_ids_cache = [ - conversation.session_id for conversation in conversations - ] - - # Initialize backing store and stream conversations to mmap files - # Workers read directly from these files + # Stream conversations to backing store and collect metadata on the fly. + # Each conversation is written to mmap and then can be GC'd — we never + # hold the full dataset in memory. await self._backing_store.initialize() - conversations_dict = {conv.session_id: conv for conv in conversations} - await self._backing_store.add_conversations(conversations_dict) + metadata_list: list[ConversationMetadata] = [] + for conversation in conversations: + await self._backing_store.add_conversation( + conversation.session_id, conversation + ) + metadata_list.append(conversation.metadata()) + self._conversation_count = len(metadata_list) await self._backing_store.finalize() # In Kubernetes mode (compress_only=True), files are already compressed # during finalize(). In local mode, uncompressed files are used directly. @@ -362,7 +360,7 @@ async def _configure_dataset(self) -> None: ) self.dataset_metadata = DatasetMetadata( - conversations=[conversation.metadata() for conversation in conversations], + conversations=metadata_list, sampling_strategy=self.user_config.input.dataset_sampling_strategy, ) self.info( diff --git a/src/aiperf/dataset/generator/parallel_decode.py b/src/aiperf/dataset/generator/parallel_decode.py deleted file mode 100644 index f8558261c..000000000 --- a/src/aiperf/dataset/generator/parallel_decode.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Parallel decode utilities for batch tokenizer operations. - -This module provides functions to decode multiple token sequences in parallel -using ProcessPoolExecutor, bypassing Python's GIL for CPU-bound tokenizer -operations. - -The daemon flag on the current process is temporarily cleared because Python's -multiprocessing refuses to spawn children from daemon processes, and AIPerf -services run as daemons. -""" - -import multiprocessing as mp -import os -from concurrent.futures import ProcessPoolExecutor -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from aiperf.common.tokenizer import Tokenizer - -# Module-level tokenizer for worker processes (initialized once per worker) -_worker_tokenizer: "Tokenizer | None" = None -_worker_tokenizer_name: str | None = None - - -def _init_worker(tokenizer_name: str) -> None: - """Initialize tokenizer in worker process. - - This function is called once per worker process when the ProcessPoolExecutor - starts. It loads the tokenizer so subsequent decode calls don't need to reload it. - - Args: - tokenizer_name: Name or path of the pretrained tokenizer to load. - """ - global _worker_tokenizer, _worker_tokenizer_name - if _worker_tokenizer is None or _worker_tokenizer_name != tokenizer_name: - # The main process already downloaded and cached the tokenizer, so force - # offline mode to skip network requests and alias resolution. - os.environ["HF_HUB_OFFLINE"] = "1" - os.environ["TRANSFORMERS_OFFLINE"] = "1" - - from aiperf.common.tokenizer import Tokenizer - - _worker_tokenizer = Tokenizer.from_pretrained( - tokenizer_name, resolve_alias=False - ) - _worker_tokenizer_name = tokenizer_name - - -def _decode_tokens(token_ids: list[int]) -> str: - """Decode tokens using worker's tokenizer. - - Args: - token_ids: List of token IDs to decode. - - Returns: - Decoded string. - - Raises: - RuntimeError: If worker tokenizer is not initialized. - """ - if _worker_tokenizer is None: - raise RuntimeError("Worker tokenizer not initialized") - return _worker_tokenizer.decode(token_ids, skip_special_tokens=False) - - -def parallel_decode( - token_sequences: list[list[int]], - tokenizer_name: str, - max_workers: int | None = None, - chunksize: int = 50, -) -> list[str]: - """Decode multiple token sequences in parallel using ProcessPoolExecutor. - - This function is optimized for batch decoding of many token sequences. - For small batches (< 10 sequences), it falls back to sequential decoding - to avoid process spawn overhead. - - Args: - token_sequences: List of token ID lists to decode. - tokenizer_name: Name or path of the pretrained tokenizer to use in workers. - max_workers: Number of worker processes. Defaults to min(cpu_count, 8). - chunksize: Number of items per worker batch for map(). - - Returns: - List of decoded strings in the same order as input. - """ - if not token_sequences: - return [] - - # For small batches, sequential is faster (avoid process overhead) - if len(token_sequences) < 10: - from aiperf.common.tokenizer import Tokenizer - - tokenizer = Tokenizer.from_pretrained(tokenizer_name) - return [ - tokenizer.decode(tokens, skip_special_tokens=False) - for tokens in token_sequences - ] - - num_workers = max_workers or min(mp.cpu_count() or 4, 8) - - # Temporarily clear the daemon flag so ProcessPoolExecutor can spawn workers. - # Python's multiprocessing refuses to spawn children from daemon processes, - # and AIPerf services run as daemons. - # - # Alternatives considered: - # - billiard: bypasses the daemon restriction natively, but crashes with - # BrokenProcessPool on macOS due to terminal FD inheritance issues. - # - loky: robust reusable executor, but still requires the same daemon flag - # hack, so no advantage over stdlib. - was_daemon = mp.current_process().daemon - try: - if was_daemon: - _set_daemon(False) - with ProcessPoolExecutor( - max_workers=num_workers, - initializer=_init_worker, - initargs=(tokenizer_name,), - ) as executor: - results = list( - executor.map(_decode_tokens, token_sequences, chunksize=chunksize) - ) - finally: - if was_daemon: - _set_daemon(True) - - return results - - -def _set_daemon(daemon: bool) -> None: - """Set the daemon flag on the current process.""" - try: - mp.current_process().daemon = daemon - except AssertionError: - # Fallback to using the internal _config dictionary if assertions are enabled - mp.current_process()._config["daemon"] = daemon diff --git a/src/aiperf/dataset/generator/prompt.py b/src/aiperf/dataset/generator/prompt.py index 9c76d4a32..f3ada6133 100644 --- a/src/aiperf/dataset/generator/prompt.py +++ b/src/aiperf/dataset/generator/prompt.py @@ -13,12 +13,49 @@ InvalidStateError, NotInitializedError, ) +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator from aiperf.common.tokenizer import Tokenizer from aiperf.dataset.generator.base import BaseGenerator DEFAULT_CORPUS_FILE = "assets/shakespeare.txt" +def sample_tokens_from_corpus( + corpus: list[int], + num_tokens: int, + rng_to_use: rng.RandomGenerator, + sep_token: int | None = None, +) -> list[int]: + """Sample tokens from a corpus with optional separator token. + + Args: + corpus: Token corpus as a list of token IDs. + num_tokens: Number of tokens to sample. + rng_to_use: RandomGenerator for sampling start position. + sep_token: Optional separator token to prepend (BOS/EOS). + + Returns: + List of sampled token IDs. + """ + corpus_len = len(corpus) + tokens: list[int] = [] + + if sep_token is not None: + tokens.append(sep_token) + num_tokens -= 1 + + start = rng_to_use.randrange(corpus_len) + end = start + num_tokens + + if end <= corpus_len: + tokens.extend(corpus[start:end]) + else: + tokens.extend(corpus[start:]) + tokens.extend(corpus[: end - corpus_len]) + + return tokens + + class PromptGenerator(BaseGenerator): """A class for generating synthetic prompts from a text corpus. @@ -47,15 +84,16 @@ def __init__(self, config: PromptConfig, tokenizer: Tokenizer, **kwargs): self._corpus_rng = rng.derive("dataset.prompt.corpus") self._prefix_rng = rng.derive("dataset.prompt.prefix") + # Hash-ID-based RNG for deterministic per-hash_id generation. + # Re-seeds itself for each hash_id, enabling identical random + # sequences per hash block regardless of processing order or workers. + self._hash_id_corpus_rng = HashIdRandomGenerator.from_base_rng(self._corpus_rng) + super().__init__(config=config, tokenizer=tokenizer, **kwargs) # Cached prompts: block ID -> list of tokens self._cache: dict[int, list[int]] = {} - # Decoded string cache: (hash_ids tuple, num_tokens, block_size) -> decoded string - # This avoids redundant tokenizer.decode() calls for repeated hash_id combinations - self._decoded_cache: dict[tuple[tuple[int, ...], int, int], str] = {} - # TODO: move this under initialize() method # Initialize corpus if not already done if self._tokenized_corpus is None: @@ -154,6 +192,7 @@ def generate( mean: int | None = None, stddev: int | None = None, hash_ids: list[int] | None = None, + block_size: int | None = None, ) -> str: """Generate a synthetic prompt with the configuration parameters. Serves as a wrapper around other internal methods to provide a unified interface. @@ -162,6 +201,8 @@ def generate( mean: The mean of the normal distribution. stddev: The standard deviation of the normal distribution. hash_ids: A list of hash indices used for token reuse. + block_size: Override block size for hash-ID generation. + Defaults to config value or :data:`InputTokensDefaults.BLOCK_SIZE`. Returns: A synthetic prompt as a string. @@ -169,10 +210,12 @@ def generate( if hash_ids: if mean is None: raise ValueError("mean must be provided when hash_ids is set.") - block_size = ( - self.config.input_tokens.block_size or InputTokensDefaults.BLOCK_SIZE + effective_block_size = ( + block_size + or self.config.input_tokens.block_size + or InputTokensDefaults.BLOCK_SIZE ) - return self._generate_cached_prompt(mean, hash_ids, block_size) + return self._generate_cached_prompt(mean, hash_ids, effective_block_size) num_tokens = self.calculate_num_tokens(mean, stddev) return self.generate_prompt(num_tokens) @@ -208,48 +251,12 @@ def _generate_cached_prompt( hash_ids: list[int], block_size: int, ) -> str: - """ - Generate a prompt containing exactly `num_tokens` by reusing previously generated prompts - stored in `_cache`. Each hash index in `hash_ids` corresponds to a block of - `block_size` tokens. If a hash index is found in `_cache`, its stored prompt is reused. - Otherwise, a new prompt is generated using `generate_prompt()` and stored in `_cache`. + """Generate a prompt by reusing previously generated token blocks. - Args: - num_tokens: The number of tokens required in the prompt. - hash_ids: A list of hash IDs to use for token reuse. - block_size: The number of tokens allocated per hash block. - - Returns: - str: A synthetic prompt as a string. - - Raises: - ConfigurationError: If the input parameters are not compatible. - """ - # Check decoded string cache first to avoid redundant decode calls - cache_key = (tuple(hash_ids), num_tokens, block_size) - if cache_key in self._decoded_cache: - return self._decoded_cache[cache_key] - - # Build token sequence using _build_token_sequence (shared logic) - final_prompt = self._build_token_sequence(num_tokens, hash_ids, block_size) - - # Decode and cache the result - decoded = self.tokenizer.decode(final_prompt, skip_special_tokens=False) - self._decoded_cache[cache_key] = decoded - return decoded - - def _build_token_sequence( - self, - num_tokens: int, - hash_ids: list[int], - block_size: int, - ) -> list[int]: - """ - Build a token sequence without decoding. Used for batch parallel decode. - - Each hash index in `hash_ids` corresponds to a block of `block_size` tokens. - If a hash index is found in `_cache`, its stored tokens are reused. - Otherwise, new tokens are sampled and stored in `_cache`. + Each hash_id in `hash_ids` corresponds to a block of `block_size` tokens. + If a hash_id is found in `_cache`, its stored tokens are reused. Otherwise, + tokens are generated deterministically using HashIdRandomGenerator re-seeding + and stored in `_cache`. Args: num_tokens: The number of tokens required in the prompt. @@ -257,7 +264,7 @@ def _build_token_sequence( block_size: The number of tokens allocated per hash block. Returns: - list[int]: A list of token IDs. + str: A synthetic prompt as a string. Raises: ConfigurationError: If the input parameters are not compatible. @@ -280,21 +287,17 @@ def _build_token_sequence( current_block_size = final_block_size if hash_id not in self._cache: - # To ensure that the prompt doesn't merge chunks, we insert a BOS or EOS token - # at the beginning. Length is maintained and the prompt generates the expected - # number of tokens. If no BOS or EOS token is available, we don't insert one. - prompt_tokens: list[int] = [] - if self.tokenizer.block_separation_token_id is not None: - prompt_tokens += [self.tokenizer.block_separation_token_id] - prompt_tokens += self._sample_tokens(current_block_size - 1) - else: - prompt_tokens += self._sample_tokens(current_block_size) - - self._cache[hash_id] = prompt_tokens # store to cache + self._hash_id_corpus_rng.reseed_for_hash_id(hash_id) + self._cache[hash_id] = sample_tokens_from_corpus( + self._tokenized_corpus, + current_block_size, + self._hash_id_corpus_rng, + self.tokenizer.block_separation_token_id, + ) final_prompt.extend(self._cache[hash_id]) - return final_prompt + return self.tokenizer.decode(final_prompt, skip_special_tokens=False) def _sample_tokens(self, num_tokens: int) -> list[int]: """Generate a list of token IDs containing exactly `num_tokens` number of tokens diff --git a/src/aiperf/dataset/loader/base_trace_loader.py b/src/aiperf/dataset/loader/base_trace_loader.py index 5acb14509..b97045097 100644 --- a/src/aiperf/dataset/loader/base_trace_loader.py +++ b/src/aiperf/dataset/loader/base_trace_loader.py @@ -1,29 +1,50 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import hashlib from abc import abstractmethod +from collections.abc import Iterator from typing import Any, Generic, TypeVar from aiperf.common.config.config_defaults import InputTokensDefaults from aiperf.common.config.user_config import UserConfig from aiperf.common.models import Conversation, Text, Turn -from aiperf.dataset.generator.parallel_decode import parallel_decode from aiperf.dataset.generator.prompt import PromptGenerator from aiperf.dataset.loader.base_loader import BaseFileLoader +from aiperf.dataset.loader.parallel_convert import parallel_convert from aiperf.dataset.synthesis.models import SynthesisParams from aiperf.dataset.synthesis.synthesizer import Synthesizer from aiperf.plugin.enums import DatasetSamplingStrategy TraceT = TypeVar("TraceT") +_MIN_TRACES_FOR_PARALLEL = 10 + + +def _compute_file_hash(filepath: str) -> str: + """Compute SHA256 hash of file content (first 16 hex chars). + + Falls back to hashing the filepath string if the file cannot be read. + """ + try: + hasher = hashlib.sha256() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest()[:16] + except (OSError, TypeError): + return hashlib.sha256(filepath.encode()).hexdigest()[:16] + class BaseTraceDatasetLoader(BaseFileLoader, Generic[TraceT]): """Base class for trace dataset loaders with hash_ids-based prompt generation. Provides common infrastructure for loading trace-format datasets (Mooncake, Bailian, etc.) including shared initialization, timestamp - filtering, 3-phase prompt generation with parallel decode, and - synthesis integration. + filtering, parallel prompt generation with deterministic per-hash_id + re-seeding, and synthesis integration. Subclasses must implement: - `can_load`: data format detection @@ -50,14 +71,9 @@ def __init__( self._end_offset = user_config.input.fixed_schedule_end_offset self._max_isl = user_config.input.synthesis.max_isl self._max_osl = user_config.input.synthesis.max_osl - - # Use the resolved tokenizer name so worker processes can load from cache - # without needing alias resolution or network access. - self._tokenizer_name = ( - prompt_generator.tokenizer.resolved_name - or user_config.tokenizer.name - or user_config.endpoint.model_names[0] - ) + self._trace_id: str = "" + self._trust_remote_code = user_config.tokenizer.trust_remote_code + self._revision = user_config.tokenizer.revision # Precedence: user CLI --isl-block-size > plugin metadata default > hardcoded fallback user_block_size = user_config.input.prompt.input_tokens.block_size @@ -171,6 +187,10 @@ def load_dataset(self) -> dict[str, list[TraceT]]: self._skipped_traces = 0 self._skipped_max_isl = 0 self._capped_max_osl = 0 + + self._trace_id = _compute_file_hash(self.filename) + self.prompt_generator._hash_id_corpus_rng.set_trace_id(self._trace_id) + self.debug(lambda: f"Trace ID: {self._trace_id} for {self.filename}") items: list[TraceT] = [] with open(self.filename) as f: @@ -202,7 +222,7 @@ def load_dataset(self) -> dict[str, list[TraceT]]: return data # ------------------------------------------------------------------ - # convert_to_conversations — 3-phase prompt generation + # convert_to_conversations # ------------------------------------------------------------------ def _get_text_input(self, trace: TraceT) -> str | None: @@ -227,77 +247,72 @@ def _build_turn(self, trace: TraceT, prompt: str) -> Turn: ) def convert_to_conversations( - self, data: dict[str, list[TraceT]] - ) -> list[Conversation]: - """Convert trace sessions to :class:`Conversation` objects. + self, + data: dict[str, list[TraceT]], + num_workers: int | None = None, + batch_size: int = 100, + ) -> Iterator[Conversation]: + """Convert trace sessions to conversations using parallel workers. - Uses a three-phase approach for optimal performance: + Uses multiprocessing Pool with shared memory for the token corpus. + Each worker gets its own HashIdRandomGenerator to produce deterministic + token sequences per hash_id regardless of worker count or order. - 1. Build token sequences, checking the string cache first. - 2. Batch parallel decode for all cache misses. - 3. Assemble final :class:`Conversation` objects. + Falls back to single-threaded conversion for small datasets. + + Yields: + Conversation objects in session order. """ - # Phase 1: Build token sequences and identify cache misses - pending_decodes: list[tuple[str, int, list[int], tuple]] = [] - conversations_data: dict[str, list[tuple[TraceT, str | None]]] = {} + sessions = list(data.items()) + if not sessions: + return + + total_traces = sum(len(traces) for _, traces in sessions) + if total_traces < _MIN_TRACES_FOR_PARALLEL: + yield from self._convert_single_threaded(sessions) + return + + pg = self.prompt_generator + serialized = [ + (sid, [t.model_dump() for t in traces]) # type: ignore[union-attr] + for sid, traces in sessions + ] + + yield from parallel_convert( + sessions=serialized, + tokenizer_name=pg.tokenizer.resolved_name, + corpus=pg._tokenized_corpus, + base_seed=pg._hash_id_corpus_rng.seed, + block_size=self._block_size, + sep_token=pg.tokenizer.block_separation_token_id, + trace_id=self._trace_id, + trust_remote_code=self._trust_remote_code, + revision=self._revision, + num_workers=num_workers, + batch_size=batch_size, + ) - for session_id, traces in data.items(): - conversations_data[session_id] = [] - for idx, trace in enumerate(traces): + def _convert_single_threaded( + self, sessions: list[tuple[str, list[TraceT]]] + ) -> Iterator[Conversation]: + """Fallback single-threaded conversion for small datasets.""" + for session_id, traces in sessions: + conversation = Conversation(session_id=session_id) + for trace in traces: text_input = self._get_text_input(trace) if text_input is not None: - conversations_data[session_id].append((trace, text_input)) - continue - - hash_ids: list[int] = getattr(trace, "hash_ids", None) or [] - input_length: int = getattr(trace, "input_length", 0) - - if hash_ids: - cache_key = ( - tuple(hash_ids), - input_length, - self._block_size, - ) - if cache_key in self.prompt_generator._decoded_cache: - prompt = self.prompt_generator._decoded_cache[cache_key] - conversations_data[session_id].append((trace, prompt)) - else: - tokens = self.prompt_generator._build_token_sequence( - input_length, hash_ids, self._block_size - ) - pending_decodes.append((session_id, idx, tokens, cache_key)) - conversations_data[session_id].append((trace, None)) + prompt = text_input else: + hash_ids: list[int] = getattr(trace, "hash_ids", None) or [] + input_length: int = getattr(trace, "input_length", 0) prompt = self.prompt_generator.generate( - mean=input_length, stddev=0, hash_ids=[] + mean=input_length, + stddev=0, + hash_ids=hash_ids, + block_size=self._block_size, ) - conversations_data[session_id].append((trace, prompt)) - - # Phase 2: Batch parallel decode for all cache misses - if pending_decodes: - self.debug( - lambda: f"Parallel decoding {len(pending_decodes)} prompts " - f"({len(data)} conversations)" - ) - token_sequences = [p[2] for p in pending_decodes] - decoded_prompts = parallel_decode(token_sequences, self._tokenizer_name) - - for (session_id, idx, _, cache_key), prompt in zip( - pending_decodes, decoded_prompts, strict=True - ): - self.prompt_generator._decoded_cache[cache_key] = prompt - trace, _ = conversations_data[session_id][idx] - conversations_data[session_id][idx] = (trace, prompt) - - # Phase 3: Build final conversation objects - conversations: list[Conversation] = [] - for session_id, trace_prompt_pairs in conversations_data.items(): - conversation = Conversation(session_id=session_id) - for trace, prompt in trace_prompt_pairs: conversation.turns.append(self._build_turn(trace, prompt)) - conversations.append(conversation) - - return conversations + yield conversation # ------------------------------------------------------------------ # Synthesis — shared orchestration with subclass hooks diff --git a/src/aiperf/dataset/loader/parallel_convert.py b/src/aiperf/dataset/loader/parallel_convert.py new file mode 100644 index 000000000..032bfb534 --- /dev/null +++ b/src/aiperf/dataset/loader/parallel_convert.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Parallel conversion of trace sessions to conversations. + +Uses multiprocessing Pool with shared memory for the token corpus. Each worker +gets its own HashIdRandomGenerator to produce deterministic token sequences per +hash_id regardless of worker count or processing order. + +The daemon flag on the current process is temporarily cleared because Python's +multiprocessing refuses to spawn children from daemon processes, and AIPerf +services run as daemons. +""" + +from __future__ import annotations + +import multiprocessing as mp +import os +import sys +from collections.abc import Callable, Iterator +from dataclasses import dataclass, field +from multiprocessing import Pool, shared_memory + +import numpy as np + +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator +from aiperf.common.models import Conversation, Text, Turn +from aiperf.common.tokenizer import Tokenizer + + +@dataclass(slots=True) +class _WorkerInitArgs: + """Arguments passed to each worker process via Pool initargs.""" + + shm_name: str + corpus_len: int + tokenizer_name: str + base_seed: int + block_size: int + sep_token: int | None + trace_id: str + trust_remote_code: bool = False + revision: str = "main" + + +@dataclass(slots=True) +class _WorkerState: + """Per-worker process state, initialized once via _init_worker.""" + + tokenizer: Tokenizer + corpus: np.ndarray + shm: shared_memory.SharedMemory # prevent GC from unmapping corpus buffer + hash_rng: HashIdRandomGenerator + block_size: int + sep_token: int | None + sample_tokens: Callable[..., list[int]] + block_cache: dict[int, list[int]] = field(default_factory=dict) + + +# Set once per worker process by _init_worker; read by _process_batch. +_worker_state: _WorkerState | None = None + + +def _init_worker(args: _WorkerInitArgs) -> None: + """Initialize worker process with shared corpus and tokenizer. + + Called once per worker when the Pool is created. Attaches to the + shared-memory corpus, creates a per-worker HashIdRandomGenerator + (seeded by trace_id for file-level determinism), and loads the + tokenizer from local cache (offline mode). + """ + global _worker_state + + from aiperf.dataset.generator.prompt import sample_tokens_from_corpus + + # The main process already downloaded and cached the tokenizer, so force + # offline mode to skip network requests and alias resolution. + os.environ["HF_HUB_OFFLINE"] = "1" + os.environ["TRANSFORMERS_OFFLINE"] = "1" + + shm = shared_memory.SharedMemory(name=args.shm_name) + + # Each worker gets its own RNG so reseed_for_hash_id calls are independent. + hash_rng = HashIdRandomGenerator(args.base_seed, _internal=True) + hash_rng.set_trace_id(args.trace_id) + + _worker_state = _WorkerState( + tokenizer=Tokenizer.from_pretrained( + args.tokenizer_name, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + resolve_alias=False, + ), + corpus=np.ndarray((args.corpus_len,), dtype=np.int32, buffer=shm.buf), + shm=shm, + hash_rng=hash_rng, + block_size=args.block_size, + sep_token=args.sep_token, + sample_tokens=sample_tokens_from_corpus, + ) + + +def _process_batch( + batch: list[tuple[str, list[dict]]], +) -> list[tuple[str, list[tuple]]]: + """Process a batch of sessions, converting hash_ids to prompts. + + Each trace dict must have 'input_length', 'output_length', 'timestamp', + 'delay', and optionally 'hash_ids' and 'text_input'. + """ + assert _worker_state is not None + hash_rng = _worker_state.hash_rng + corpus = _worker_state.corpus + block_size = _worker_state.block_size + sep_token = _worker_state.sep_token + decode = _worker_state.tokenizer.decode + sample_tokens = _worker_state.sample_tokens + block_cache = _worker_state.block_cache + + def get_block_tokens(hash_id: int, size: int) -> list[int]: + if hash_id in block_cache: + return block_cache[hash_id] + hash_rng.reseed_for_hash_id(hash_id) + tokens = sample_tokens(corpus, size, hash_rng, sep_token) + block_cache[hash_id] = tokens + return tokens + + results = [] + for session_id, traces in batch: + turns = [] + for trace in traces: + if trace.get("text_input"): + # Literal prompt provided by the trace (no generation needed). + prompt = trace["text_input"] + elif trace.get("hash_ids"): + # Generate prompt from hash_id blocks. All blocks are full-sized + # except the last, which gets the remainder tokens. + hash_ids = trace["hash_ids"] + input_length = trace["input_length"] + final_block_size = input_length - (len(hash_ids) - 1) * block_size + + tokens: list[int] = [] + for i, hid in enumerate(hash_ids): + size = final_block_size if i == len(hash_ids) - 1 else block_size + tokens.extend(get_block_tokens(hid, size)) + prompt = decode(tokens, skip_special_tokens=False) + else: + prompt = "" + + turns.append( + ( + trace.get("timestamp"), + trace.get("delay"), + prompt, + trace.get("output_length"), + ) + ) + results.append((session_id, turns)) + + return results + + +def _has_broken_stdio() -> bool: + """Check if any stdio stream has an invalid file descriptor.""" + for stream in (sys.stdin, sys.stdout, sys.stderr): + try: + fd = stream.fileno() + if fd < 0: + return True + os.fstat(fd) + except (OSError, ValueError, AttributeError): + return True + return False + + +def _ensure_valid_stdio_fds() -> None: + """Redirect broken stdio to /dev/null before spawning Pool workers. + + Under the Textual terminal UI, child service processes inherit + Textual-managed sys.stdin/stdout/stderr objects whose fileno() may + return -1. When Pool workers fork and call util._close_stdin(), the + invalid FD propagates to _posixsubprocess.fork_exec causing + "bad value(s) in fds_to_keep". Only redirects when a problem is + detected so non-dashboard modes keep normal stdio. + """ + if not _has_broken_stdio(): + return + + devnull = os.open(os.devnull, os.O_RDWR) + for fd in (0, 1, 2): + os.dup2(devnull, fd) + if devnull > 2: + os.close(devnull) + sys.stdin = os.fdopen(0, "r", closefd=False) + sys.stdout = os.fdopen(1, "w", closefd=False) + sys.stderr = os.fdopen(2, "w", closefd=False) + + +def _set_daemon(daemon: bool) -> None: + """Set the daemon flag on the current process. + + Python's multiprocessing refuses to spawn children from daemon processes, + and AIPerf services run as daemons. This temporarily clears the flag. + """ + try: + mp.current_process().daemon = daemon + except AssertionError: + mp.current_process()._config["daemon"] = daemon + + +def parallel_convert( + sessions: list[tuple[str, list[dict]]], + *, + tokenizer_name: str, + corpus: list[int], + base_seed: int, + block_size: int, + sep_token: int | None, + trace_id: str, + trust_remote_code: bool = False, + revision: str = "main", + num_workers: int | None = None, + batch_size: int = 100, +) -> Iterator[Conversation]: + """Convert trace sessions to conversations using parallel workers. + + Yields Conversation objects one at a time as batches complete, using + ``pool.imap`` to preserve insertion order while avoiding materializing + all results in memory at once. + + Args: + sessions: List of (session_id, [trace_dict, ...]) tuples. + tokenizer_name: HuggingFace tokenizer name (already cached locally). + corpus: Tokenized corpus as a list of token IDs. + base_seed: Base seed for HashIdRandomGenerator. + block_size: Number of tokens per hash block. + sep_token: Optional separator token prepended to each block. + trace_id: File-derived trace ID for deterministic per-file seeding. + num_workers: Number of worker processes. Defaults to min(cpu_count, 16). + batch_size: Number of sessions per worker batch. + + Yields: + Conversation objects in the same order as the input sessions. + """ + _ensure_valid_stdio_fds() + + corpus_len = len(corpus) + shm = shared_memory.SharedMemory( + create=True, size=corpus_len * np.dtype(np.int32).itemsize + ) + + try: + np.ndarray((corpus_len,), dtype=np.int32, buffer=shm.buf)[:] = corpus + + batches = [ + sessions[i : i + batch_size] for i in range(0, len(sessions), batch_size) + ] + + workers = num_workers or min(os.cpu_count() or 4, 16) + + was_daemon = mp.current_process().daemon + try: + if was_daemon: + _set_daemon(False) + init_args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=corpus_len, + tokenizer_name=tokenizer_name, + base_seed=base_seed, + block_size=block_size, + sep_token=sep_token, + trace_id=trace_id, + trust_remote_code=trust_remote_code, + revision=revision, + ) + with Pool(workers, _init_worker, (init_args,)) as pool: + # imap preserves submission order (unlike imap_unordered) + for batch_result in pool.imap(_process_batch, batches): + for sid, turns in batch_result: + yield Conversation( + session_id=sid, + turns=[ + Turn( + timestamp=ts, + delay=delay, + texts=[Text(name="text", contents=[prompt])], + max_tokens=max_tokens, + ) + for ts, delay, prompt, max_tokens in turns + ], + ) + finally: + if was_daemon: + _set_daemon(True) + finally: + shm.close() + shm.unlink() diff --git a/src/aiperf/dataset/protocols.py b/src/aiperf/dataset/protocols.py index abf738328..a48c9f3fd 100644 --- a/src/aiperf/dataset/protocols.py +++ b/src/aiperf/dataset/protocols.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from aiperf.common.models import Conversation @@ -47,7 +48,7 @@ def load_dataset(self) -> dict[str, list["CustomDatasetT"]]: ... def convert_to_conversations( self, custom_data: dict[str, list["CustomDatasetT"]] - ) -> list[Conversation]: ... + ) -> Iterable[Conversation]: ... @runtime_checkable diff --git a/src/aiperf/dataset/synthesis/rolling_hasher.py b/src/aiperf/dataset/synthesis/rolling_hasher.py index c1fc1ae28..5beb55008 100644 --- a/src/aiperf/dataset/synthesis/rolling_hasher.py +++ b/src/aiperf/dataset/synthesis/rolling_hasher.py @@ -225,7 +225,9 @@ def hashes_to_texts( if hash_ids: # Use PromptGenerator to generate text from hash_ids # This uses the Shakespeare corpus and caches blocks by hash_id - text = prompt_generator.generate(mean=input_len, hash_ids=hash_ids) + text = prompt_generator.generate( + mean=input_len, hash_ids=hash_ids, block_size=block_size + ) else: # No hash_ids - generate plain text of target length text = prompt_generator.generate(mean=input_len) diff --git a/tests/unit/common/test_hash_id_random_generator.py b/tests/unit/common/test_hash_id_random_generator.py new file mode 100644 index 000000000..02afaee6d --- /dev/null +++ b/tests/unit/common/test_hash_id_random_generator.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for HashIdRandomGenerator parallel processing with reproducibility.""" + +import pytest + +from aiperf.common import random_generator as rng +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator + + +class TestHashIdRandomGenerator: + """Test HashIdRandomGenerator for parallel processing reproducibility.""" + + @pytest.fixture(autouse=True) + def setup_rng(self): + """Initialize global RNG before each test.""" + rng.reset() + rng.init(42) + yield + rng.reset() + + def test_deterministic_seeding_per_hash_id(self): + """Test that the same hash_id always produces the same random sequence.""" + base_rng = rng.derive("test.base") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Generate values for hash_id=123 + hash_rng.reseed_for_hash_id(123) + values_1 = [hash_rng.randrange(1000) for _ in range(10)] + + # Generate values for hash_id=456 + hash_rng.reseed_for_hash_id(456) + values_2 = [hash_rng.randrange(1000) for _ in range(10)] + + # Generate values for hash_id=123 again (should match values_1) + hash_rng.reseed_for_hash_id(123) + values_3 = [hash_rng.randrange(1000) for _ in range(10)] + + # Same hash_id produces same sequence + assert values_1 == values_3 + # Different hash_ids produce different sequences + assert values_1 != values_2 + + def test_independence_across_workers(self): + """Test that different worker instances produce identical results for same hash_id.""" + # Simulate Worker 1 + base_rng_1 = rng.derive("worker.corpus") + hash_rng_1 = HashIdRandomGenerator.from_base_rng(base_rng_1) + + # Reset and re-init to simulate Worker 2 with same global seed + rng.reset() + rng.init(42) + + # Simulate Worker 2 + base_rng_2 = rng.derive("worker.corpus") + hash_rng_2 = HashIdRandomGenerator.from_base_rng(base_rng_2) + + # Both workers process hash_id=789 + hash_rng_1.reseed_for_hash_id(789) + worker1_values = [hash_rng_1.randrange(1000) for _ in range(10)] + + hash_rng_2.reseed_for_hash_id(789) + worker2_values = [hash_rng_2.randrange(1000) for _ in range(10)] + + # Workers produce identical results + assert worker1_values == worker2_values + + def test_order_independence(self): + """Test that processing order doesn't affect reproducibility.""" + base_rng = rng.derive("test.order") + + # Process in order: 100, 200, 300 + hash_rng_1 = HashIdRandomGenerator.from_base_rng(base_rng) + results_order1 = {} + for hash_id in [100, 200, 300]: + hash_rng_1.reseed_for_hash_id(hash_id) + results_order1[hash_id] = [hash_rng_1.randrange(1000) for _ in range(5)] + + # Process in different order: 300, 100, 200 + hash_rng_2 = HashIdRandomGenerator.from_base_rng(base_rng) + results_order2 = {} + for hash_id in [300, 100, 200]: + hash_rng_2.reseed_for_hash_id(hash_id) + results_order2[hash_id] = [hash_rng_2.randrange(1000) for _ in range(5)] + + # Same hash_id produces same results regardless of order + assert results_order1[100] == results_order2[100] + assert results_order1[200] == results_order2[200] + assert results_order1[300] == results_order2[300] + + def test_parallel_cache_simulation(self): + """Simulate parallel workers with individual caches.""" + # Worker 1 processes hash_ids: 1, 2, 3 + base_rng_w1 = rng.derive("worker.corpus") + hash_rng_w1 = HashIdRandomGenerator.from_base_rng(base_rng_w1) + cache_w1 = {} + + for hash_id in [1, 2, 3]: + hash_rng_w1.reseed_for_hash_id(hash_id) + cache_w1[hash_id] = [hash_rng_w1.randrange(1000) for _ in range(5)] + + # Reset for Worker 2 simulation + rng.reset() + rng.init(42) + + # Worker 2 processes hash_ids: 3, 4, 5 (overlapping hash_id=3) + base_rng_w2 = rng.derive("worker.corpus") + hash_rng_w2 = HashIdRandomGenerator.from_base_rng(base_rng_w2) + cache_w2 = {} + + for hash_id in [3, 4, 5]: + hash_rng_w2.reseed_for_hash_id(hash_id) + cache_w2[hash_id] = [hash_rng_w2.randrange(1000) for _ in range(5)] + + # Both workers produce identical results for hash_id=3 + assert cache_w1[3] == cache_w2[3] + + # Different hash_ids produce different results + assert cache_w1[1] != cache_w1[2] + assert cache_w2[4] != cache_w2[5] + + def test_non_deterministic_mode(self): + """Test that non-deterministic mode works (seed=None).""" + rng.reset() + rng.init(None) + + base_rng = rng.derive("test.nondeterministic") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Should not raise errors + hash_rng.reseed_for_hash_id(123) + values = [hash_rng.randrange(1000) for _ in range(10)] + + assert len(values) == 10 + # Ensure that an actual seed is created for the HashIdRandomGenerator + assert hash_rng.seed is not None + + def test_multiple_random_operations(self): + """Test various random operations after reseeding.""" + base_rng = rng.derive("test.operations") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Test for hash_id=555 + hash_rng.reseed_for_hash_id(555) + int_val = hash_rng.randrange(100, 200) + float_val = hash_rng.uniform(0.0, 1.0) + choice_val = hash_rng.choice([10, 20, 30, 40]) + + # Re-seed with same hash_id and verify reproducibility + hash_rng.reseed_for_hash_id(555) + int_val_2 = hash_rng.randrange(100, 200) + float_val_2 = hash_rng.uniform(0.0, 1.0) + choice_val_2 = hash_rng.choice([10, 20, 30, 40]) + + assert int_val == int_val_2 + assert float_val == float_val_2 + assert choice_val == choice_val_2 + + def test_hash_collision_independence(self): + """Test that different hash_ids produce independent sequences.""" + base_rng = rng.derive("test.collision") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_ids = [1, 12345, 999999, 7777777, 123456789] + sequences = {} + + for hash_id in hash_ids: + hash_rng.reseed_for_hash_id(hash_id) + sequences[hash_id] = [hash_rng.randrange(1000) for _ in range(20)] + + # All sequences should be different + unique_sequences = set(tuple(seq) for seq in sequences.values()) + assert len(unique_sequences) == len(hash_ids) + + def test_reseed_state_isolation(self): + """Test that reseeding properly isolates state between hash_ids.""" + base_rng = rng.derive("test.isolation") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + # Generate partial sequence for hash_id=111 + hash_rng.reseed_for_hash_id(111) + partial_seq_111 = [hash_rng.randrange(1000) for _ in range(3)] + + # Switch to hash_id=222 and generate values + hash_rng.reseed_for_hash_id(222) + _ = [hash_rng.randrange(1000) for _ in range(5)] + + # Return to hash_id=111 and continue - should continue from fresh state + hash_rng.reseed_for_hash_id(111) + full_seq_111 = [hash_rng.randrange(1000) for _ in range(10)] + + # The first 3 values should match the partial sequence + assert full_seq_111[:3] == partial_seq_111 + + def test_trace_id_isolation(self): + """Test that different trace_ids produce different sequences for same hash_id.""" + base_rng = rng.derive("test.trace_id") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_rng.set_trace_id("trace_file_a") + hash_rng.reseed_for_hash_id(100) + values_a = [hash_rng.randrange(1000) for _ in range(10)] + + hash_rng.set_trace_id("trace_file_b") + hash_rng.reseed_for_hash_id(100) + values_b = [hash_rng.randrange(1000) for _ in range(10)] + + assert values_a != values_b + + +class TestHashIdRandomGeneratorEdgeCases: + """Test edge cases and error conditions.""" + + @pytest.fixture(autouse=True) + def setup_rng(self): + """Initialize global RNG before each test.""" + rng.reset() + rng.init(42) + yield + rng.reset() + + def test_zero_hash_id(self): + """Test with hash_id=0.""" + base_rng = rng.derive("test.zero") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_rng.reseed_for_hash_id(0) + values = [hash_rng.randrange(1000) for _ in range(5)] + + assert len(values) == 5 + + def test_negative_hash_id(self): + """Test with negative hash_id.""" + base_rng = rng.derive("test.negative") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + hash_rng.reseed_for_hash_id(-123) + values = [hash_rng.randrange(1000) for _ in range(5)] + + # Should not produce the same as positive 123 + hash_rng.reseed_for_hash_id(123) + values_positive = [hash_rng.randrange(1000) for _ in range(5)] + + assert values != values_positive + + def test_large_hash_id(self): + """Test with very large hash_id.""" + base_rng = rng.derive("test.large") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + + large_hash_id = 999999999999999999 + hash_rng.reseed_for_hash_id(large_hash_id) + values = [hash_rng.randrange(1000) for _ in range(5)] + + assert len(values) == 5 + + @pytest.mark.parametrize( + "operation", + [ + lambda rng: rng.integers(0, 100, size=2), + lambda rng: rng.random_batch(2), + lambda rng: rng.shuffle([1, 2, 3, 4, 5]), + lambda rng: rng.numpy_choice([1, 2, 3, 4, 5], size=2), + lambda rng: rng.normal(0, 1, size=2), + ], + ) + def test_numpy_rng_raises_exception(self, operation): + """Test that using NumPy RNG operations raises an exception.""" + base_rng = rng.derive("test.numpy") + hash_rng = HashIdRandomGenerator.from_base_rng(base_rng) + hash_rng.reseed_for_hash_id(123) + + with pytest.raises( + RuntimeError, match="HashIdRandomGenerator does not support NumPy RNG" + ): + operation(hash_rng) diff --git a/tests/unit/common/test_tokenizer_kwarg_overrides.py b/tests/unit/common/test_tokenizer_kwarg_overrides.py new file mode 100644 index 000000000..819070d67 --- /dev/null +++ b/tests/unit/common/test_tokenizer_kwarg_overrides.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for tokenizer kwarg override detection and application. + +Tokenizers like Kimi use non-standard kwargs (e.g. `allow_special_tokens` +instead of `add_special_tokens`). Passing unsupported kwargs triggers the +slow `PreTrainedTokenizer.super()` fallback. These tests verify that +`_supports_kwarg` and `_apply_kwarg_overrides` correctly detect and adapt. +""" + +import pytest + +from aiperf.common.tokenizer import Tokenizer, _supports_kwarg + +# -- Fake tokenizer backends for testing -- + + +class StandardTokenizerBackend: + """Mimics a standard HuggingFace tokenizer (e.g. Qwen).""" + + def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> list[int]: + return list(range(len(text.split()))) + + def decode( + self, token_ids: list[int], skip_special_tokens: bool = False, **kwargs + ) -> str: + return " ".join(f"t{i}" for i in token_ids) + + def __call__(self, text: str, add_special_tokens: bool = True, **kwargs) -> dict: + return {"input_ids": self.encode(text, add_special_tokens=add_special_tokens)} + + bos_token_id = 1 + eos_token_id = 2 + + +class KimiLikeTokenizerBackend: + """Mimics Kimi's TikTokenTokenizer: uses allow_special_tokens, no skip_special_tokens.""" + + def encode( + self, text: str, allow_special_tokens: bool = True, **kwargs + ) -> list[int]: + if kwargs: + raise TypeError( + f"Unexpected kwargs would trigger slow super().encode: {kwargs}" + ) + return list(range(len(text.split()))) + + def decode(self, token_ids: list[int] | int, **kwargs) -> str: + if kwargs: + raise TypeError( + f"Unexpected kwargs would trigger slow super().decode: {kwargs}" + ) + if isinstance(token_ids, int): + token_ids = [token_ids] + return " ".join(f"t{i}" for i in token_ids) + + def __call__(self, text: str, allow_special_tokens: bool = True, **kwargs) -> dict: + return { + "input_ids": self.encode(text, allow_special_tokens=allow_special_tokens) + } + + bos_token_id = 1 + eos_token_id = 2 + + +class MinimalDecodeTokenizerBackend: + """Tokenizer with standard encode but minimal decode (no skip_special_tokens).""" + + def encode(self, text: str, add_special_tokens: bool = True, **kwargs) -> list[int]: + return list(range(len(text.split()))) + + def decode(self, token_ids: list[int], **kwargs) -> str: + if kwargs: + raise TypeError(f"Unexpected kwargs: {kwargs}") + return " ".join(f"t{i}" for i in token_ids) + + bos_token_id = 1 + eos_token_id = 2 + + +class KwargsOnlyTokenizerBackend: + """Tokenizer that only accepts **kwargs (no named params beyond self/text).""" + + def encode(self, text, **kwargs): + return [0] + + def decode(self, token_ids, **kwargs): + return "decoded" + + bos_token_id = 0 + eos_token_id = 0 + + +# -- _supports_kwarg tests -- + + +class TestSupportsKwarg: + def test_detects_named_param(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "encode", "add_special_tokens") is True + + def test_rejects_missing_param(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "encode", "allow_special_tokens") is False + + def test_detects_allow_special_tokens(self): + backend = KimiLikeTokenizerBackend() + assert _supports_kwarg(backend, "encode", "allow_special_tokens") is True + + def test_rejects_add_special_tokens_on_kimi(self): + backend = KimiLikeTokenizerBackend() + assert _supports_kwarg(backend, "encode", "add_special_tokens") is False + + def test_detects_skip_special_tokens_on_standard(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "decode", "skip_special_tokens") is True + + def test_rejects_skip_special_tokens_on_kimi(self): + backend = KimiLikeTokenizerBackend() + assert _supports_kwarg(backend, "decode", "skip_special_tokens") is False + + def test_missing_method_returns_false(self): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, "nonexistent_method", "anything") is False + + def test_kwargs_only_method(self): + backend = KwargsOnlyTokenizerBackend() + assert _supports_kwarg(backend, "encode", "add_special_tokens") is False + assert _supports_kwarg(backend, "encode", "allow_special_tokens") is False + + @pytest.mark.parametrize( + ("method", "kwarg", "expected"), + [ + ("encode", "text", True), + ("encode", "add_special_tokens", True), + ("decode", "token_ids", True), + ("decode", "skip_special_tokens", True), + ("encode", "nonexistent", False), + ], + ) + def test_standard_backend_parametrized(self, method, kwarg, expected): + backend = StandardTokenizerBackend() + assert _supports_kwarg(backend, method, kwarg) is expected + + +# -- _apply_kwarg_overrides tests -- + + +class TestApplyKwargOverrides: + @staticmethod + def _make_tokenizer(backend) -> Tokenizer: + tok = Tokenizer() + tok._tokenizer = backend + tok._apply_kwarg_overrides() + return tok + + def test_standard_tokenizer_keeps_defaults(self): + tok = self._make_tokenizer(StandardTokenizerBackend()) + assert tok._encode_args == {"add_special_tokens": False} + assert tok._call_args == {"add_special_tokens": False} + assert tok._decode_args == {"skip_special_tokens": True} + + def test_kimi_like_overrides_encode_and_call_args(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + assert tok._encode_args == {"allow_special_tokens": False} + assert tok._call_args == {"allow_special_tokens": False} + + def test_kimi_like_clears_decode_args(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + assert tok._decode_args == {} + + def test_minimal_decode_clears_decode_args(self): + tok = self._make_tokenizer(MinimalDecodeTokenizerBackend()) + assert tok._encode_args == {"add_special_tokens": False} + assert tok._decode_args == {} + + def test_none_tokenizer_is_noop(self): + tok = Tokenizer() + tok._apply_kwarg_overrides() + assert tok._encode_args == {"add_special_tokens": False} + assert tok._call_args == {"add_special_tokens": False} + assert tok._decode_args == {"skip_special_tokens": True} + + +# -- End-to-end: encode/decode through Tokenizer wrapper -- + + +class TestKwargOverridesEndToEnd: + @staticmethod + def _make_tokenizer(backend) -> Tokenizer: + tok = Tokenizer() + tok._tokenizer = backend + tok._apply_kwarg_overrides() + return tok + + def test_standard_encode_passes_correct_kwargs(self): + tok = self._make_tokenizer(StandardTokenizerBackend()) + result = tok.encode("hello world") + assert isinstance(result, list) + + def test_standard_decode_passes_correct_kwargs(self): + tok = self._make_tokenizer(StandardTokenizerBackend()) + result = tok.decode([0, 1, 2]) + assert isinstance(result, str) + + def test_kimi_encode_does_not_raise(self): + """Kimi backend raises TypeError if unexpected kwargs are passed.""" + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok.encode("hello world") + assert isinstance(result, list) + + def test_kimi_decode_does_not_raise(self): + """Kimi backend raises TypeError if unexpected kwargs are passed.""" + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok.decode([0, 1, 2]) + assert isinstance(result, str) + + def test_kimi_call_does_not_raise(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok("hello world") + assert "input_ids" in result + + def test_standard_encode_without_override_would_fail_on_kimi(self): + """Verify that without overrides, Kimi backend rejects add_special_tokens.""" + tok = Tokenizer() + tok._tokenizer = KimiLikeTokenizerBackend() + # Don't call _apply_kwarg_overrides - defaults still have add_special_tokens + with pytest.raises(TypeError, match="Unexpected kwargs"): + tok.encode("hello world") + + def test_standard_decode_without_override_would_fail_on_kimi(self): + """Verify that without overrides, Kimi backend rejects skip_special_tokens.""" + tok = Tokenizer() + tok._tokenizer = KimiLikeTokenizerBackend() + with pytest.raises(TypeError, match="Unexpected kwargs"): + tok.decode([0, 1, 2]) + + def test_user_kwargs_override_defaults(self): + """User-provided kwargs should override the defaults.""" + tok = self._make_tokenizer(StandardTokenizerBackend()) + result = tok.encode("hello", add_special_tokens=True) + assert isinstance(result, list) + + def test_kimi_user_kwargs_override_defaults(self): + tok = self._make_tokenizer(KimiLikeTokenizerBackend()) + result = tok.encode("hello", allow_special_tokens=True) + assert isinstance(result, list) diff --git a/tests/unit/dataset/composer/test_base_composer.py b/tests/unit/dataset/composer/test_base_composer.py index a889be039..da8422733 100644 --- a/tests/unit/dataset/composer/test_base_composer.py +++ b/tests/unit/dataset/composer/test_base_composer.py @@ -245,10 +245,10 @@ def test_prefix_prompt_enabled_property(self, base_config, mock_tokenizer): composer2 = ConcreteBaseComposer(base_config, mock_tokenizer) assert composer2.prefix_prompt_enabled is False - def test_inject_context_prompts_with_shared_system_prompt( + def test_finalize_conversation_with_shared_system_prompt( self, base_config, mock_tokenizer ): - """Test _inject_context_prompts with shared system prompt.""" + """Test _finalize_conversation with shared system prompt.""" base_config.input.prompt.prefix_prompt.shared_system_prompt_length = 50 base_config.input.prompt.prefix_prompt.length = 0 base_config.input.conversation.num = 3 @@ -259,7 +259,6 @@ def test_inject_context_prompts_with_shared_system_prompt( ): composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -268,13 +267,13 @@ def test_inject_context_prompts_with_shared_system_prompt( Conversation(session_id="conv_2"), ] - # Mock the prompt generator method with patch.object( composer.prompt_generator, "get_shared_system_prompt", return_value="shared system prompt text", ): - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # All conversations should have the same system message assert conversations[0].system_message == "shared system prompt text" @@ -285,17 +284,16 @@ def test_inject_context_prompts_with_shared_system_prompt( assert conversations[1].user_context_message is None assert conversations[2].user_context_message is None - def test_inject_context_prompts_with_user_context_prompt( + def test_finalize_conversation_with_user_context_prompt( self, base_config, mock_tokenizer ): - """Test _inject_context_prompts with user context prompts.""" + """Test _finalize_conversation with user context prompts.""" base_config.input.prompt.prefix_prompt.user_context_prompt_length = 30 base_config.input.prompt.prefix_prompt.length = 0 base_config.input.conversation.num = 3 composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -304,7 +302,6 @@ def test_inject_context_prompts_with_user_context_prompt( Conversation(session_id="conv_2"), ] - # Mock the prompt generator method def mock_generate_user_context(index): return f"user context {index}" @@ -313,7 +310,8 @@ def mock_generate_user_context(index): "generate_user_context_prompt", side_effect=mock_generate_user_context, ): - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # Each conversation should have unique user context assert conversations[0].user_context_message == "user context 0" @@ -324,10 +322,8 @@ def mock_generate_user_context(index): assert conversations[1].system_message is None assert conversations[2].system_message is None - def test_inject_context_prompts_with_both_prompts( - self, base_config, mock_tokenizer - ): - """Test _inject_context_prompts with both shared system and user context prompts.""" + def test_finalize_conversation_with_both_prompts(self, base_config, mock_tokenizer): + """Test _finalize_conversation with both shared system and user context prompts.""" base_config.input.prompt.prefix_prompt.shared_system_prompt_length = 50 base_config.input.prompt.prefix_prompt.user_context_prompt_length = 30 base_config.input.prompt.prefix_prompt.length = 0 @@ -339,7 +335,6 @@ def test_inject_context_prompts_with_both_prompts( ): composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -347,7 +342,6 @@ def test_inject_context_prompts_with_both_prompts( Conversation(session_id="conv_1"), ] - # Mock both prompt generator methods def mock_generate_user_context(index): return f"user context {index}" @@ -363,7 +357,8 @@ def mock_generate_user_context(index): side_effect=mock_generate_user_context, ), ): - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # Both conversations should have system message assert conversations[0].system_message == "shared system prompt" @@ -372,8 +367,8 @@ def mock_generate_user_context(index): assert conversations[0].user_context_message == "user context 0" assert conversations[1].user_context_message == "user context 1" - def test_inject_context_prompts_with_no_prompts(self, base_config, mock_tokenizer): - """Test _inject_context_prompts when no context prompts are configured.""" + def test_finalize_conversation_with_no_prompts(self, base_config, mock_tokenizer): + """Test _finalize_conversation when no context prompts are configured.""" base_config.input.prompt.prefix_prompt.length = 0 base_config.input.prompt.prefix_prompt.shared_system_prompt_length = None base_config.input.prompt.prefix_prompt.user_context_prompt_length = None @@ -381,7 +376,6 @@ def test_inject_context_prompts_with_no_prompts(self, base_config, mock_tokenize composer = ConcreteBaseComposer(base_config, mock_tokenizer) - # Create mock conversations from aiperf.common.models import Conversation conversations = [ @@ -389,8 +383,8 @@ def test_inject_context_prompts_with_no_prompts(self, base_config, mock_tokenize Conversation(session_id="conv_1"), ] - # Should not call any prompt generator methods - composer._inject_context_prompts(conversations) + for i, conv in enumerate(conversations): + composer._finalize_conversation(conv, i) # No messages should be set assert conversations[0].system_message is None diff --git a/tests/unit/dataset/composer/test_custom_composer.py b/tests/unit/dataset/composer/test_custom_composer.py index 75ec844b5..f861df296 100644 --- a/tests/unit/dataset/composer/test_custom_composer.py +++ b/tests/unit/dataset/composer/test_custom_composer.py @@ -64,34 +64,26 @@ def test_create_loader_instance_dataset_types( composer._create_loader_instance(dataset_type) assert isinstance(composer.loader, expected_instance) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) - def test_create_dataset_trace( - self, mock_check_file, mock_parallel_decode, trace_config, mock_tokenizer - ): + def test_create_dataset_trace(self, mock_check_file, trace_config, mock_tokenizer): """Test that create_dataset returns correct type.""" - mock_parallel_decode.return_value = ["decoded 1", "decoded 2", "decoded 3"] composer = CustomDatasetComposer(trace_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) == 3 assert all(isinstance(c, Conversation) for c in conversations) assert all(isinstance(turn, Turn) for c in conversations for turn in c.turns) assert all(len(turn.texts) == 1 for c in conversations for turn in c.turns) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) - def test_max_tokens_config( - self, mock_check_file, mock_parallel_decode, trace_config, mock_tokenizer - ): - mock_parallel_decode.return_value = ["decoded 1", "decoded 2", "decoded 3"] + def test_max_tokens_config(self, mock_check_file, trace_config, mock_tokenizer): trace_config.input.prompt.output_tokens.mean = 120 trace_config.input.prompt.output_tokens.stddev = 8.0 composer = CustomDatasetComposer(trace_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) > 0 # With global RNG, verify max_tokens is set to a positive integer @@ -104,7 +96,6 @@ def test_max_tokens_config( # Should be roughly around the mean of 120 (within 3 stddev) assert 96 < turn.max_tokens < 144 - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") @patch("aiperf.dataset.composer.custom.check_file_exists") @patch("builtins.open", mock_open(read_data=MOCK_TRACE_CONTENT)) @patch("pathlib.Path.iterdir", return_value=[]) @@ -112,17 +103,15 @@ def test_max_tokens_mooncake( self, mock_iterdir, mock_check_file, - mock_parallel_decode, custom_config, mock_tokenizer, ): """Test that max_tokens can be set from the custom file""" - mock_parallel_decode.return_value = ["decoded 1", "decoded 2", "decoded 3"] mock_check_file.return_value = None custom_config.input.custom_dataset_type = CustomDatasetType.MOONCAKE_TRACE composer = CustomDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: for turn in conversation.turns: @@ -151,9 +140,8 @@ def test_create_dataset_empty_result( mock_get_class.return_value = mock_loader_class composer = CustomDatasetComposer(custom_config, mock_tokenizer) - result = composer.create_dataset() + result = list(composer.create_dataset()) - assert isinstance(result, list) assert len(result) == 0 diff --git a/tests/unit/dataset/composer/test_synthetic_composer.py b/tests/unit/dataset/composer/test_synthetic_composer.py index 0c9e1ea85..c333bfe9c 100644 --- a/tests/unit/dataset/composer/test_synthetic_composer.py +++ b/tests/unit/dataset/composer/test_synthetic_composer.py @@ -97,7 +97,7 @@ def test_initialization_with_all_zero_mean(self, mock_tokenizer): def test_create_dataset_basic(self, synthetic_config, mock_tokenizer): """Test basic dataset creation with text-only conversations.""" composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test create_dataset returns correct number of conversations assert len(conversations) == 5 # num_conversations @@ -119,7 +119,7 @@ def test_create_dataset_basic(self, synthetic_config, mock_tokenizer): def test_create_dataset_with_images(self, image_config, mock_tokenizer): """Test dataset creation with image generation enabled.""" composer = SyntheticDatasetComposer(image_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations include image payloads assert len(conversations) == 3 @@ -140,7 +140,7 @@ def test_create_dataset_with_images(self, image_config, mock_tokenizer): def test_create_dataset_with_audio(self, audio_config, mock_tokenizer): """Test dataset creation with audio generation enabled.""" composer = SyntheticDatasetComposer(audio_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations include audio payloads assert len(conversations) == 3 @@ -160,7 +160,7 @@ def test_create_dataset_with_audio(self, audio_config, mock_tokenizer): def test_create_dataset_multimodal(self, multimodal_config, mock_tokenizer): """Test dataset creation with both image and audio enabled.""" composer = SyntheticDatasetComposer(multimodal_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations include both image and audio payloads assert ( @@ -183,7 +183,7 @@ def test_create_dataset_with_prefix_prompts( ): """Test dataset creation with prefix prompts enabled.""" composer = SyntheticDatasetComposer(prefix_prompt_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) == 5 for conversation in conversations: @@ -198,7 +198,7 @@ def test_create_dataset_with_prefix_prompts( def test_create_dataset_multiple_turns(self, multiturn_config, mock_tokenizer): """Test dataset creation with multiple turns and delays.""" composer = SyntheticDatasetComposer(multiturn_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test conversations have multiple turns assert len(conversations) == 4 @@ -303,7 +303,7 @@ def test_turn_delays_from_config_options(self, mock_tokenizer): rng.reset() rng.init(42) # Set seed for reproducibility composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Verify conversations were created assert len(conversations) == 5 @@ -329,7 +329,7 @@ def test_turn_delays_from_config_options(self, mock_tokenizer): rng.reset() rng.init(42) # Reset seed composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: # First turn should still have no delay @@ -365,7 +365,7 @@ def test_turn_delays_with_zero_mean(self, mock_tokenizer): rng.reset() rng.init(42) composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: # All turns should have None delay when mean=0 @@ -511,7 +511,7 @@ def test_zero_conversations(self, synthetic_config, mock_tokenizer): synthetic_config.input.conversation.num_dataset_entries = 0 composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) assert len(conversations) == 0 @@ -536,7 +536,7 @@ def test_edge_case_statistical_parameters(self, mock_tokenizer): ) composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Test with very small/large mean and stddev values assert len(conversations) == 2 @@ -555,7 +555,7 @@ def test_multi_turn_does_not_control_dataset_entries(self, mock_tokenizer): ) composer = SyntheticDatasetComposer(config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Verify that num_dataset_entries controls the number of conversations generated assert len(conversations) == 10 @@ -568,7 +568,7 @@ def test_different_conversation_counts( synthetic_config.input.conversation.num_dataset_entries = num_conversations composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Parametrized test for different num_conversations values assert len(conversations) == num_conversations @@ -579,7 +579,7 @@ def test_different_batch_sizes(self, synthetic_config, batch_size, mock_tokenize synthetic_config.input.prompt.batch_size = batch_size composer = SyntheticDatasetComposer(synthetic_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Parametrized test for different batch_size values assert len(conversations) > 0 @@ -605,7 +605,7 @@ def test_missing_required_generators(self, synthetic_config, mock_tokenizer): with pytest.raises( ValueError, match="Text prompt generation requires a tokenizer" ): - composer.create_dataset() + list(composer.create_dataset()) def test_reproducibility_with_fixed_seed(self, multimodal_config, mock_tokenizer): """Test that dataset generation is reproducible with fixed random seed.""" @@ -620,12 +620,12 @@ def test_reproducibility_with_fixed_seed(self, multimodal_config, mock_tokenizer rng.reset() rng.init(42) composer1 = SyntheticDatasetComposer(multimodal_config, mock_tokenizer) - conversations1 = composer1.create_dataset() + conversations1 = list(composer1.create_dataset()) rng.reset() rng.init(42) composer2 = SyntheticDatasetComposer(multimodal_config, mock_tokenizer) - conversations2 = composer2.create_dataset() + conversations2 = list(composer2.create_dataset()) # Basic structure should be the same assert len(conversations1) == len(conversations2) @@ -653,7 +653,7 @@ def test_model_selection_random(self, custom_config, mock_tokenizer): custom_config.endpoint.model_names = ["test-model-1", "test-model-2"] composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # With random selection, verify models are from the valid set for conversation in conversations: @@ -665,7 +665,7 @@ def test_model_selection_round_robin(self, custom_config, mock_tokenizer): custom_config.endpoint.model_names = ["test-model-1", "test-model-2"] composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # Check that models are selected in round-robin fashion for i, conversation in enumerate(conversations): @@ -682,7 +682,7 @@ def test_max_tokens_integration_with_mean(self, custom_config, mock_tokenizer): custom_config.input.prompt.output_tokens.stddev = 5.0 composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) # With global RNG, verify max_tokens is set to a positive integer # around the mean of 100 @@ -699,7 +699,7 @@ def test_max_tokens_not_set_when_mean_none(self, custom_config, mock_tokenizer): custom_config.input.prompt.output_tokens.stddev = None composer = SyntheticDatasetComposer(custom_config, mock_tokenizer) - conversations = composer.create_dataset() + conversations = list(composer.create_dataset()) for conversation in conversations: for turn in conversation.turns: diff --git a/tests/unit/dataset/composer/test_synthetic_rankings_composer.py b/tests/unit/dataset/composer/test_synthetic_rankings_composer.py index f86b64d86..422adf7c0 100644 --- a/tests/unit/dataset/composer/test_synthetic_rankings_composer.py +++ b/tests/unit/dataset/composer/test_synthetic_rankings_composer.py @@ -18,7 +18,7 @@ def test_create_dataset_structure(synthetic_config, mock_tokenizer): synthetic_config.input.rankings.passages.stddev = 1 composer = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - dataset = composer.create_dataset() + dataset = list(composer.create_dataset()) assert len(dataset) == synthetic_config.input.conversation.num_dataset_entries for conv in dataset: @@ -42,7 +42,7 @@ def test_passage_count_distribution(synthetic_config, mock_tokenizer): synthetic_config.input.rankings.passages.stddev = 2 composer = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - dataset = composer.create_dataset() + dataset = list(composer.create_dataset()) passage_counts = [len(conv.turns[0].texts[1].contents) for conv in dataset] assert all(1 <= c <= 10 for c in passage_counts) @@ -56,10 +56,10 @@ def test_reproducibility_fixed_seed(synthetic_config, mock_tokenizer): synthetic_config.input.random_seed = 42 composer1 = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - data1 = composer1.create_dataset() + data1 = list(composer1.create_dataset()) composer2 = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - data2 = composer2.create_dataset() + data2 = list(composer2.create_dataset()) # Session IDs differ (fresh), but text contents should match for c1, c2 in zip(data1, data2, strict=True): @@ -78,7 +78,7 @@ def test_rankings_specific_token_options(synthetic_config, mock_tokenizer): synthetic_config.input.random_seed = 42 composer = SyntheticRankingsDatasetComposer(synthetic_config, mock_tokenizer) - dataset = composer.create_dataset() + dataset = list(composer.create_dataset()) # Verify that data was generated assert len(dataset) > 0 diff --git a/tests/unit/dataset/conftest.py b/tests/unit/dataset/conftest.py index d5fabd459..704eb8994 100644 --- a/tests/unit/dataset/conftest.py +++ b/tests/unit/dataset/conftest.py @@ -5,16 +5,16 @@ """ from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest import aiperf.endpoints # noqa: F401 # Import to register endpoints import aiperf.transports # noqa: F401 # Import to register transports from aiperf.common.config import EndpointConfig, OutputConfig, ServiceConfig, UserConfig -from aiperf.common.models import Conversation +from aiperf.common.models import Conversation, DatasetMetadata from aiperf.dataset.dataset_manager import DatasetManager -from aiperf.plugin.enums import EndpointType +from aiperf.plugin.enums import DatasetSamplingStrategy, EndpointType @pytest.fixture @@ -32,37 +32,64 @@ def user_config(tmp_path: Path) -> UserConfig: @pytest.fixture -def empty_dataset_manager( +def dataset_manager( user_config: UserConfig, ) -> DatasetManager: - """Create a DatasetManager instance with empty dataset.""" - manager = DatasetManager( + """Create a DatasetManager instance.""" + return DatasetManager( service_config=ServiceConfig(), user_config=user_config, service_id="test_dataset_manager", ) - manager.dataset = {} - return manager -@pytest.fixture -def populated_dataset_manager( +def _create_dataset_manager_with_client( user_config: UserConfig, - sample_conversations: dict[str, Conversation], + conversations: dict[str, Conversation], ) -> DatasetManager: - """Create a DatasetManager instance with sample data.""" + """Create a DatasetManager with a mock dataset client backed by conversations.""" manager = DatasetManager( service_config=ServiceConfig(), user_config=user_config, service_id="test_dataset_manager", ) - manager.dataset = sample_conversations + + async def mock_get_conversation(conversation_id: str) -> Conversation: + if conversation_id not in conversations: + raise KeyError(conversation_id) + return conversations[conversation_id] + + mock_client = AsyncMock() + mock_client.get_conversation = AsyncMock(side_effect=mock_get_conversation) + manager._dataset_client = mock_client + + manager.dataset_metadata = DatasetMetadata( + conversations=[conv.metadata() for conv in conversations.values()], + sampling_strategy=DatasetSamplingStrategy.RANDOM, + ) return manager +@pytest.fixture +def populated_dataset_manager( + user_config: UserConfig, + sample_conversations: dict[str, Conversation], +) -> DatasetManager: + """Create a DatasetManager with a mock dataset client for payload tests.""" + return _create_dataset_manager_with_client(user_config, sample_conversations) + + +@pytest.fixture +def empty_dataset_manager( + user_config: UserConfig, +) -> DatasetManager: + """Create a DatasetManager with an empty dataset client.""" + return _create_dataset_manager_with_client(user_config, {}) + + @pytest.fixture def capture_file_writes(): - """Provide a fixture to capture file write operations for testing purposes.""" + """Capture file write operations for testing.""" class FileWriteCapture: def __init__(self): diff --git a/tests/unit/dataset/generator/test_parallel_decode.py b/tests/unit/dataset/generator/test_parallel_decode.py deleted file mode 100644 index a6b0cca79..000000000 --- a/tests/unit/dataset/generator/test_parallel_decode.py +++ /dev/null @@ -1,269 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Unit tests for parallel_decode module.""" - -import importlib -import os -from unittest.mock import MagicMock, call, patch - -import pytest - -from aiperf.dataset.generator.parallel_decode import _set_daemon, parallel_decode - -# Import the module directly (not through __init__.py which exports the function) -pd_module = importlib.import_module("aiperf.dataset.generator.parallel_decode") - - -class TestParallelDecode: - """Test suite for parallel_decode module.""" - - def test_parallel_decode_empty_list(self): - """Test parallel_decode with empty input returns empty list.""" - result = parallel_decode([], "gpt2") - assert result == [] - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_parallel_decode_small_batch_sequential(self, mock_tokenizer_class): - """Test that small batches (< 10) use sequential decoding.""" - mock_tokenizer = MagicMock() - mock_tokenizer.decode.return_value = "decoded" - mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer - - token_sequences = [[1, 2, 3], [4, 5, 6]] # Less than 10 - result = parallel_decode(token_sequences, "gpt2") - - # Should use sequential decoding (Tokenizer.from_pretrained called once) - mock_tokenizer_class.from_pretrained.assert_called_once_with("gpt2") - assert mock_tokenizer.decode.call_count == 2 - assert result == ["decoded", "decoded"] - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_parallel_decode_large_batch_uses_executor(self, mock_executor_class): - """Test that large batches (>= 10) use ProcessPoolExecutor.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - token_sequences = [[i] for i in range(15)] # 15 sequences - result = parallel_decode(token_sequences, "gpt2") - - # Should use ProcessPoolExecutor - mock_executor_class.assert_called_once() - mock_executor.map.assert_called_once() - assert len(result) == 15 - - @patch.object(pd_module, "mp") - @patch.object(pd_module, "ProcessPoolExecutor") - def test_parallel_decode_respects_max_workers(self, mock_executor_class, mock_mp): - """Test that max_workers parameter is respected.""" - mock_mp.cpu_count.return_value = 16 - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - token_sequences = [[i] for i in range(15)] - parallel_decode(token_sequences, "gpt2", max_workers=4) - - # Should be called with max_workers=4 - call_kwargs = mock_executor_class.call_args.kwargs - assert call_kwargs["max_workers"] == 4 - - @patch.object(pd_module, "mp") - @patch.object(pd_module, "ProcessPoolExecutor") - def test_parallel_decode_default_max_workers_capped_at_8( - self, mock_executor_class, mock_mp - ): - """Test that default max_workers is capped at 8.""" - mock_mp.cpu_count.return_value = 64 # Lots of CPUs - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - token_sequences = [[i] for i in range(15)] - parallel_decode(token_sequences, "gpt2") - - # Should be capped at 8 - call_kwargs = mock_executor_class.call_args.kwargs - assert call_kwargs["max_workers"] == 8 - - -class TestSetDaemon: - """Test suite for _set_daemon helper.""" - - def test_set_daemon_uses_property(self): - """_set_daemon sets daemon via the public property when possible.""" - mock_proc = MagicMock() - mock_proc.daemon = True - with patch.object(pd_module.mp, "current_process", return_value=mock_proc): - _set_daemon(False) - assert mock_proc.daemon is False - - def test_set_daemon_falls_back_to_config_on_assertion_error(self): - """_set_daemon falls back to _config when property raises AssertionError.""" - mock_proc = MagicMock() - type(mock_proc).daemon = property( - fget=lambda self: self._config.get("daemon"), - fset=MagicMock(side_effect=AssertionError), - ) - mock_proc._config = {"daemon": True} - with patch.object(pd_module.mp, "current_process", return_value=mock_proc): - _set_daemon(False) - assert mock_proc._config["daemon"] is False - - -class TestParallelDecodeDaemonFlag: - """Test that parallel_decode properly manages the daemon flag.""" - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_daemon_flag_cleared_before_executor_and_restored_after( - self, mock_executor_class - ): - """Daemon flag is cleared before spawning and restored after.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - mock_process = MagicMock() - mock_process.daemon = True - with ( - patch.object(pd_module.mp, "current_process", return_value=mock_process), - patch.object(pd_module, "_set_daemon") as mock_set, - ): - parallel_decode([[i] for i in range(15)], "gpt2") - - assert mock_set.call_args_list == [call(False), call(True)] - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_daemon_flag_not_toggled_for_non_daemon_process(self, mock_executor_class): - """Daemon flag is not toggled when the process is not a daemon.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.return_value = ["decoded"] * 15 - mock_executor_class.return_value = mock_executor - - mock_process = MagicMock() - mock_process.daemon = False - with ( - patch.object(pd_module.mp, "current_process", return_value=mock_process), - patch.object(pd_module, "_set_daemon") as mock_set, - ): - parallel_decode([[i] for i in range(15)], "gpt2") - - mock_set.assert_not_called() - - @patch.object(pd_module, "ProcessPoolExecutor") - def test_daemon_flag_restored_on_executor_error(self, mock_executor_class): - """Daemon flag is restored even when the executor raises.""" - mock_executor = MagicMock() - mock_executor.__enter__ = MagicMock(return_value=mock_executor) - mock_executor.__exit__ = MagicMock(return_value=False) - mock_executor.map.side_effect = RuntimeError("boom") - mock_executor_class.return_value = mock_executor - - mock_process = MagicMock() - mock_process.daemon = True - with ( - patch.object(pd_module.mp, "current_process", return_value=mock_process), - patch.object(pd_module, "_set_daemon") as mock_set, - pytest.raises(RuntimeError, match="boom"), - ): - parallel_decode([[i] for i in range(15)], "gpt2") - - assert mock_set.call_args_list == [call(False), call(True)] - - -class TestWorkerFunctions: - """Test suite for worker functions.""" - - def test_decode_tokens_raises_without_init(self): - """Test that _decode_tokens raises if worker not initialized.""" - pd_module._worker_tokenizer = None - - with pytest.raises(RuntimeError, match="not initialized"): - pd_module._decode_tokens([1, 2, 3]) - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_loads_tokenizer(self, mock_tokenizer_class): - """Test that _init_worker loads the tokenizer.""" - pd_module._worker_tokenizer = None - pd_module._worker_tokenizer_name = None - - mock_tokenizer = MagicMock() - mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer - - pd_module._init_worker("gpt2") - - mock_tokenizer_class.from_pretrained.assert_called_once_with( - "gpt2", resolve_alias=False - ) - assert pd_module._worker_tokenizer is mock_tokenizer - assert pd_module._worker_tokenizer_name == "gpt2" - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_sets_offline_mode(self, mock_tokenizer_class, monkeypatch): - """Test that _init_worker enables HuggingFace offline mode.""" - monkeypatch.delenv("HF_HUB_OFFLINE", raising=False) - monkeypatch.delenv("TRANSFORMERS_OFFLINE", raising=False) - pd_module._worker_tokenizer = None - pd_module._worker_tokenizer_name = None - - mock_tokenizer_class.from_pretrained.return_value = MagicMock() - - pd_module._init_worker("gpt2") - - assert os.environ["HF_HUB_OFFLINE"] == "1" - assert os.environ["TRANSFORMERS_OFFLINE"] == "1" - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_reuses_tokenizer_same_name(self, mock_tokenizer_class): - """Test that _init_worker reuses tokenizer if same name.""" - mock_tokenizer = MagicMock() - pd_module._worker_tokenizer = mock_tokenizer - pd_module._worker_tokenizer_name = "gpt2" - - pd_module._init_worker("gpt2") - - # Should NOT call from_pretrained again - mock_tokenizer_class.from_pretrained.assert_not_called() - assert pd_module._worker_tokenizer is mock_tokenizer - - @patch("aiperf.common.tokenizer.Tokenizer") - def test_init_worker_reloads_tokenizer_different_name(self, mock_tokenizer_class): - """Test that _init_worker reloads tokenizer if different name.""" - old_tokenizer = MagicMock() - pd_module._worker_tokenizer = old_tokenizer - pd_module._worker_tokenizer_name = "gpt2" - - new_tokenizer = MagicMock() - mock_tokenizer_class.from_pretrained.return_value = new_tokenizer - - pd_module._init_worker("llama") - - mock_tokenizer_class.from_pretrained.assert_called_once_with( - "llama", resolve_alias=False - ) - assert pd_module._worker_tokenizer is new_tokenizer - assert pd_module._worker_tokenizer_name == "llama" - - def test_decode_tokens_uses_worker_tokenizer(self): - """Test that _decode_tokens uses the worker tokenizer.""" - mock_tokenizer = MagicMock() - mock_tokenizer.decode.return_value = "decoded text" - pd_module._worker_tokenizer = mock_tokenizer - - result = pd_module._decode_tokens([1, 2, 3]) - - mock_tokenizer.decode.assert_called_once_with( - [1, 2, 3], skip_special_tokens=False - ) - assert result == "decoded text" diff --git a/tests/unit/dataset/generator/test_prompt_generator.py b/tests/unit/dataset/generator/test_prompt_generator.py index 0bd151bad..bdb2e33d7 100644 --- a/tests/unit/dataset/generator/test_prompt_generator.py +++ b/tests/unit/dataset/generator/test_prompt_generator.py @@ -661,125 +661,30 @@ def test_generate_user_context_prompt_corpus_not_initialized(self, mock_tokenize assert "corpus" in str(exc_info.value).lower() # ============================================================================ - # Decoded String Cache Tests + # HashIdRandomGenerator Integration Tests # ============================================================================ - def test_decoded_cache_initialized_empty(self, basic_config): - """Test that decoded cache is initialized as empty dict.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - assert hasattr(generator, "_decoded_cache") - assert isinstance(generator._decoded_cache, dict) - assert len(generator._decoded_cache) == 0 + def test_hash_id_rng_initialized(self, basic_config): + """Test that HashIdRandomGenerator is initialized in PromptGenerator.""" + from aiperf.common.hash_id_random_generator import HashIdRandomGenerator - def test_decoded_cache_populated_on_first_call(self, basic_config): - """Test that decoded cache is populated after first call.""" tokenizer, config = basic_config generator = PromptGenerator(config, tokenizer) - _ = generator._generate_cached_prompt(10, [1, 2], 5) + assert hasattr(generator, "_hash_id_corpus_rng") + assert isinstance(generator._hash_id_corpus_rng, HashIdRandomGenerator) - # Should have one entry in decoded cache - expected_key = ((1, 2), 10, 5) - assert expected_key in generator._decoded_cache - assert isinstance(generator._decoded_cache[expected_key], str) - - def test_decoded_cache_hit_on_repeated_call(self, basic_config): - """Test that decoded cache is hit on repeated calls with same params.""" + def test_generate_cached_prompt_deterministic_per_hash_id(self, basic_config): + """Test that same hash_ids produce identical prompts across calls.""" tokenizer, config = basic_config generator = PromptGenerator(config, tokenizer) - # First call - should populate cache result1 = generator._generate_cached_prompt(10, [1, 2], 5) - # Second call with same params - should hit cache - with patch.object(generator.tokenizer, "decode") as mock_decode: - result2 = generator._generate_cached_prompt(10, [1, 2], 5) - mock_decode.assert_not_called() # Decode should NOT be called - - assert result1 == result2 - - def test_decoded_cache_miss_different_hash_ids(self, basic_config): - """Test that different hash_ids create different cache entries.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._generate_cached_prompt(10, [1, 2], 5) - _ = generator._generate_cached_prompt(10, [3, 4], 5) - - # Both should be cached separately - assert ((1, 2), 10, 5) in generator._decoded_cache - assert ((3, 4), 10, 5) in generator._decoded_cache - assert len(generator._decoded_cache) == 2 + # Clear block cache to force regeneration + generator._cache.clear() - def test_decoded_cache_miss_different_num_tokens(self, basic_config): - """Test that different num_tokens creates different cache entry.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._generate_cached_prompt(10, [1, 2], 5) - _ = generator._generate_cached_prompt(8, [1, 2], 5) # Different final block - - # Should have two separate entries - assert ((1, 2), 10, 5) in generator._decoded_cache - assert ((1, 2), 8, 5) in generator._decoded_cache - assert len(generator._decoded_cache) == 2 - - def test_decoded_cache_key_structure(self, basic_config): - """Test that cache key is (tuple(hash_ids), num_tokens, block_size).""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - # 12 tokens = 5 + 5 + 2 (valid final block size) - generator._generate_cached_prompt(12, [1, 2, 3], 5) - - expected_key = ((1, 2, 3), 12, 5) - assert expected_key in generator._decoded_cache - - # ============================================================================ - # _build_token_sequence Method Tests - # ============================================================================ + result2 = generator._generate_cached_prompt(10, [1, 2], 5) - def test_build_token_sequence_returns_tokens(self, basic_config): - """Test that _build_token_sequence returns a list of token IDs.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - tokens = generator._build_token_sequence(10, [1, 2], 5) - - assert isinstance(tokens, list) - assert all(isinstance(t, int) for t in tokens) - assert len(tokens) == 10 - - def test_build_token_sequence_populates_cache(self, basic_config): - """Test that _build_token_sequence populates the token block cache.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._build_token_sequence(10, [1, 2], 5) - - # Token block cache should be populated - assert 1 in generator._cache - assert 2 in generator._cache - - def test_build_token_sequence_does_not_populate_decoded_cache(self, basic_config): - """Test that _build_token_sequence does NOT populate decoded cache.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - _ = generator._build_token_sequence(10, [1, 2], 5) - - # Decoded cache should remain empty - assert len(generator._decoded_cache) == 0 - - def test_build_token_sequence_same_validation_as_generate_cached( - self, basic_config - ): - """Test that _build_token_sequence has same validation as _generate_cached_prompt.""" - tokenizer, config = basic_config - generator = PromptGenerator(config, tokenizer) - - # This should raise same error as _generate_cached_prompt - with pytest.raises(ConfigurationError): - generator._build_token_sequence(10, [1, 2, 3], 5) # final_block_size = 0 + # Same hash_ids should produce identical results due to HashIdRandomGenerator + assert result1 == result2 diff --git a/tests/unit/dataset/loader/test_bailian_trace.py b/tests/unit/dataset/loader/test_bailian_trace.py index 97885a993..7fb1e021d 100644 --- a/tests/unit/dataset/loader/test_bailian_trace.py +++ b/tests/unit/dataset/loader/test_bailian_trace.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import logging -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from pydantic import ValidationError @@ -112,8 +112,7 @@ class TestBailianTraceDatasetLoader: def mock_prompt_generator(self): generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture @@ -415,12 +414,7 @@ def test_can_load(self, data, expected): # ---- convert_to_conversations ---- - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_convert_to_conversations( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): - mock_parallel_decode.return_value = ["decoded prompt 1", "decoded prompt 2"] - + def test_convert_to_conversations(self, mock_prompt_generator, default_user_config): trace_data = { "100": [ BailianTrace( @@ -447,7 +441,7 @@ def test_convert_to_conversations( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 2 assert conversations[0].session_id == "100" @@ -460,7 +454,7 @@ def test_convert_empty_data(self, mock_prompt_generator, default_user_config): user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - assert loader.convert_to_conversations({}) == [] + assert list(loader.convert_to_conversations({})) == [] def test_convert_without_hash_ids(self, mock_prompt_generator, default_user_config): """When hash_ids is empty, falls back to normal prompt generation.""" @@ -480,62 +474,18 @@ def test_convert_without_hash_ids(self, mock_prompt_generator, default_user_conf user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 mock_prompt_generator.generate.assert_called_once_with( - mean=100, stddev=0, hash_ids=[] + mean=100, stddev=0, hash_ids=[], block_size=512 ) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_parallel_decode_length_mismatch_raises( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): - """strict=True in zip guards against silent data loss.""" - mock_parallel_decode.return_value = ["only one"] # expecting 2 - - trace_data = { - "1": [ - BailianTrace( - chat_id=1, - timestamp=1.0, - input_length=10, - output_length=5, - hash_ids=[1], - ) - ], - "2": [ - BailianTrace( - chat_id=2, - timestamp=2.0, - input_length=20, - output_length=10, - hash_ids=[2], - ) - ], - } - - loader = BailianTraceDatasetLoader( - filename="dummy.jsonl", - user_config=default_user_config, - prompt_generator=mock_prompt_generator, - ) - - with pytest.raises(ValueError, match="zip"): - loader.convert_to_conversations(trace_data) - # ---- multi-turn conversation conversion ---- - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") def test_multi_turn_conversation_ordering( - self, mock_parallel_decode, mock_prompt_generator, default_user_config + self, mock_prompt_generator, default_user_config ): - mock_parallel_decode.return_value = [ - "prompt turn 1", - "prompt turn 2", - "prompt turn 3", - ] - trace_data = { "100": [ BailianTrace( @@ -572,7 +522,7 @@ def test_multi_turn_conversation_ordering( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 conv = conversations[0] @@ -613,8 +563,7 @@ class TestBailianTraceSynthesisIntegration: def mock_prompt_generator(self): generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator def test_speedup_ratio_scales_timestamps(self, mock_prompt_generator): diff --git a/tests/unit/dataset/loader/test_base_trace_loader.py b/tests/unit/dataset/loader/test_base_trace_loader.py new file mode 100644 index 000000000..bbc36e19a --- /dev/null +++ b/tests/unit/dataset/loader/test_base_trace_loader.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for BaseTraceDatasetLoader. + +Covers file hashing, trace_id lifecycle, parallel vs single-threaded threshold, +convert_to_conversations dispatching, and integration with parallel_convert. +""" + +import hashlib +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from aiperf.common.config import ( + EndpointConfig, + InputConfig, + InputTokensConfig, + PromptConfig, + UserConfig, +) +from aiperf.common.config.config_defaults import InputTokensDefaults +from aiperf.common.models import Conversation +from aiperf.dataset.loader.base_trace_loader import ( + _MIN_TRACES_FOR_PARALLEL, + _compute_file_hash, +) +from aiperf.dataset.loader.mooncake_trace import MooncakeTraceDatasetLoader + +# ----------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------- + + +@pytest.fixture +def create_jsonl_file(): + """Create a temporary JSONL file with custom content.""" + filenames = [] + + def _create_file(content_lines): + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for line in content_lines: + f.write(line + "\n") + filenames.append(f.name) + return f.name + + yield _create_file + + for fn in filenames: + Path(fn).unlink(missing_ok=True) + + +@pytest.fixture +def default_user_config() -> UserConfig: + return UserConfig(endpoint=EndpointConfig(model_names=["test-model"])) + + +@pytest.fixture +def mock_prompt_generator(): + """Mock PromptGenerator with required attributes for BaseTraceDatasetLoader.""" + generator = Mock() + generator.generate.return_value = "Generated prompt text" + generator._cache = {} + generator._tokenized_corpus = list(range(100, 200)) + generator._hash_id_corpus_rng = Mock() + generator._hash_id_corpus_rng.seed = 42 + generator.tokenizer = Mock() + generator.tokenizer.resolved_name = "test-model" + generator.tokenizer.block_separation_token_id = None + return generator + + +# ----------------------------------------------------------------------- +# _compute_file_hash +# ----------------------------------------------------------------------- + + +class TestComputeFileHash: + """Tests for _compute_file_hash function.""" + + def test_hash_is_16_hex_chars(self, create_jsonl_file): + """Hash should be 16 hex characters.""" + filename = create_jsonl_file(["test content"]) + result = _compute_file_hash(filename) + assert len(result) == 16 + assert all(c in "0123456789abcdef" for c in result) + + def test_deterministic_for_same_content(self, create_jsonl_file): + """Same file content should produce same hash.""" + f1 = create_jsonl_file(["line 1", "line 2"]) + f2 = create_jsonl_file(["line 1", "line 2"]) + assert _compute_file_hash(f1) == _compute_file_hash(f2) + + def test_different_content_different_hash(self, create_jsonl_file): + """Different file content should produce different hash.""" + f1 = create_jsonl_file(["line 1"]) + f2 = create_jsonl_file(["line 2"]) + assert _compute_file_hash(f1) != _compute_file_hash(f2) + + def test_fallback_on_file_error(self): + """Non-existent file should fall back to hashing the filepath string.""" + result = _compute_file_hash("/nonexistent/path/file.jsonl") + expected = hashlib.sha256(b"/nonexistent/path/file.jsonl").hexdigest()[:16] + assert result == expected + + def test_fallback_on_type_error(self): + """TypeError (e.g. from mock_open) should fall back to filepath hash.""" + with patch("builtins.open") as mock_open: + mock_file = MagicMock() + mock_file.__enter__ = Mock(return_value=mock_file) + mock_file.__exit__ = Mock(return_value=False) + mock_file.read.return_value = "string not bytes" + mock_open.return_value = mock_file + + result = _compute_file_hash("test.jsonl") + expected = hashlib.sha256(b"test.jsonl").hexdigest()[:16] + assert result == expected + + def test_hash_matches_sha256(self, create_jsonl_file): + """Hash should match first 16 chars of SHA-256.""" + content = ["test line one", "test line two"] + filename = create_jsonl_file(content) + + with open(filename, "rb") as f: + expected = hashlib.sha256(f.read()).hexdigest()[:16] + + assert _compute_file_hash(filename) == expected + + +# ----------------------------------------------------------------------- +# load_dataset — trace_id lifecycle +# ----------------------------------------------------------------------- + + +class TestLoadDatasetTraceId: + """Tests that load_dataset computes trace_id and sets it on the RNG.""" + + def test_load_dataset_sets_trace_id( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """load_dataset should compute file hash and set trace_id.""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader.load_dataset() + + assert loader._trace_id != "" + assert len(loader._trace_id) == 16 + mock_prompt_generator._hash_id_corpus_rng.set_trace_id.assert_called_once_with( + loader._trace_id + ) + + def test_trace_id_matches_file_hash( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """trace_id should match the SHA-256 hash of the file.""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + expected_hash = _compute_file_hash(filename) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader.load_dataset() + + assert loader._trace_id == expected_hash + + def test_different_files_different_trace_ids( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Different files should produce different trace_ids.""" + f1 = create_jsonl_file(['{"input_length": 100, "hash_ids": [1]}']) + f2 = create_jsonl_file(['{"input_length": 200, "hash_ids": [2]}']) + + loader1 = MooncakeTraceDatasetLoader( + filename=f1, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader1.load_dataset() + + loader2 = MooncakeTraceDatasetLoader( + filename=f2, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + loader2.load_dataset() + + assert loader1._trace_id != loader2._trace_id + + +# ----------------------------------------------------------------------- +# convert_to_conversations — threshold dispatching +# ----------------------------------------------------------------------- + + +class TestConvertToConversationsDispatching: + """Tests for parallel vs single-threaded dispatching.""" + + def test_empty_data_returns_empty(self, mock_prompt_generator, default_user_config): + """Empty data dict should return empty list.""" + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + result = list(loader.convert_to_conversations({})) + assert result == [] + + def test_small_dataset_uses_single_threaded( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Datasets with fewer than _MIN_TRACES_FOR_PARALLEL traces use single-threaded.""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert" + ) as mock_parallel: + result = list(loader.convert_to_conversations(data)) + + mock_parallel.assert_not_called() + assert len(result) == 1 + assert isinstance(result[0], Conversation) + + def test_large_dataset_uses_parallel( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Datasets with >= _MIN_TRACES_FOR_PARALLEL traces use parallel conversion.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + mock_conversations = [Conversation(session_id=f"s{i}") for i in range(3)] + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=mock_conversations, + ) as mock_parallel: + result = list(loader.convert_to_conversations(data)) + + mock_parallel.assert_called_once() + assert result == mock_conversations + + def test_parallel_convert_receives_correct_args( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """parallel_convert should receive the right parameters from the loader.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data, num_workers=8, batch_size=50)) + + call_kwargs = mock_parallel.call_args[1] + assert call_kwargs["tokenizer_name"] == "test-model" + assert call_kwargs["base_seed"] == 42 + assert call_kwargs["block_size"] == InputTokensDefaults.BLOCK_SIZE + assert call_kwargs["sep_token"] is None + assert call_kwargs["trace_id"] == loader._trace_id + assert call_kwargs["num_workers"] == 8 + assert call_kwargs["batch_size"] == 50 + + def test_exactly_threshold_uses_parallel( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Exactly _MIN_TRACES_FOR_PARALLEL traces should use parallel.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + mock_parallel.assert_called_once() + + def test_below_threshold_uses_single_threaded( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """One below _MIN_TRACES_FOR_PARALLEL uses single-threaded.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL - 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert" + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + mock_parallel.assert_not_called() + + +# ----------------------------------------------------------------------- +# _convert_single_threaded +# ----------------------------------------------------------------------- + + +class TestConvertSingleThreaded: + """Tests for the single-threaded conversion fallback.""" + + def test_text_input_used_directly( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces with text_input should use the literal text.""" + content = ['{"text_input": "Hello world", "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + assert len(conversations) == 1 + assert conversations[0].turns[0].texts[0].contents == ["Hello world"] + mock_prompt_generator.generate.assert_not_called() + + def test_hash_ids_calls_generate( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces with hash_ids should call prompt_generator.generate().""" + content = ['{"input_length": 100, "hash_ids": [1, 2], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + assert len(conversations) == 1 + mock_prompt_generator.generate.assert_called_once_with( + mean=100, + stddev=0, + hash_ids=[1, 2], + block_size=InputTokensDefaults.BLOCK_SIZE, + ) + + def test_no_input_calls_generate_with_empty_hash_ids( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces with input_length but no hash_ids still call generate.""" + content = ['{"input_length": 50, "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + list(loader.convert_to_conversations(data)) + + mock_prompt_generator.generate.assert_called_once_with( + mean=50, + stddev=0, + hash_ids=[], + block_size=InputTokensDefaults.BLOCK_SIZE, + ) + + def test_turn_fields_populated( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Turn objects should have correct timestamp, delay, max_tokens.""" + content = [ + '{"input_length": 100, "hash_ids": [1], "timestamp": 5000, "output_length": 42}' + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + turn = conversations[0].turns[0] + assert turn.timestamp == 5000 + assert turn.max_tokens == 42 + + def test_multi_session_conversion( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Multiple sessions each produce a separate Conversation.""" + content = [ + '{"session_id": "s1", "input_length": 100, "hash_ids": [1], "timestamp": 1000}', + '{"session_id": "s2", "input_length": 200, "hash_ids": [2], "timestamp": 2000}', + '{"session_id": "s1", "input_length": 150, "hash_ids": [3], "delay": 50}', + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + conversations = list(loader.convert_to_conversations(data)) + + session_ids = {c.session_id for c in conversations} + assert "s1" in session_ids + assert "s2" in session_ids + + s1_conv = next(c for c in conversations if c.session_id == "s1") + assert len(s1_conv.turns) == 2 + + +# ----------------------------------------------------------------------- +# Block size precedence +# ----------------------------------------------------------------------- + + +class TestBlockSizePrecedence: + """Tests that block size follows the correct precedence chain.""" + + def test_default_block_size(self, mock_prompt_generator, default_user_config): + """Without any overrides, block_size should be InputTokensDefaults.BLOCK_SIZE.""" + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + assert loader._block_size == InputTokensDefaults.BLOCK_SIZE + + def test_plugin_default_block_size( + self, mock_prompt_generator, default_user_config + ): + """Plugin metadata default should override the hardcoded fallback.""" + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + assert loader._block_size == 16 + + def test_user_cli_overrides_plugin_default(self, mock_prompt_generator): + """User CLI --isl-block-size should override plugin metadata default.""" + user_config = UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + input=InputConfig( + prompt=PromptConfig( + input_tokens=InputTokensConfig(block_size=64), + ), + ), + ) + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + assert loader._block_size == 64 + + def test_block_size_passed_to_single_threaded( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Single-threaded path should pass block_size to generate().""" + content = ['{"input_length": 100, "hash_ids": [1], "timestamp": 1000}'] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + data = loader.load_dataset() + list(loader.convert_to_conversations(data)) + + mock_prompt_generator.generate.assert_called_once_with( + mean=100, + stddev=0, + hash_ids=[1], + block_size=16, + ) + + def test_block_size_passed_to_parallel( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Parallel path should pass block_size to parallel_convert.""" + content = [ + f'{{"input_length": 100, "hash_ids": [{i}], "timestamp": {i * 1000}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + default_block_size=16, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + assert mock_parallel.call_args[1]["block_size"] == 16 + + +# ----------------------------------------------------------------------- +# Serialization for parallel path +# ----------------------------------------------------------------------- + + +class TestSessionSerialization: + """Tests that traces are correctly serialized to dicts for parallel workers.""" + + def test_traces_serialized_via_model_dump( + self, create_jsonl_file, mock_prompt_generator, default_user_config + ): + """Traces should be serialized via model_dump() before passing to parallel_convert.""" + content = [ + f'{{"input_length": {100 + i}, "hash_ids": [{i}], "timestamp": {i * 1000}, "output_length": {20 + i}}}' + for i in range(_MIN_TRACES_FOR_PARALLEL + 1) + ] + filename = create_jsonl_file(content) + + loader = MooncakeTraceDatasetLoader( + filename=filename, + user_config=default_user_config, + prompt_generator=mock_prompt_generator, + ) + data = loader.load_dataset() + + with patch( + "aiperf.dataset.loader.base_trace_loader.parallel_convert", + return_value=[], + ) as mock_parallel: + list(loader.convert_to_conversations(data)) + + sessions = mock_parallel.call_args[1]["sessions"] + assert len(sessions) > 0 + + # Each session should be (str, list[dict]) + for sid, traces in sessions: + assert isinstance(sid, str) + for trace_dict in traces: + assert isinstance(trace_dict, dict) + assert "input_length" in trace_dict diff --git a/tests/unit/dataset/loader/test_parallel_convert.py b/tests/unit/dataset/loader/test_parallel_convert.py new file mode 100644 index 000000000..cc7edd3f8 --- /dev/null +++ b/tests/unit/dataset/loader/test_parallel_convert.py @@ -0,0 +1,1062 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for parallel_convert module. + +Covers worker initialization, batch processing, daemon flag handling, +shared memory management, and end-to-end parallel conversion. +""" + +import multiprocessing as mp +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import pytest + +import aiperf.dataset.loader.parallel_convert as parallel_convert_mod +from aiperf.common.hash_id_random_generator import HashIdRandomGenerator +from aiperf.common.models import Conversation +from aiperf.dataset.generator.prompt import sample_tokens_from_corpus +from aiperf.dataset.loader.parallel_convert import ( + _init_worker, + _process_batch, + _set_daemon, + _WorkerInitArgs, + _WorkerState, + parallel_convert, +) + +# ----------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------- + + +@pytest.fixture +def sample_corpus(): + """A small corpus of token IDs for testing.""" + return list(range(100, 200)) + + +@pytest.fixture +def sample_corpus_array(sample_corpus): + """Corpus as numpy int32 array.""" + return np.array(sample_corpus, dtype=np.int32) + + +@pytest.fixture +def mock_tokenizer(): + """Mock tokenizer that decodes tokens into readable strings.""" + tok = MagicMock() + tok.decode.side_effect = lambda ids, **kw: " ".join(f"t{i}" for i in ids) + return tok + + +@pytest.fixture +def setup_worker(sample_corpus_array, mock_tokenizer): + """Set up module-level _worker_state for _process_batch tests.""" + seed = 42 + hash_rng = HashIdRandomGenerator(seed, _internal=True) + hash_rng.set_trace_id("test_trace") + + parallel_convert_mod._worker_state = _WorkerState( + tokenizer=mock_tokenizer, + corpus=sample_corpus_array, + shm=MagicMock(), + hash_rng=hash_rng, + block_size=10, + sep_token=None, + sample_tokens=sample_tokens_from_corpus, + ) + yield + parallel_convert_mod._worker_state = None + + +@pytest.fixture +def setup_worker_with_sep(sample_corpus_array, mock_tokenizer): + """Set up worker with a separator token (BOS/EOS).""" + seed = 42 + hash_rng = HashIdRandomGenerator(seed, _internal=True) + hash_rng.set_trace_id("test_trace") + + parallel_convert_mod._worker_state = _WorkerState( + tokenizer=mock_tokenizer, + corpus=sample_corpus_array, + shm=MagicMock(), + hash_rng=hash_rng, + block_size=10, + sep_token=1, + sample_tokens=sample_tokens_from_corpus, + ) + yield + parallel_convert_mod._worker_state = None + + +# ----------------------------------------------------------------------- +# _set_daemon +# ----------------------------------------------------------------------- + + +class TestSetDaemon: + """Tests for daemon flag manipulation.""" + + def test_set_daemon_true(self): + """Setting daemon to True should work on a non-daemon process.""" + original = mp.current_process().daemon + try: + _set_daemon(True) + assert mp.current_process().daemon is True + finally: + _set_daemon(original) + + def test_set_daemon_false(self): + """Setting daemon to False should work.""" + original = mp.current_process().daemon + try: + _set_daemon(False) + assert mp.current_process().daemon is False + finally: + _set_daemon(original) + + def test_set_daemon_fallback_on_assertion_error(self): + """If daemon= setter raises AssertionError, fallback to _config.""" + proc = mp.current_process() + original = proc.daemon + + with patch.object( + type(proc), "daemon", property(fset=Mock(side_effect=AssertionError)) + ): + _set_daemon(True) + assert proc._config["daemon"] is True + + # Restore + proc._config["daemon"] = original + + +# ----------------------------------------------------------------------- +# _process_batch +# ----------------------------------------------------------------------- + + +class TestProcessBatch: + """Tests for _process_batch worker function.""" + + def test_text_input_traces(self, setup_worker): + """Traces with text_input should use the literal text.""" + batch = [ + ( + "session-1", + [ + { + "text_input": "Hello world", + "timestamp": 100, + "delay": None, + "output_length": 10, + }, + ], + ), + ] + results = _process_batch(batch) + + assert len(results) == 1 + sid, turns = results[0] + assert sid == "session-1" + assert len(turns) == 1 + ts, delay, prompt, max_tokens = turns[0] + assert prompt == "Hello world" + assert ts == 100 + assert max_tokens == 10 + + def test_hash_ids_traces(self, setup_worker, mock_tokenizer): + """Traces with hash_ids should generate tokens and decode.""" + batch = [ + ( + "session-1", + [ + { + "hash_ids": [1, 2], + "input_length": 15, + "timestamp": 200, + "delay": 5, + "output_length": 20, + }, + ], + ), + ] + results = _process_batch(batch) + + assert len(results) == 1 + sid, turns = results[0] + assert sid == "session-1" + assert len(turns) == 1 + ts, delay, prompt, max_tokens = turns[0] + assert ts == 200 + assert delay == 5 + assert max_tokens == 20 + # decode was called with generated tokens + mock_tokenizer.decode.assert_called_once() + assert isinstance(prompt, str) + assert len(prompt) > 0 + + def test_empty_trace_no_hash_ids_no_text(self, setup_worker): + """Traces without hash_ids or text_input produce empty prompt.""" + batch = [ + ( + "session-1", + [ + { + "timestamp": 300, + "delay": None, + "output_length": 5, + "input_length": 0, + }, + ], + ), + ] + results = _process_batch(batch) + + _, turns = results[0] + _, _, prompt, _ = turns[0] + assert prompt == "" + + def test_multiple_sessions_in_batch(self, setup_worker, mock_tokenizer): + """Multiple sessions in one batch are all processed.""" + batch = [ + ( + "s1", + [ + { + "text_input": "prompt A", + "timestamp": 1, + "delay": None, + "output_length": 10, + } + ], + ), + ( + "s2", + [ + { + "text_input": "prompt B", + "timestamp": 2, + "delay": None, + "output_length": 20, + } + ], + ), + ( + "s3", + [ + { + "text_input": "prompt C", + "timestamp": 3, + "delay": None, + "output_length": 30, + } + ], + ), + ] + results = _process_batch(batch) + + assert len(results) == 3 + assert results[0][0] == "s1" + assert results[1][0] == "s2" + assert results[2][0] == "s3" + assert results[0][1][0][2] == "prompt A" + assert results[1][1][0][2] == "prompt B" + assert results[2][1][0][2] == "prompt C" + + def test_multi_turn_session(self, setup_worker, mock_tokenizer): + """A session with multiple turns (traces) processes all turns.""" + batch = [ + ( + "session-1", + [ + { + "text_input": "turn 1", + "timestamp": 100, + "delay": None, + "output_length": 10, + }, + { + "text_input": "turn 2", + "timestamp": 200, + "delay": 50, + "output_length": 20, + }, + { + "text_input": "turn 3", + "timestamp": 300, + "delay": 100, + "output_length": 30, + }, + ], + ), + ] + results = _process_batch(batch) + + _, turns = results[0] + assert len(turns) == 3 + assert turns[0][2] == "turn 1" + assert turns[1][2] == "turn 2" + assert turns[2][2] == "turn 3" + + def test_hash_id_block_cache_reuse(self, setup_worker, mock_tokenizer): + """Same hash_id within a batch should reuse cached tokens.""" + batch = [ + ( + "s1", + [ + { + "hash_ids": [42], + "input_length": 10, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ( + "s2", + [ + { + "hash_ids": [42], + "input_length": 10, + "timestamp": 2, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + results = _process_batch(batch) + + # Both sessions with same hash_id should get same decoded output + prompt_1 = results[0][1][0][2] + prompt_2 = results[1][1][0][2] + assert prompt_1 == prompt_2 + + def test_different_hash_ids_produce_different_prompts( + self, setup_worker, mock_tokenizer + ): + """Different hash_ids should produce different token sequences.""" + batch = [ + ( + "s1", + [ + { + "hash_ids": [100], + "input_length": 10, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ( + "s2", + [ + { + "hash_ids": [200], + "input_length": 10, + "timestamp": 2, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + results = _process_batch(batch) + + prompt_1 = results[0][1][0][2] + prompt_2 = results[1][1][0][2] + assert prompt_1 != prompt_2 + + def test_final_block_size_calculation(self, setup_worker, mock_tokenizer): + """Last hash block should get the remainder tokens.""" + # block_size=10, input_length=25, 3 hash_ids + # first two blocks: 10 tokens each, last block: 25 - 2*10 = 5 tokens + batch = [ + ( + "s1", + [ + { + "hash_ids": [1, 2, 3], + "input_length": 25, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + _process_batch(batch) + + # Verify decode was called (tokens were generated) + assert mock_tokenizer.decode.call_count == 1 + decoded_tokens = mock_tokenizer.decode.call_args[0][0] + # 10 + 10 + 5 = 25 total tokens + assert len(decoded_tokens) == 25 + + def test_separator_token_prepended(self, setup_worker_with_sep, mock_tokenizer): + """When sep_token is set, each block should have it prepended.""" + batch = [ + ( + "s1", + [ + { + "hash_ids": [1], + "input_length": 10, + "timestamp": 1, + "delay": None, + "output_length": 5, + }, + ], + ), + ] + _process_batch(batch) + + decoded_tokens = mock_tokenizer.decode.call_args[0][0] + # With sep_token=1, first token in block should be 1 + assert decoded_tokens[0] == 1 + + def test_mixed_text_input_and_hash_ids(self, setup_worker, mock_tokenizer): + """Session with both text_input and hash_id traces.""" + batch = [ + ( + "s1", + [ + { + "text_input": "literal text", + "timestamp": 1, + "delay": None, + "output_length": 10, + }, + { + "hash_ids": [5], + "input_length": 10, + "timestamp": 2, + "delay": None, + "output_length": 20, + }, + ], + ), + ] + results = _process_batch(batch) + + _, turns = results[0] + assert turns[0][2] == "literal text" + assert turns[1][2] != "literal text" # Generated prompt + + def test_none_fields_preserved(self, setup_worker): + """None values for timestamp, delay, output_length are preserved.""" + batch = [ + ( + "s1", + [ + { + "text_input": "test", + "timestamp": None, + "delay": None, + "output_length": None, + }, + ], + ), + ] + results = _process_batch(batch) + + ts, delay, _, max_tokens = results[0][1][0] + assert ts is None + assert delay is None + assert max_tokens is None + + +# ----------------------------------------------------------------------- +# _init_worker +# ----------------------------------------------------------------------- + + +class TestInitWorker: + """Tests for _init_worker function.""" + + def test_init_worker_sets_up_state(self, sample_corpus_array, tmp_path): + """_init_worker should populate _worker dict with all required fields.""" + from multiprocessing import shared_memory + + shm = shared_memory.SharedMemory(create=True, size=sample_corpus_array.nbytes) + try: + np.copyto( + np.ndarray( + sample_corpus_array.shape, + dtype=sample_corpus_array.dtype, + buffer=shm.buf, + ), + sample_corpus_array, + ) + + mock_tok = MagicMock() + args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=len(sample_corpus_array), + tokenizer_name="test-model", + base_seed=42, + block_size=10, + sep_token=1, + trace_id="abc123", + ) + with patch( + "aiperf.common.tokenizer.Tokenizer.from_pretrained", + return_value=mock_tok, + ): + _init_worker(args) + + state = parallel_convert_mod._worker_state + assert state is not None + assert state.tokenizer is mock_tok + assert state.block_size == 10 + assert state.sep_token == 1 + assert isinstance(state.hash_rng, HashIdRandomGenerator) + assert np.array_equal(state.corpus, sample_corpus_array) + finally: + parallel_convert_mod._worker_state = None + shm.close() + shm.unlink() + + def test_init_worker_passes_tokenizer_config(self, sample_corpus_array): + """_init_worker should forward trust_remote_code and revision to Tokenizer.""" + from multiprocessing import shared_memory + + shm = shared_memory.SharedMemory(create=True, size=sample_corpus_array.nbytes) + try: + np.copyto( + np.ndarray( + sample_corpus_array.shape, + dtype=sample_corpus_array.dtype, + buffer=shm.buf, + ), + sample_corpus_array, + ) + + args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=len(sample_corpus_array), + tokenizer_name="test-model", + base_seed=42, + block_size=10, + sep_token=None, + trace_id="abc", + trust_remote_code=True, + revision="v2.0", + ) + with patch( + "aiperf.common.tokenizer.Tokenizer.from_pretrained", + return_value=MagicMock(), + ) as mock_from_pretrained: + _init_worker(args) + + mock_from_pretrained.assert_called_once_with( + "test-model", + trust_remote_code=True, + revision="v2.0", + resolve_alias=False, + ) + finally: + parallel_convert_mod._worker_state = None + shm.close() + shm.unlink() + + def test_init_worker_default_tokenizer_config(self, sample_corpus_array): + """_init_worker defaults to trust_remote_code=False and revision='main'.""" + from multiprocessing import shared_memory + + shm = shared_memory.SharedMemory(create=True, size=sample_corpus_array.nbytes) + try: + np.copyto( + np.ndarray( + sample_corpus_array.shape, + dtype=sample_corpus_array.dtype, + buffer=shm.buf, + ), + sample_corpus_array, + ) + + args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=len(sample_corpus_array), + tokenizer_name="test-model", + base_seed=42, + block_size=10, + sep_token=None, + trace_id="abc", + ) + with patch( + "aiperf.common.tokenizer.Tokenizer.from_pretrained", + return_value=MagicMock(), + ) as mock_from_pretrained: + _init_worker(args) + + mock_from_pretrained.assert_called_once_with( + "test-model", + trust_remote_code=False, + revision="main", + resolve_alias=False, + ) + finally: + parallel_convert_mod._worker_state = None + shm.close() + shm.unlink() + + def test_init_worker_sets_offline_env(self, sample_corpus_array): + """Worker should set HF offline environment variables.""" + import os + from multiprocessing import shared_memory + + original_hf = os.environ.get("HF_HUB_OFFLINE") + original_tf = os.environ.get("TRANSFORMERS_OFFLINE") + + shm = shared_memory.SharedMemory(create=True, size=sample_corpus_array.nbytes) + try: + np.copyto( + np.ndarray( + sample_corpus_array.shape, + dtype=sample_corpus_array.dtype, + buffer=shm.buf, + ), + sample_corpus_array, + ) + + args = _WorkerInitArgs( + shm_name=shm.name, + corpus_len=len(sample_corpus_array), + tokenizer_name="test-model", + base_seed=42, + block_size=10, + sep_token=None, + trace_id="abc", + ) + with patch( + "aiperf.common.tokenizer.Tokenizer.from_pretrained", + return_value=MagicMock(), + ): + _init_worker(args) + + assert os.environ.get("HF_HUB_OFFLINE") == "1" + assert os.environ.get("TRANSFORMERS_OFFLINE") == "1" + finally: + parallel_convert_mod._worker_state = None + shm.close() + shm.unlink() + # Restore env + if original_hf is None: + os.environ.pop("HF_HUB_OFFLINE", None) + else: + os.environ["HF_HUB_OFFLINE"] = original_hf + if original_tf is None: + os.environ.pop("TRANSFORMERS_OFFLINE", None) + else: + os.environ["TRANSFORMERS_OFFLINE"] = original_tf + + +# ----------------------------------------------------------------------- +# parallel_convert — end-to-end +# ----------------------------------------------------------------------- + + +class TestParallelConvert: + """Tests for the parallel_convert orchestration function.""" + + def test_empty_sessions_returns_empty(self, sample_corpus): + """Empty input returns empty output.""" + result = list( + parallel_convert( + sessions=[], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + assert result == [] + + def test_returns_conversation_objects(self, sample_corpus): + """Output should be a list of Conversation objects.""" + sessions = [ + ( + "s1", + [ + { + "text_input": "hello", + "timestamp": 1, + "delay": None, + "output_length": 5, + } + ], + ), + ] + + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + + mock_pool_instance.imap.return_value = [ + [("s1", [(1, None, "hello", 5)])], + ] + + result = list( + parallel_convert( + sessions=sessions, + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + assert len(result) == 1 + assert isinstance(result[0], Conversation) + assert result[0].session_id == "s1" + assert len(result[0].turns) == 1 + assert result[0].turns[0].timestamp == 1 + assert result[0].turns[0].max_tokens == 5 + + def test_batching_splits_sessions(self, sample_corpus): + """Sessions should be split into batches of batch_size.""" + sessions = [ + ( + f"s{i}", + [ + { + "text_input": f"p{i}", + "timestamp": i, + "delay": None, + "output_length": 5, + } + ], + ) + for i in range(5) + ] + + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + + mock_pool_instance.imap.return_value = [ + [(f"s{i}", [(i, None, f"p{i}", 5)]) for i in range(2)], + [(f"s{i}", [(i, None, f"p{i}", 5)]) for i in range(2, 5)], + ] + + list( + parallel_convert( + sessions=sessions, + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + batch_size=2, + ) + ) + + # map was called with batches of size 2 + batches = mock_pool_instance.imap.call_args[0][1] + assert len(batches) == 3 # 5 sessions / 2 batch_size = 3 batches + assert len(batches[0]) == 2 + assert len(batches[1]) == 2 + assert len(batches[2]) == 1 + + def test_daemon_flag_restored(self, sample_corpus): + """Daemon flag should be restored after Pool finishes.""" + original_daemon = mp.current_process().daemon + + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + mock_pool_instance.imap.return_value = [] + + list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "t", + "timestamp": 1, + "delay": None, + "output_length": 1, + } + ], + ) + ], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + assert mp.current_process().daemon == original_daemon + + def test_shared_memory_cleanup(self, sample_corpus): + """Shared memory should be cleaned up even on errors.""" + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + mock_pool_instance.imap.side_effect = RuntimeError("Pool error") + + with pytest.raises(RuntimeError, match="Pool error"): + list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "t", + "timestamp": 1, + "delay": None, + "output_length": 1, + } + ], + ) + ], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + # No leaked shared memory (if it leaked, subsequent tests would detect it) + + def test_multi_turn_conversations(self, sample_corpus): + """Sessions with multiple turns should produce multi-turn Conversations.""" + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + + mock_pool_instance.imap.return_value = [ + [("s1", [(100, None, "turn 1", 10), (200, 50, "turn 2", 20)])], + ] + + result = list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "turn 1", + "timestamp": 100, + "delay": None, + "output_length": 10, + }, + { + "text_input": "turn 2", + "timestamp": 200, + "delay": 50, + "output_length": 20, + }, + ], + ) + ], + tokenizer_name="test", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + ) + ) + + assert len(result) == 1 + conv = result[0] + assert len(conv.turns) == 2 + assert conv.turns[0].timestamp == 100 + assert conv.turns[0].delay is None + assert conv.turns[0].max_tokens == 10 + assert conv.turns[1].timestamp == 200 + assert conv.turns[1].delay == 50 + assert conv.turns[1].max_tokens == 20 + + def test_pool_receives_correct_init_args(self, sample_corpus): + """Pool should be initialized with correct arguments.""" + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + mock_pool_instance.imap.return_value = [] + + list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "t", + "timestamp": 1, + "delay": None, + "output_length": 1, + } + ], + ) + ], + tokenizer_name="my-tokenizer", + corpus=sample_corpus, + base_seed=12345, + block_size=64, + sep_token=7, + trace_id="trace_abc", + num_workers=4, + ) + ) + + call_args = MockPool.call_args + assert call_args[0][0] == 4 # num_workers + assert call_args[0][1] is _init_worker + initargs = call_args[0][2] + assert len(initargs) == 1 + args = initargs[0] + assert isinstance(args, _WorkerInitArgs) + assert args.tokenizer_name == "my-tokenizer" + assert args.base_seed == 12345 + assert args.block_size == 64 + assert args.sep_token == 7 + assert args.trace_id == "trace_abc" + assert args.trust_remote_code is False + assert args.revision == "main" + + def test_pool_receives_tokenizer_config(self, sample_corpus): + """Pool init args should include trust_remote_code and revision.""" + with patch("aiperf.dataset.loader.parallel_convert.Pool") as MockPool: + mock_pool_instance = MagicMock() + MockPool.return_value.__enter__ = Mock(return_value=mock_pool_instance) + MockPool.return_value.__exit__ = Mock(return_value=False) + mock_pool_instance.imap.return_value = [] + + list( + parallel_convert( + sessions=[ + ( + "s1", + [ + { + "text_input": "t", + "timestamp": 1, + "delay": None, + "output_length": 1, + } + ], + ) + ], + tokenizer_name="kimi-model", + corpus=sample_corpus, + base_seed=42, + block_size=10, + sep_token=None, + trace_id="test", + trust_remote_code=True, + revision="v2.0", + ) + ) + + args = MockPool.call_args[0][2][0] + assert args.trust_remote_code is True + assert args.revision == "v2.0" + + +# ----------------------------------------------------------------------- +# Determinism: _process_batch produces identical results with same seed +# ----------------------------------------------------------------------- + + +class TestProcessBatchDeterminism: + """Tests that _process_batch is deterministic across invocations.""" + + def _setup_and_process( + self, corpus_array, hash_ids, input_length, block_size, trace_id, seed=42 + ): + """Helper: set up worker state and process a single batch.""" + hash_rng = HashIdRandomGenerator(seed, _internal=True) + hash_rng.set_trace_id(trace_id) + + mock_tok = MagicMock() + mock_tok.decode.side_effect = lambda ids, **kw: ",".join(str(i) for i in ids) + + parallel_convert_mod._worker_state = _WorkerState( + tokenizer=mock_tok, + corpus=corpus_array, + shm=MagicMock(), + hash_rng=hash_rng, + block_size=block_size, + sep_token=None, + sample_tokens=sample_tokens_from_corpus, + ) + + batch = [ + ( + "s1", + [ + { + "hash_ids": hash_ids, + "input_length": input_length, + "timestamp": 1, + "delay": None, + "output_length": 5, + } + ], + ), + ] + result = _process_batch(batch) + parallel_convert_mod._worker_state = None + return result[0][1][0][2] # prompt string + + def test_same_seed_same_trace_id_same_result(self, sample_corpus_array): + """Identical seed + trace_id + hash_ids = identical prompt.""" + prompt_1 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_a" + ) + prompt_2 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_a" + ) + assert prompt_1 == prompt_2 + + def test_different_trace_id_different_result(self, sample_corpus_array): + """Different trace_ids produce different prompts.""" + prompt_1 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_a" + ) + prompt_2 = self._setup_and_process( + sample_corpus_array, [1, 2], 15, 10, "trace_b" + ) + assert prompt_1 != prompt_2 + + def test_different_seed_different_result(self, sample_corpus_array): + """Different seeds produce different prompts.""" + prompt_1 = self._setup_and_process( + sample_corpus_array, [1], 10, 10, "trace_a", seed=42 + ) + prompt_2 = self._setup_and_process( + sample_corpus_array, [1], 10, 10, "trace_a", seed=99 + ) + assert prompt_1 != prompt_2 + + def test_different_hash_ids_different_result(self, sample_corpus_array): + """Different hash_ids produce different prompts.""" + prompt_1 = self._setup_and_process(sample_corpus_array, [10], 10, 10, "trace_a") + prompt_2 = self._setup_and_process(sample_corpus_array, [20], 10, 10, "trace_a") + assert prompt_1 != prompt_2 diff --git a/tests/unit/dataset/loader/test_trace.py b/tests/unit/dataset/loader/test_trace.py index 238dc9873..f3aa66f5e 100644 --- a/tests/unit/dataset/loader/test_trace.py +++ b/tests/unit/dataset/loader/test_trace.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import logging -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest from pydantic import ValidationError @@ -12,6 +12,7 @@ InputTokensConfig, PromptConfig, SynthesisConfig, + TokenizerConfig, UserConfig, ) from aiperf.dataset.loader.models import MooncakeTrace @@ -111,9 +112,7 @@ def mock_prompt_generator(self): generator = Mock() generator.generate.return_value = "Generated prompt text" # Required for convert_to_conversations() to check string cache - generator._decoded_cache = {} - # Mock _build_token_sequence to return a simple token list - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture @@ -356,18 +355,8 @@ def test_load_dataset_logs_skipped_traces( # Check that the skipped traces message is logged assert f"Skipped {expected_skipped:,} traces" in caplog.text - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_convert_to_conversations( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): + def test_convert_to_conversations(self, mock_prompt_generator, default_user_config): """Test conversion of trace data to conversations.""" - # Mock parallel_decode to return decoded prompts - mock_parallel_decode.return_value = [ - "decoded prompt 1", - "decoded prompt 2", - "decoded prompt 3", - ] - # Setup trace data trace_data = { "session-1": [ @@ -401,7 +390,7 @@ def test_convert_to_conversations( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 3 @@ -432,7 +421,7 @@ def test_convert_to_conversations_empty_data( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations({}) + conversations = list(loader.convert_to_conversations({})) assert len(conversations) == 0 @@ -453,7 +442,7 @@ def test_convert_to_conversations_with_text_input( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 # One conversation with multiple turns conversation = conversations[0] @@ -486,7 +475,7 @@ def test_convert_to_conversations_multi_turn_messages_on_turns( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 conversation = conversations[0] @@ -518,7 +507,7 @@ def test_convert_to_conversations_messages_with_tools( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert len(conversations) == 1 turn = conversations[0].turns[0] @@ -542,7 +531,7 @@ def test_convert_to_conversations_messages_without_tools( user_config=default_user_config, prompt_generator=mock_prompt_generator, ) - conversations = loader.convert_to_conversations(trace_data) + conversations = list(loader.convert_to_conversations(trace_data)) assert conversations[0].turns[0].raw_tools is None @@ -882,11 +871,52 @@ def test_load_dataset_max_isl_and_max_osl_combined( assert traces[2][0].output_length == 50 +class TestBaseTraceLoaderTokenizerConfig: + """Tests that tokenizer config options flow through to parallel_convert.""" + + def test_stores_trust_remote_code_and_revision(self): + """BaseTraceDatasetLoader stores tokenizer config from user_config.""" + generator = Mock() + generator.generate.return_value = "prompt" + generator._cache = {} + + user_config = UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + tokenizer=TokenizerConfig(trust_remote_code=True, revision="v2.0"), + ) + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=generator, + ) + + assert loader._trust_remote_code is True + assert loader._revision == "v2.0" + + def test_default_tokenizer_config(self): + """Default tokenizer config should be trust_remote_code=False, revision='main'.""" + generator = Mock() + generator.generate.return_value = "prompt" + generator._cache = {} + + user_config = UserConfig( + endpoint=EndpointConfig(model_names=["test-model"]), + ) + loader = MooncakeTraceDatasetLoader( + filename="dummy.jsonl", + user_config=user_config, + prompt_generator=generator, + ) + + assert loader._trust_remote_code is False + assert loader._revision == "main" + + class TestMooncakeTraceReproducibility: """Tests for reproducibility of Mooncake trace prompt generation. - These tests verify that the two-phase Mooncake flow with parallel_decode - yields identical prompts across runs when the RNG is seeded consistently. + These tests verify that HashIdRandomGenerator-based generation yields + identical prompts across runs when the RNG is seeded consistently. """ @pytest.fixture @@ -894,8 +924,7 @@ def mock_prompt_generator(self): """Create a mock prompt generator for testing.""" generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture @@ -910,9 +939,8 @@ def user_config_for_reproducibility(self): ), ) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") def test_mooncake_flow_reproducibility_with_same_seed( - self, mock_parallel_decode, mock_tokenizer_cls, user_config_for_reproducibility + self, mock_tokenizer_cls, user_config_for_reproducibility ): """Verify Mooncake flow produces identical prompts across runs with same seed. @@ -922,16 +950,7 @@ def test_mooncake_flow_reproducibility_with_same_seed( from aiperf.common import random_generator as rng from aiperf.dataset.generator import PromptGenerator - # Mock parallel_decode to return deterministic results based on input - def deterministic_decode(token_sequences, tokenizer_name): - return [ - f"decoded_prompt_{i}_{len(seq)}" - for i, seq in enumerate(token_sequences) - ] - - mock_parallel_decode.side_effect = deterministic_decode - - # Create trace data with hash_ids to exercise the two-phase flow + # Create trace data with hash_ids trace_data = { "session-1": [ MooncakeTrace( @@ -997,39 +1016,6 @@ def deterministic_decode(token_sequences, tokenizer_name): f"First run: {prompts1}, Second run: {prompts2}" ) - @patch("aiperf.dataset.loader.base_trace_loader.parallel_decode") - def test_parallel_decode_length_mismatch_raises( - self, mock_parallel_decode, mock_prompt_generator, default_user_config - ): - """Verify that length mismatch between pending_decodes and decoded_prompts raises. - - This tests the strict=True behavior in zip() that guards against silent data loss. - """ - # Mock parallel_decode to return FEWER results than expected - mock_parallel_decode.return_value = ["decoded prompt 1"] # Only 1, expecting 3 - - trace_data = { - "session-1": [ - MooncakeTrace(input_length=100, hash_ids=[1, 2], timestamp=1000), - ], - "session-2": [ - MooncakeTrace(input_length=200, hash_ids=[3, 4, 5], timestamp=2000), - ], - "session-3": [ - MooncakeTrace(input_length=150, hash_ids=[6], timestamp=3000), - ], - } - - loader = MooncakeTraceDatasetLoader( - filename="dummy.jsonl", - user_config=default_user_config, - prompt_generator=mock_prompt_generator, - ) - - # Should raise ValueError due to strict=True in zip - with pytest.raises(ValueError, match="zip"): - loader.convert_to_conversations(trace_data) - # ============================================================================ # Synthesis Integration Tests @@ -1070,8 +1056,7 @@ def mock_prompt_generator(self): """Create a mock prompt generator for testing.""" generator = Mock() generator.generate.return_value = "Generated prompt text" - generator._decoded_cache = {} - generator._build_token_sequence.return_value = [1, 2, 3, 4, 5] + generator._cache = {} return generator @pytest.fixture diff --git a/tests/unit/dataset/test_dataset_manager.py b/tests/unit/dataset/test_dataset_manager.py index 81f97d577..e54861276 100644 --- a/tests/unit/dataset/test_dataset_manager.py +++ b/tests/unit/dataset/test_dataset_manager.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from pydantic import ValidationError @@ -404,9 +404,8 @@ async def test_in_memory_dataset_freed_after_client_initialization( ProfileConfigureCommand(config=user_config, service_id="test_service") ) - # After configuration, in-memory dataset should be empty - assert dataset_manager.dataset == {} - assert dataset_manager._conversation_ids_cache == [] + # After configuration, conversation count should reflect dataset size + assert dataset_manager._conversation_count > 0 @pytest.mark.asyncio async def test_dataset_configured_event_set_after_client_initialization( @@ -463,8 +462,7 @@ async def test_handle_conversation_request_uses_dataset_client( 0 ].conversation_id - # Verify in-memory dataset is empty (freed) - assert dataset_manager.dataset == {} + # Verify in-memory dataset was not materialized # Request should still work via dataset client request = ConversationRequestMessage( @@ -489,8 +487,7 @@ async def test_handle_conversation_turn_request_uses_dataset_client( 0 ].conversation_id - # Verify in-memory dataset is empty (freed) - assert dataset_manager.dataset == {} + # Verify in-memory dataset was not materialized # Request should still work via dataset client request = ConversationTurnRequestMessage( @@ -572,17 +569,12 @@ async def test_configure_client_compress_only_skips_client_creation( """In compress_only mode, _configure_dataset_client_and_free_memory skips client creation.""" service_config = ServiceConfig(service_run_type=ServiceRunType.KUBERNETES) manager = DatasetManager(service_config, base_user_config) - # Simulate some dataset entries - manager.dataset = {"conv1": MagicMock(), "conv2": MagicMock()} - manager._conversation_ids_cache = ["conv1", "conv2"] + manager._conversation_count = 2 await manager._configure_dataset_client_and_free_memory() # Should have set dataset_configured event assert manager.dataset_configured.is_set() - # Should have freed memory (cleared dataset) - assert manager.dataset == {} - assert manager._conversation_ids_cache == [] # Should NOT have created a dataset client assert manager._dataset_client is None diff --git a/tests/unit/dataset/test_dataset_manager_inputs_json.py b/tests/unit/dataset/test_dataset_manager_inputs_json.py index 6ceb2beb1..2b92f2015 100644 --- a/tests/unit/dataset/test_dataset_manager_inputs_json.py +++ b/tests/unit/dataset/test_dataset_manager_inputs_json.py @@ -112,7 +112,10 @@ async def test_generate_inputs_json_session_order_preservation( written_json = json.loads(capture_file_writes.written_content) session_ids = [session["session_id"] for session in written_json["data"]] - expected_order = list(populated_dataset_manager.dataset.keys()) + expected_order = [ + conv.conversation_id + for conv in populated_dataset_manager.dataset_metadata.conversations + ] assert session_ids == expected_order @pytest.mark.asyncio