Skip to content

Commit 639c939

Browse files
authored
[TRTC-1943][feat] Env vars override support in LLM API (#9104)
Signed-off-by: Venky Ganesh <[email protected]>
1 parent f61067c commit 639c939

File tree

14 files changed

+187
-36
lines changed

14 files changed

+187
-36
lines changed

tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from ..flashinfer_utils import ENABLE_PDL, IS_FLASHINFER_AVAILABLE
3+
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE, get_env_enable_pdl
44

55
if IS_FLASHINFER_AVAILABLE:
66
from flashinfer.activation import silu_and_mul
@@ -11,7 +11,7 @@
1111
# Warp this into custom op since flashinfer didn't warp it properly and we want to avoid graph break between mlp layer for user buffer optimization
1212
@torch.library.custom_op("trtllm::flashinfer_silu_and_mul", mutates_args=())
1313
def flashinfer_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
14-
return silu_and_mul(x, enable_pdl=ENABLE_PDL)
14+
return silu_and_mul(x, enable_pdl=get_env_enable_pdl())
1515

1616
@flashinfer_silu_and_mul.register_fake
1717
def _(x: torch.Tensor) -> torch.Tensor:
@@ -21,7 +21,7 @@ def _(x: torch.Tensor) -> torch.Tensor:
2121
@torch.library.custom_op("trtllm::flashinfer_rmsnorm", mutates_args=())
2222
def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor,
2323
eps: float) -> torch.Tensor:
24-
return rmsnorm(input, weight, eps, enable_pdl=ENABLE_PDL)
24+
return rmsnorm(input, weight, eps, enable_pdl=get_env_enable_pdl())
2525

2626
@flashinfer_rmsnorm.register_fake
2727
def _(input: torch.Tensor, weight: torch.Tensor,
@@ -32,7 +32,10 @@ def _(input: torch.Tensor, weight: torch.Tensor,
3232
mutates_args=())
3333
def flashinfer_gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor,
3434
eps: float) -> torch.Tensor:
35-
return gemma_rmsnorm(input, weight, eps, enable_pdl=ENABLE_PDL)
35+
return gemma_rmsnorm(input,
36+
weight,
37+
eps,
38+
enable_pdl=get_env_enable_pdl())
3639

3740
@flashinfer_gemma_rmsnorm.register_fake
3841
def _(input: torch.Tensor, weight: torch.Tensor,
@@ -44,7 +47,11 @@ def _(input: torch.Tensor, weight: torch.Tensor,
4447
def flashinfer_fused_add_rmsnorm(input: torch.Tensor,
4548
residual: torch.Tensor,
4649
weight: torch.Tensor, eps: float) -> None:
47-
fused_add_rmsnorm(input, residual, weight, eps, enable_pdl=ENABLE_PDL)
50+
fused_add_rmsnorm(input,
51+
residual,
52+
weight,
53+
eps,
54+
enable_pdl=get_env_enable_pdl())
4855

4956
@torch.library.custom_op("trtllm::flashinfer_gemma_fused_add_rmsnorm",
5057
mutates_args=("input", "residual"))
@@ -56,7 +63,7 @@ def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor,
5663
residual,
5764
weight,
5865
eps,
59-
enable_pdl=ENABLE_PDL)
66+
enable_pdl=get_env_enable_pdl())
6067

