Skip to content

Commit 0048d41

Browse files
Update dynamo headers to provide raw integer values by default (NVIDIA#1583)
Replaced categorical hints (LOW/MEDIUM/HIGH) with raw integers for Output Sequence Length (OSL) and Inter-Arrival Time (IAT). Added backward compatibility for string-based categorical values, which are now mapped to integers. Updated tests and logic to handle both raw and categorical inputs with a `use_raw_values` flag for flexibility. ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing/index.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. ## Summary by CodeRabbit * **New Features** * Emit raw numeric OSL/IAT values by default, with a toggle to emit categorical labels instead. * Option to suppress HTTP header injection while still supplying agent hints. * **Improvements** * Prefix configuration accepts integers with backward-compatible coercion from legacy strings. * LangChain/ADK integrations forward raw-mode and header-suppression settings. * **Tests** * Expanded tests for raw vs. categorical modes, prediction overrides, and header-suppression. Authors: - Dhruv Nandakumar (https://github.com/dnandakumar-nv) - Claude (https://github.com/claude) Approvers: - https://github.com/mnajafian-nv - Will Killian (https://github.com/willkill07) URL: NVIDIA#1583
1 parent e87e5e4 commit 0048d41

File tree

8 files changed

+482
-181
lines changed

8 files changed

+482
-181
lines changed

packages/nvidia_nat_adk/src/nat/plugins/adk/llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder):
179179
if config.base_url:
180180
config_dict["api_base"] = config.base_url
181181

182-
# Build Dynamo prefix headers if prefix_template is configured
183-
if config.prefix_template is not None:
182+
# Build Dynamo prefix headers if prefix_template is configured and headers are enabled
183+
if config.prefix_template is not None and not config.disable_headers:
184184
# Generate a static prefix ID for this LLM instance
185185
# For dynamic prefix IDs, users should use the LangChain client or manage sessions manually
186186
unique_id = uuid.uuid4().hex[:16]
@@ -189,8 +189,8 @@ async def dynamo_adk(config: DynamoModelConfig, _builder: Builder):
189189
extra_headers = {
190190
"x-prefix-id": prefix_id,
191191
"x-prefix-total-requests": str(config.prefix_total_requests),
192-
"x-prefix-osl": config.prefix_osl.upper(),
193-
"x-prefix-iat": config.prefix_iat.upper(),
192+
"x-prefix-osl": str(config.prefix_osl),
193+
"x-prefix-iat": str(config.prefix_iat),
194194
}
195195
config_dict["extra_headers"] = extra_headers
196196

packages/nvidia_nat_adk/tests/test_adk_llm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,9 @@ def dynamo_cfg_with_prefix(self):
199199
base_url="http://localhost:8000/v1",
200200
prefix_template="session-{uuid}",
201201
prefix_total_requests=15,
202-
prefix_osl="HIGH",
203-
prefix_iat="LOW",
202+
prefix_osl=2048,
203+
prefix_iat=50,
204+
disable_headers=False,
204205
)
205206

