Skip to content

Commit 27536c2

Browse files
committed
feat: add conversation context mode for multi-turn history control
Introduce ConversationContextMode enum (accumulate_all, drop_responses, standalone) to control how prior turns are accumulated in multi-turn conversations. Modes resolve with conversation > dataset default > accumulate_all precedence. Standalone replaces turn_list with only the current turn; drop_responses skips storing assistant responses. Signed-off-by: Anthony Casagrande <acasagrande@nvidia.com>
1 parent e294a31 commit 27536c2

File tree

17 files changed

+484
-13
lines changed

17 files changed

+484
-13
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ aiperf profile \
7474
- [Reproducibility](docs/reproducibility.md) - Deterministic datasets with `--random-seed`
7575
- [Template Endpoint](docs/tutorials/template-endpoint.md) - Custom Jinja2 request templates
7676
- [Multi-Turn Conversations](docs/tutorials/multi-turn.md) - Multi-turn conversation benchmarking
77+
- [Conversation Context Mode](docs/tutorials/conversation-context-mode.md) - Control how conversation history accumulates
7778
- [Local Tokenizer](docs/tutorials/local-tokenizer.md) - Use local tokenizers without HuggingFace
7879

7980
### Endpoint Types
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
<!--
2+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
-->
5+
# Conversation Context Mode
6+
7+
Conversation context mode controls how prior turns are accumulated when building multi-turn chat requests. Different dataset formats imply different accumulation strategies, and AIPerf automatically selects the right one based on your data.
8+
9+
## Modes
10+
11+
### `accumulate_all`
12+
13+
Standard multi-turn chat. The live inference response is stored and included in subsequent requests.
14+
15+
**Dataset:**
16+
```
17+
Turn 1: {"role": "user", "content": "What is ML?"}
18+
Turn 2: {"role": "user", "content": "Give an example"}
19+
Turn 3: {"role": "user", "content": "How does it differ from traditional programming?"}
20+
```
21+
22+
**Replay:**
23+
```
24+
Request 1: [User "What is ML?"]
25+
→ Server responds with A1
26+
27+
Request 2: [User "What is ML?", Assistant A1, User "Give an example"]
28+
→ Server responds with A2
29+
30+
Request 3: [User "What is ML?", Assistant A1, User "Give an example", Assistant A2, User "How does it differ..."]
31+
→ Server responds with A3
32+
```
33+
34+
Default for:
35+
- Synthetic datasets
36+
- Multi-turn JSONL
37+
- ShareGPT
38+
- Mooncake traces with `hash_ids`
39+
40+
### `drop_responses`
41+
42+
Delta-compressed prompts. Each dataset turn only contains the *new* messages since the previous turn. AIPerf accumulates these deltas to reconstruct the full conversation. The live inference response is only used for measurement and discarded -- the pre-canned assistant responses in the dataset are used instead.
43+
44+
**Dataset (each turn is a delta):**
45+
```
46+
Turn 1: [{"role": "user", "content": "What is ML?"}]
47+
Turn 2: [{"role": "assistant", "content": "ML is..."}, {"role": "user", "content": "Give an example"}]
48+
Turn 3: [{"role": "assistant", "content": "Sure..."}, {"role": "user", "content": "How does it differ..."}]
49+
```
50+
51+
**Replay (deltas accumulated):**
52+
```
53+
Request 1: [User "What is ML?"]
54+
→ Live response discarded
55+
56+
Request 2: [User "What is ML?"] + [Assistant "ML is...", User "Give an example"]
57+
→ Live response discarded
58+
59+
Request 3: [User "What is ML?"] + [Assistant "ML is...", User "Give an example"] + [Assistant "Sure...", User "How does it differ..."]
60+
→ Live response discarded
61+
```
62+
63+
Default for:
64+
- N/A (no built-in loader defaults to this mode yet)
65+
66+
### `standalone`
67+
68+
Self-contained prompts. Each turn already contains its full context. No session accumulation.
69+
70+
**Dataset:**
71+
```
72+
Turn 1: [{"role": "user", "content": "What is ML?"}]
73+
Turn 2: [{"role": "user", "content": "What is ML?"}, {"role": "assistant", "content": "ML is..."}, {"role": "user", "content": "Give an example"}]
74+
Turn 3: [{"role": "user", "content": "What is ML?"}, {"role": "assistant", "content": "ML is..."}, {"role": "user", "content": "Give an example"}, {"role": "assistant", "content": "Sure..."}, {"role": "user", "content": "How does it differ..."}]
75+
```
76+
77+
**Replay:**
78+
```
79+
Request 1: sends Turn 1 as-is
80+
Request 2: sends Turn 2 as-is
81+
Request 3: sends Turn 3 as-is
82+
```
83+
84+
Each turn is sent exactly as it appears in the dataset.
85+
86+
Default for:
87+
- Mooncake traces with pre-built `messages` arrays
88+
89+
## How It Works
90+
91+
Context mode is resolved through a priority chain:
92+
93+
1. **Per-conversation override** -- A conversation in the dataset can specify its own `context_mode`
94+
2. **Loader default** -- The dataset loader can declare a default based on dataset format semantics
95+
3. **Global fallback** -- `accumulate_all`
96+
97+
This means most users never need to think about context mode. The loader picks the right default, and individual conversations can override it when needed.