6168
@torch.library.custom_op(
6269
"trtllm::flashinfer_apply_rope_with_cos_sin_cache_inplace",

tensorrt_llm/_torch/flashinfer_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99

1010
def get_env_enable_pdl():
11-
return os.environ.get("TRTLLM_ENABLE_PDL", "0") == "1"
11+
enabled = os.environ.get("TRTLLM_ENABLE_PDL", "0") == "1"
12+
if enabled and not getattr(get_env_enable_pdl, "_printed", False):
13+
logger.info("PDL enabled")
14+
setattr(get_env_enable_pdl, "_printed", True)
15+
return enabled
1216

1317

14-
ENABLE_PDL = get_env_enable_pdl()
15-
if ENABLE_PDL:
16-
logger.info("PDL is enabled")
17-
1818
if platform.system() != "Windows":
1919
try:
2020
import flashinfer

tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
else:
3131
from typing_extensions import override
3232

33-
from ..flashinfer_utils import ENABLE_PDL
33+
from ..flashinfer_utils import get_env_enable_pdl
3434
from .sampling_utils import (
3535
GREEDY,
3636
GroupedStrategySampler,
@@ -113,7 +113,7 @@ def _prepare_probs_with_temperature(
113113
probs = flashinfer.sampling.softmax(
114114
logits,
115115
temperature,
116-
enable_pdl=ENABLE_PDL,
116+
enable_pdl=get_env_enable_pdl(),
117117
)
118118
return probs
119119

tensorrt_llm/bench/benchmark/low_latency.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import os
54
from functools import partial
65
from pathlib import Path
76

@@ -46,12 +45,14 @@
4645
help="Path to a serialized TRT-LLM engine.",
4746
)
4847
@optgroup.option(
48+
"--config",
4949
"--extra_llm_api_options",
50+
"extra_llm_api_options",
5051
type=str,
5152
default=None,
5253
help=
53-
"Path to a YAML file that overwrites the parameters specified by trtllm-bench."
54-
)
54+
"Path to a YAML file that overwrites the parameters specified by trtllm-bench. "
55+
"Can be specified as either --config or --extra_llm_api_options.")
5556
@optgroup.option(
5657
"--backend",
5758
type=click.Choice(ALL_SUPPORTED_BACKENDS),
@@ -192,6 +193,7 @@ def latency_command(
192193
) -> None:
193194
"""Run a latency test on a TRT-LLM engine."""
194195
logger.info("Preparing to run latency benchmark...")
196+
195197
# Parameters from CLI
196198
# Model, experiment, and engine params
197199
options = get_general_cli_options(params, bench_env)
@@ -263,14 +265,6 @@ def latency_command(
263265
exec_settings["settings_config"][
264266
"scheduler_policy"] = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
265267

266-
# Set environment variables for setting runtime options.
267-
# TODO: Once passing of variables is fixed, these should work
268-
# when using MPI in C++ runtime.
269-
os.environ["TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"] = "1"
270-
os.environ["TRTLLM_MMHA_KERNEL_BLOCK_SIZE"] = "256"
271-
os.environ["FORCE_MULTI_BLOCK_MODE"] = "1"
272-
os.environ["TRTLLM_ENABLE_PDL"] = "1"
273-
274268
# Performance options
275269
exec_settings["performance_options"]["cuda_graphs"] = True
276270
exec_settings["performance_options"]["multi_block_mode"] = True
@@ -290,6 +284,17 @@ def latency_command(
290284
kwargs = kwargs | runtime_config.get_llm_args()
291285
kwargs['backend'] = options.backend
292286

287+
# Set environment variables for setting runtime options.
288+
default_env_overrides = {
289+
"TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG": "1",
290+
"TRTLLM_MMHA_KERNEL_BLOCK_SIZE": "256",
291+
"FORCE_MULTI_BLOCK_MODE": "1",
292+
"TRTLLM_ENABLE_PDL": "1",
293+
}
294+
# Update defaults with existing overrides (user preference takes priority)
295+
default_env_overrides.update(kwargs.get("env_overrides", {}))
296+
kwargs["env_overrides"] = default_env_overrides
297+
293298
try:
294299
logger.info("Setting up latency benchmark.")
295300

tensorrt_llm/bench/benchmark/throughput.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,14 @@
6161
help="Paths to custom module directories to import.",
6262
)
6363
@optgroup.option(
64+
"--config",
6465
"--extra_llm_api_options",
66+
"extra_llm_api_options",
6567
type=str,
6668
default=None,
6769
help=
68-
"Path to a YAML file that overwrites the parameters specified by trtllm-bench."
69-
)
70+
"Path to a YAML file that overwrites the parameters specified by trtllm-bench. "
71+
"Can be specified as either --config or --extra_llm_api_options.")
7072
@optgroup.option("--sampler_options",
7173
type=click.Path(exists=True,
7274
readable=True,
@@ -293,6 +295,7 @@ def throughput_command(
293295
) -> None:
294296
"""Run a throughput test on a TRT-LLM engine."""
295297
logger.info("Preparing to run throughput benchmark...")
298+
296299
# Parameters from CLI
297300
image_data_format: str = params.get("image_data_format", "pt")
298301
data_device: str = params.get("data_device", "cpu")

tensorrt_llm/commands/eval.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@
9797
default=None,
9898
help="The revision to use for the HuggingFace model "
9999
"(branch name, tag name, or commit id).")
100-
@click.option("--extra_llm_api_options",
100+
@click.option("--config",
101+
"--extra_llm_api_options",
102+
"extra_llm_api_options",
101103
type=str,
102104
default=None,
103-
help="Path to a YAML file that overwrites the parameters")
105+
help="Path to a YAML file that overwrites the parameters. "
106+
"Can be specified as either --config or --extra_llm_api_options.")
104107
@click.option("--disable_kv_cache_reuse",
105108
is_flag=True,
106109
default=False,

tensorrt_llm/commands/serve.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,14 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
342342
help="The revision to use for the HuggingFace model "
343343
"(branch name, tag name, or commit id).")
344344
@click.option(
345+
"--config",
345346
"--extra_llm_api_options",
347+
"extra_llm_api_options",
346348
type=str,
347349
default=None,
348350
help=
349-
"Path to a YAML file that overwrites the parameters specified by trtllm-serve."
350-
)
351+
"Path to a YAML file that overwrites the parameters specified by trtllm-serve. "
352+
"Can be specified as either --config or --extra_llm_api_options.")
351353
@click.option(
352354
"--reasoning_parser",
353355
type=click.Choice(ReasoningParserFactory.parsers.keys()),

tensorrt_llm/executor/worker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,17 @@ def worker_main(
241241
tokenizer: Optional[TokenizerBase] = None,
242242
llm_args: Optional[BaseLlmArgs] = None,
243243
) -> None:
244+
244245
mpi_comm().barrier()
245246

247+
if llm_args is not None and llm_args.env_overrides:
248+
# this is needed because MPI_Init seems to cache the env at import time.
249+
# The cached env snapshot is used to spawn workers.
250+
# Any env overrides to the main process after tensorrt_llm import
251+
# may not get reflected in the spawned worker process, no matter how early,
252+
# unless we update it explicitly here.
253+
os.environ.update(llm_args.env_overrides)
254+
246255
if llm_args is not None and llm_args.trust_remote_code:
247256
_init_hf_modules()
248257

tensorrt_llm/llmapi/llm.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def __init__(self,
135135
logger.set_level("info") # force display the backend
136136

137137
try:
138+
env_overrides = kwargs.get("env_overrides", None)
139+
self._process_env_overrides(env_overrides)
140+
138141
backend = kwargs.get('backend', None)
139142
if backend == "pytorch":
140143
logger.info("Using LLM with PyTorch backend")
@@ -587,6 +590,25 @@ def get_kv_cache_events_async(self,
587590
'''
588591
return self._executor.aget_kv_events(timeout=timeout)
589592

593+
def _process_env_overrides(self,
594+
env_overrides: Optional[dict[str, str]]) -> None:
595+
if env_overrides is None:
596+
return
597+
logger.info("Processing LLM API environment variable overrides")
598+
# TODO: If an env var is cached at import-time in code, overriding os.environ will
599+
# unfortunately not update wherever the var is used.
600+
# This is a known issue and only way to fix it is at every such usage to access it
601+
# from os.environ on-demand.
602+
for key, value in env_overrides.items():
603+
str_value = str(value)
604+
if key in os.environ:
605+
old_value = os.environ[key]
606+
os.environ[key] = str_value
607+
logger.info(f"Overriding {key}: '{old_value}' -> '{str_value}'")
608+
else:
609+
os.environ[key] = str_value
610+
logger.info(f"Setting {key}='{str_value}'")
611+
590612
def _prepare_sampling_params(
591613
self,
592614
sampling_params: Optional[SamplingParams] = None) -> SamplingParams:

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,6 +1931,12 @@ class BaseLlmArgs(StrictBaseModel):
19311931
status="prototype",
19321932
)
19331933

1934+
env_overrides: Optional[Dict[str, str]] = Field(
1935+
default=None,
1936+
description=
1937+
"[EXPERIMENTAL] Environment variable overrides. NOTE: import-time-cached env vars in the code won’t update unless the code fetches them from os.environ on demand.",
1938+
status="prototype")
1939+
19341940
_parallel_config: Optional[_ParallelConfig] = PrivateAttr(default=None)
19351941
_model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None)
19361942
_speculative_model: Optional[str] = PrivateAttr(default=None)

0 commit comments

Comments
 (0)