206207
@patch('google.adk.models.lite_llm.LiteLlm')
@@ -238,8 +239,8 @@ async def test_creation_with_prefix_template(self, mock_litellm_class, dynamo_cf
238239
assert "x-prefix-id" in headers
239240
assert headers["x-prefix-id"].startswith("session-")
240241
assert headers["x-prefix-total-requests"] == "15"
241-
assert headers["x-prefix-osl"] == "HIGH"
242-
assert headers["x-prefix-iat"] == "LOW"
242+
assert headers["x-prefix-osl"] == "2048"
243+
assert headers["x-prefix-iat"] == "50"
243244

244245
assert client is mock_llm_instance
245246

@@ -268,6 +269,8 @@ async def test_excludes_dynamo_specific_fields(self, mock_litellm_class, dynamo_
268269
assert "prefix_total_requests" not in kwargs
269270
assert "prefix_osl" not in kwargs
270271
assert "prefix_iat" not in kwargs
272+
assert "prefix_use_raw_values" not in kwargs
273+
assert "disable_headers" not in kwargs
271274
assert "request_timeout" not in kwargs
272275

273276
@patch('google.adk.models.lite_llm.LiteLlm')
@@ -280,6 +283,7 @@ async def test_prefix_id_is_unique_per_instance(self, mock_litellm_class, mock_b
280283
config = DynamoModelConfig(
281284
model_name="test-model",
282285
prefix_template="session-{uuid}",
286+
disable_headers=False,
283287
)
284288

285289
prefix_ids = set()

packages/nvidia_nat_core/src/nat/llm/dynamo_llm.py

Lines changed: 129 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,28 @@
3939
-------------------------
4040
4141
prefix_osl (Output Sequence Length)
42-
Hint for expected response length:
42+
Expected output tokens for response length hinting. By default, the raw
43+
integer value is sent. When ``prefix_use_raw_values`` is False, values are
44+
converted to categories:
4345
44-
- LOW: decode_cost=1.0, short responses
45-
- MEDIUM: decode_cost=2.0, typical responses
46-
- HIGH: decode_cost=3.0, long responses
46+
- < 256 tokens: LOW (decode_cost=1.0, short responses)
47+
- < 1024 tokens: MEDIUM (decode_cost=2.0, typical responses)
48+
- >= 1024 tokens: HIGH (decode_cost=3.0, long responses)
49+
50+
Accepts categorical strings (LOW/MEDIUM/HIGH) for backward compatibility,
51+
which are converted to representative token counts (128/512/2048).
4752
4853
prefix_iat (Inter-Arrival Time)
49-
Hint for request pacing:
54+
Expected inter-arrival time in milliseconds. By default, the raw integer
55+
value is sent. When ``prefix_use_raw_values`` is False, values are converted
56+
to categories:
57+
58+
- < 100ms: LOW (iat_factor=1.5, rapid bursts, high worker stickiness)
59+
- < 500ms: MEDIUM (iat_factor=1.0, normal pacing)
60+
- >= 500ms: HIGH (iat_factor=0.6, slow requests, more exploration)
5061
51-
- LOW: iat_factor=1.5, rapid bursts -> high worker stickiness
52-
- MEDIUM: iat_factor=1.0, normal pacing
53-
- HIGH: iat_factor=0.6, slow requests -> more exploration
62+
Accepts categorical strings (LOW/MEDIUM/HIGH) for backward compatibility,
63+
which are converted to representative millisecond values (50/250/750).
5464
5565
prefix_total_requests
5666
Expected requests per conversation:
@@ -74,6 +84,7 @@
7484
from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup
7585

7686
from pydantic import Field
87+
from pydantic import field_validator
7788

7889
from nat.builder.builder import Builder
7990
from nat.builder.context import Context
@@ -90,6 +101,17 @@
90101
# Define valid prefix hint values
91102
PrefixLevel = Literal["LOW", "MEDIUM", "HIGH"]
92103

104+
# Representative token counts for categorical levels (midpoint of ranges):
105+
# LOW: 128 tokens (midpoint of 0-256 range)
106+
# MEDIUM: 512 tokens (midpoint of 256-1024 range)
107+
# HIGH: 2048 tokens (midpoint of 1024-4096 range)
108+
_OSL_CATEGORY_TO_INT: dict[str, int] = {"LOW": 128, "MEDIUM": 512, "HIGH": 2048}
109+
# Representative interarrival times for categorical levels (midpoint of ranges):
110+
# LOW: 50ms (midpoint of 0-100ms range)
111+
# MEDIUM: 250ms (midpoint of 100-500ms range)
112+
# HIGH: 750ms (midpoint of 500-1000ms range)
113+
_IAT_CATEGORY_TO_INT: dict[str, int] = {"LOW": 50, "MEDIUM": 250, "HIGH": 750}
114+
93115
# =============================================================================
94116
# CATEGORY CONVERSION HELPERS
95117
# =============================================================================
@@ -314,22 +336,22 @@ class DynamoModelConfig(OpenAIModelConfig, name="dynamo"):
314336
"Lower values allow more load balancing across workers."),
315337
space=SearchSpace(low=1, high=20, step=5))
316338

317-
prefix_osl: PrefixLevel = OptimizableField(
318-
default="MEDIUM",
319-
description="Output Sequence Length hint for the Dynamo router. "
320-
"LOW means short responses (decode_cost=1.0), "
321-
"MEDIUM means typical (decode_cost=2.0), "
322-
"HIGH means long responses (decode_cost=3.0).",
323-
space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"]),
339+
prefix_osl: int = OptimizableField(
340+
default=512,
341+
ge=1,
342+
description="Expected output tokens for response length hinting (Output Sequence Length). "
343+
"Raw integer value is sent by default. Accepts categorical strings "
344+
"(LOW/MEDIUM/HIGH) for backward compatibility (mapped to 128/512/2048).",
345+
space=SearchSpace(low=64, high=4096, step=64),
324346
)
325347

326-
prefix_iat: PrefixLevel = OptimizableField(
327-
default="MEDIUM",
328-
description="Inter-Arrival Time hint for the Dynamo router. "
329-
"LOW means rapid bursts (iat_factor=1.5, high stickiness), "
330-
"MEDIUM means normal (iat_factor=1.0), "
331-
"HIGH means slow requests (iat_factor=0.6, more exploration).",
332-
space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"]),
348+
prefix_iat: int = OptimizableField(
349+
default=250,
350+
ge=1,
351+
description="Expected inter-arrival time in milliseconds for request pacing. "
352+
"Raw integer value is sent by default. Accepts categorical strings "
353+
"(LOW/MEDIUM/HIGH) for backward compatibility (mapped to 50/250/750).",
354+
space=SearchSpace(low=10, high=1000, step=50),
333355
)
334356

335357
request_timeout: float = Field(
@@ -338,12 +360,49 @@ class DynamoModelConfig(OpenAIModelConfig, name="dynamo"):
338360
description="HTTP request timeout in seconds for LLM requests.",
339361
)
340362

363+
prefix_use_raw_values: bool = Field(
364+
default=True,
365+
description="When True, send raw integer values for OSL (output tokens) and IAT (interarrival ms) "
366+
"in headers and nvext.agent_hints. When False, convert to categorical LOW/MEDIUM/HIGH.",
367+
)
368+
341369
prediction_trie_path: str | None = Field(
342370
default=None,
343371
description="Path to prediction_trie.json file. When set, predictions are "
344372
"looked up and used to override both HTTP headers and nvext.agent_hints for each LLM call.",
345373
)
346374

375+
disable_headers: bool = Field(
376+
default=True,
377+
description="If True, do not inject Dynamo prefix hints as HTTP headers. "
378+
"Hints will still be injected via nvext.agent_hints in the request body if prefix_template is set.",
379+
)
380+
381+
# =========================================================================
382+
# VALIDATORS (backward compatibility: categorical strings -> integers)
383+
# =========================================================================
384+
385+
@field_validator("prefix_osl", mode="before")
386+
@classmethod
387+
def _coerce_prefix_osl(cls, v: object) -> int:
388+
if isinstance(v, int):
389+
return v
390+
if isinstance(v, str):
391+
upper = v.upper()
392+
if upper in _OSL_CATEGORY_TO_INT:
393+
return _OSL_CATEGORY_TO_INT[upper]
394+
raise ValueError(f"Invalid OSL value '{v}'. Must be an integer >= 1 "
395+
f"or one of: {', '.join(_OSL_CATEGORY_TO_INT.keys())}")
396+
raise TypeError(f"prefix_osl must be int or str, got {type(v)}")
397+
398+
@field_validator("prefix_iat", mode="before")
399+
@classmethod
400+
def _coerce_prefix_iat(cls, v: object) -> object:
401+
"""Convert categorical IAT strings (LOW/MEDIUM/HIGH) to representative millisecond values."""
402+
if isinstance(v, str) and v.upper() in _IAT_CATEGORY_TO_INT:
403+
return _IAT_CATEGORY_TO_INT[v.upper()]
404+
return v
405+
347406
# =========================================================================
348407
# UTILITY METHODS
349408
# =========================================================================
@@ -371,8 +430,10 @@ def get_dynamo_field_names() -> frozenset[str]:
371430
"prefix_total_requests",
372431
"prefix_osl",
373432
"prefix_iat",
433+
"prefix_use_raw_values",
374434
"request_timeout",
375435
"prediction_trie_path",
436+
"disable_headers",
376437
})
377438

378439

@@ -397,15 +458,19 @@ def __init__(
397458
self,
398459
transport: httpx.AsyncBaseTransport,
399460
total_requests: int,
400-
osl: str,
401-
iat: str,
461+
osl: int,
462+
iat: int,
402463
prediction_lookup: "PredictionTrieLookup | None" = None,
464+
use_raw_values: bool = True,
465+
disable_headers: bool = True,
403466
):
404467
self._transport = transport
405468
self._total_requests = total_requests
406-
self._osl = osl.upper()
407-
self._iat = iat.upper()
469+
self._osl = osl
470+
self._iat = iat
408471
self._prediction_lookup = prediction_lookup
472+
self._use_raw_values = use_raw_values
473+
self._disable_headers = disable_headers
409474

410475
async def handle_async_request(self, request: "httpx.Request") -> "httpx.Response":
411476
# Get prefix ID from context (supports depth-awareness and overrides)
@@ -419,10 +484,10 @@ async def handle_async_request(self, request: "httpx.Request") -> "httpx.Respons
419484
# If context not available or latency_sensitivity not implemented yet, default to MEDIUM
420485
latency_sensitivity = "MEDIUM"
421486

422-
# Initialize with static config values
487+
# Initialize with static config values (always integers)
423488
total_requests = self._total_requests
424-
osl = self._osl
425-
iat = self._iat
489+
osl_raw = self._osl
490+
iat_raw = self._iat
426491

427492
# Check for prediction override
428493
if self._prediction_lookup is not None:
@@ -445,19 +510,17 @@ async def handle_async_request(self, request: "httpx.Request") -> "httpx.Respons
445510
if prediction:
446511
# Override with prediction-derived values
447512
total_requests = int(prediction.remaining_calls.mean)
448-
osl = _output_tokens_to_osl(prediction.output_tokens.p90)
449-
iat = _interarrival_ms_to_iat(prediction.interarrival_ms.mean)
513+
osl_raw = int(prediction.output_tokens.p90)
514+
iat_raw = int(prediction.interarrival_ms.mean)
450515

451516
logger.debug(
452517
"Overriding hints from prediction: path=%s, call_index=%d, "
453-
"total_requests=%d, osl=%s (tokens=%d), iat=%s (ms=%d)",
518+
"total_requests=%d, osl_raw=%d, iat_raw=%d",
454519
path,
455520
call_index,
456521
total_requests,
457-
osl,
458-
int(prediction.output_tokens.p90),
459-
iat,
460-
int(prediction.interarrival_ms.mean),
522+
osl_raw,
523+
iat_raw,
461524
)
462525
else:
463526
logger.debug(
@@ -469,26 +532,35 @@ async def handle_async_request(self, request: "httpx.Request") -> "httpx.Respons
469532
except Exception:
470533
logger.exception("Failed to lookup prediction")
471534

472-
# Inject HTTP headers
535+
# Compute final values for headers/body
536+
if self._use_raw_values:
537+
osl_value: int | str = osl_raw
538+
iat_value: int | str = iat_raw
539+
else:
540+
osl_value = _output_tokens_to_osl(osl_raw)
541+
iat_value = _interarrival_ms_to_iat(iat_raw)
542+
473543
headers = dict(request.headers)
474-
headers[f"{LLMHeaderPrefix.DYNAMO}-id"] = prefix_id
475-
headers[f"{LLMHeaderPrefix.DYNAMO}-total-requests"] = str(total_requests)
476-
headers[f"{LLMHeaderPrefix.DYNAMO}-osl"] = osl
477-
headers[f"{LLMHeaderPrefix.DYNAMO}-iat"] = iat
478-
headers[f"{LLMHeaderPrefix.DYNAMO}-latency-sensitivity"] = latency_sensitivity
544+
if not self._disable_headers:
545+
# Headers always need strings
546+
headers[f"{LLMHeaderPrefix.DYNAMO}-id"] = prefix_id
547+
headers[f"{LLMHeaderPrefix.DYNAMO}-total-requests"] = str(total_requests)
548+
headers[f"{LLMHeaderPrefix.DYNAMO}-osl"] = str(osl_value)
549+
headers[f"{LLMHeaderPrefix.DYNAMO}-iat"] = str(iat_value)
550+
headers[f"{LLMHeaderPrefix.DYNAMO}-latency-sensitivity"] = latency_sensitivity
479551

480552
# Modify body to inject nvext.agent_hints (if JSON POST request)
481553
content = request.content
482554
if request.method == "POST" and content:
483555
try:
484556
body = json.loads(content.decode("utf-8", errors="replace"))
485557
if isinstance(body, dict):
486-
# Build agent_hints dict
558+
# Build agent_hints dict (int or str depending on raw mode)
487559
agent_hints = {
488560
"prefix_id": prefix_id,
489561
"total_requests": total_requests,
490-
"osl": osl,
491-
"iat": iat,
562+
"osl": osl_value,
563+
"iat": iat_value,
492564
"latency_sensitivity": latency_sensitivity,
493565
}
494566

@@ -527,8 +599,8 @@ async def handle_async_request(self, request: "httpx.Request") -> "httpx.Respons
527599
logger.debug("Injected Dynamo hints: prefix_id=%s, total_requests=%d, osl=%s, iat=%s, latency_sensitivity=%s",
528600
prefix_id,
529601
total_requests,
530-
osl,
531-
iat,
602+
osl_value,
603+
iat_value,
532604
latency_sensitivity)
533605

534606
return await self._transport.handle_async_request(new_request)
@@ -546,10 +618,12 @@ async def aclose(self) -> None:
546618
def create_httpx_client_with_dynamo_hooks(
547619
prefix_template: str | None,
548620
total_requests: int,
549-
osl: str,
550-
iat: str,
621+
osl: int,
622+
iat: int,
551623
timeout: float = 600.0,
552624
prediction_lookup: "PredictionTrieLookup | None" = None,
625+
use_raw_values: bool = True,
626+
disable_headers: bool = True,
553627
) -> "httpx.AsyncClient":
554628
"""
555629
Create an httpx.AsyncClient with Dynamo hint injection via custom transport.
@@ -564,10 +638,12 @@ def create_httpx_client_with_dynamo_hooks(
564638
Args:
565639
prefix_template: Template string with {uuid} placeholder (unused, kept for API compat)
566640
total_requests: Expected number of requests for this prefix
567-
osl: Output sequence length hint (LOW/MEDIUM/HIGH)
568-
iat: Inter-arrival time hint (LOW/MEDIUM/HIGH)
641+
osl: Expected output tokens (raw integer value)
642+
iat: Expected inter-arrival time in milliseconds (raw integer value)
569643
timeout: HTTP request timeout in seconds
570644
prediction_lookup: Optional PredictionTrieLookup for dynamic hint injection
645+
use_raw_values: When True send raw integers; when False convert to LOW/MEDIUM/HIGH
646+
disable_headers: If True, do not inject hints as HTTP headers (still injects nvext.agent_hints)
571647
572648
Returns:
573649
An httpx.AsyncClient configured with Dynamo hint injection.
@@ -586,6 +662,8 @@ def create_httpx_client_with_dynamo_hooks(
586662
osl=osl,
587663
iat=iat,
588664
prediction_lookup=prediction_lookup,
665+
use_raw_values=use_raw_values,
666+
disable_headers=disable_headers,
589667
)
590668

591669
return httpx.AsyncClient(

0 commit comments

Comments
 (0)