docs/tutorials/multi-turn.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,6 @@ The delays between turns are controlled by:
442442
- Consider using `--request-rate` to control conversation start rate for more predictable load
443443
- Use `--random-seed` for reproducible conversation patterns
444444

445+
**See also:**
446+
- [Conversation Context Mode](conversation-context-mode.md) — Control how conversation history accumulates (delta-compressed, standalone, etc.)
447+

src/aiperf/common/enums/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
CommandResponseStatus,
1414
CommandType,
1515
ConnectionReuseStrategy,
16+
ConversationContextMode,
1617
CreditPhase,
1718
ExportLevel,
1819
GPUTelemetryMode,
@@ -77,6 +78,7 @@
7778
"CommandResponseStatus",
7879
"CommandType",
7980
"ConnectionReuseStrategy",
81+
"ConversationContextMode",
8082
"CreditPhase",
8183
"EnergyMetricUnit",
8284
"EnergyMetricUnitInfo",

src/aiperf/common/enums/enums.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,23 @@ class CommandResponseStatus(CaseInsensitiveStrEnum):
9696
UNHANDLED = "unhandled" # The command was received but not handled by any hook
9797

9898

99+
class ConversationContextMode(CaseInsensitiveStrEnum):
100+
"""Controls how prior turns are accumulated in multi-turn conversations.
101+
102+
The context mode is a property of how the dataset was constructed.
103+
It determines what conversation history is included in each request.
104+
"""
105+
106+
ACCUMULATE_ALL = "accumulate_all"
107+
"""Standard multi-turn chat. Both user and assistant turns are kept in history."""
108+
109+
DROP_RESPONSES = "drop_responses"
110+
"""Delta-compressed prompts. Dataset turns accumulate but live inference responses are discarded."""
111+
112+
STANDALONE = "standalone"
113+
"""Self-contained prompts. Each turn already has full context; no prior turns included."""
114+
115+
99116
class ConnectionReuseStrategy(CaseInsensitiveStrEnum):
100117
"""Transport connection reuse strategy. Controls how and when connections are reused across requests."""
101118

src/aiperf/common/models/dataset_models.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pydantic import Field
99

10-
from aiperf.common.enums import MediaType
10+
from aiperf.common.enums import ConversationContextMode, MediaType
1111
from aiperf.common.models.base_models import AIPerfBaseModel
1212
from aiperf.common.types import MediaTypeT
1313
from aiperf.plugin.enums import DatasetClientStoreType, DatasetSamplingStrategy
@@ -246,6 +246,12 @@ class DatasetMetadata(AIPerfBaseModel):
246246
default=False,
247247
description="Whether the dataset has timing data (timestamps/delays in turns).",
248248
)
249+
default_context_mode: ConversationContextMode | None = Field(
250+
default=None,
251+
description="Dataset-level default for how prior turns are accumulated. "
252+
"Set by the loader based on dataset format semantics. "
253+
"Individual conversations can override this via their own context_mode field.",
254+
)
249255

250256
@cached_property
251257
def total_turn_count(self) -> int:
@@ -270,6 +276,11 @@ class Conversation(AIPerfBaseModel):
270276
session_id: str = Field(
271277
default="", description="Unique identifier for the conversation."
272278
)
279+
context_mode: ConversationContextMode | None = Field(
280+
default=None,
281+
description="How prior turns are accumulated for this conversation. "
282+
"When None, inherits the dataset-level default.",
283+
)
273284
turns: list[Turn] = Field(
274285
default=[], description="List of turns in the conversation."
275286
)

src/aiperf/dataset/composer/custom.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aiperf.common.tokenizer import Tokenizer
1111
from aiperf.common.utils import load_json_str
1212
from aiperf.dataset.composer.base import BaseDatasetComposer
13+
from aiperf.dataset.loader.base_loader import BaseLoader
1314
from aiperf.dataset.utils import check_file_exists
1415
from aiperf.plugin import plugins
1516
from aiperf.plugin.enums import CustomDatasetType, PluginType
@@ -18,6 +19,7 @@
1819
class CustomDatasetComposer(BaseDatasetComposer):
1920
def __init__(self, config: UserConfig, tokenizer: Tokenizer | None):
2021
super().__init__(config, tokenizer)
22+
self.loader: BaseLoader | None = None
2123

2224
def create_dataset(self) -> list[Conversation]:
2325
"""Create conversations from a file or directory.

src/aiperf/dataset/dataset_manager.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from aiperf.common.enums import (
1515
CommAddress,
1616
CommandType,
17+
ConversationContextMode,
1718
CreditPhase,
1819
MessageType,
1920
PublicDatasetType,
@@ -108,6 +109,7 @@ def __init__(
108109
compress_only=self._compress_only,
109110
)
110111
self._dataset_client: DatasetClientStoreProtocol | None = None
112+
self._default_context_mode: ConversationContextMode | None = None
111113

112114
@on_command(CommandType.PROFILE_CONFIGURE)
113115
async def _profile_configure_command(
@@ -291,14 +293,21 @@ async def _load_public_dataset(self) -> list[Conversation]:
291293
self.user_config.input.dataset_sampling_strategy = (
292294
loader.get_recommended_sampling_strategy()
293295
)
296+
self._default_context_mode = loader.get_default_context_mode()
294297
return await loader.convert_to_conversations(dataset)
295298

296299
def _load_custom_dataset(self) -> list[Conversation]:
297300
ComposerClass = plugins.get_class(
298301
PluginType.DATASET_COMPOSER, ComposerType.CUSTOM
299302
)
300303
composer = ComposerClass(config=self.user_config, tokenizer=self.tokenizer)
301-
return composer.create_dataset()
304+
conversations = composer.create_dataset()
305+
self._default_context_mode = (
306+
composer.loader.get_default_context_mode()
307+
if composer.loader is not None
308+
else None
309+
)
310+
return conversations
302311

303312
def _is_rankings_endpoint(self, endpoint_type: str) -> bool:
304313
return "rankings" in endpoint_type.lower()
@@ -321,6 +330,7 @@ async def _configure_dataset(self) -> None:
321330

322331
self.dataset_configured.clear()
323332

333+
self._default_context_mode = None
324334
if self.user_config.input.public_dataset is not None:
325335
conversations = await self._load_public_dataset()
326336
elif (
@@ -364,6 +374,7 @@ async def _configure_dataset(self) -> None:
364374
self.dataset_metadata = DatasetMetadata(
365375
conversations=[conversation.metadata() for conversation in conversations],
366376
sampling_strategy=self.user_config.input.dataset_sampling_strategy,
377+
default_context_mode=self._default_context_mode,
367378
)
368379
self.info(
369380
f"sampling strategy: {self.dataset_metadata.sampling_strategy}, "

src/aiperf/dataset/loader/base_loader.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from abc import ABC, abstractmethod
55

66
from aiperf.common.config.user_config import UserConfig
7+
from aiperf.common.enums import ConversationContextMode
78
from aiperf.common.mixins import AIPerfLoggerMixin
89
from aiperf.common.models import Conversation
910
from aiperf.common.session_id_generator import SessionIDGenerator
@@ -31,6 +32,15 @@ def __init__(self, *, user_config: UserConfig, **kwargs):
3132
seed=user_config.input.random_seed
3233
)
3334

35+
@classmethod
36+
def get_default_context_mode(cls) -> ConversationContextMode | None:
37+
"""Dataset-level default context mode for conversations without an explicit one.
38+
39+
Override in subclasses when the dataset format implies a specific mode.
40+
Returns None to fall through to the global ACCUMULATE_ALL default.
41+
"""
42+
return None
43+
3444
@abstractmethod
3545
def load_dataset(self) -> dict[str, list[CustomDatasetT]]: ...
3646

src/aiperf/dataset/loader/base_trace_loader.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from aiperf.common.config.config_defaults import InputTokensDefaults
88
from aiperf.common.config.user_config import UserConfig
9+
from aiperf.common.enums import ConversationContextMode
910
from aiperf.common.models import Conversation, Text, Turn
1011
from aiperf.dataset.generator.parallel_decode import parallel_decode
1112
from aiperf.dataset.generator.prompt import PromptGenerator
@@ -213,6 +214,16 @@ def _get_text_input(self, trace: TraceT) -> str | None:
213214
"""
214215
return getattr(trace, "text_input", None)
215216

217+
def _infer_context_mode(
218+
self, traces: list[TraceT]
219+
) -> ConversationContextMode | None:
220+
"""Infer context_mode from trace data when not explicitly set.
221+
222+
Override in subclasses to auto-detect based on trace content.
223+
Default returns None (falls through to global ACCUMULATE_ALL default).
224+
"""
225+
return None
226+
216227
def _build_turn(self, trace: TraceT, prompt: str) -> Turn:
217228
"""Build a :class:`Turn` from trace data and a generated prompt.
218229
@@ -292,7 +303,12 @@ def convert_to_conversations(
292303
# Phase 3: Build final conversation objects
293304
conversations: list[Conversation] = []
294305
for session_id, trace_prompt_pairs in conversations_data.items():
295-
conversation = Conversation(session_id=session_id)
306+
traces_in_session = [trace for trace, _ in trace_prompt_pairs]
307+
context_mode = self._infer_context_mode(traces_in_session)
308+
309+
conversation = Conversation(
310+
session_id=session_id, context_mode=context_mode
311+
)
296312
for trace, prompt in trace_prompt_pairs:
297313
conversation.turns.append(self._build_turn(trace, prompt))
298314
conversations.append(conversation)

0 commit comments

Comments
 (0)