Skip to content

Commit 9e7182b

Browse files
authored
[TRTLLM-9615][feat] Implement a distributed tuning system (NVIDIA#9621)
Four distinct strategies are implemented to accommodate different distributed tuning scenarios, including BROADCAST, INDEPENDENT, MERGE, PARALLEL. * Distributed tuning is disabled by default, with the INDEPENDENT strategy as the fallback. This conservative approach prevents unexpected behavior in standard use cases. * Only operations with significant tuning time overhead have been assigned the PARALLEL strategy, which allows the same tensor parallelism (TP) rank to tune tactics concurrently across different ranks. This targeted approach balances performance gains with stability. * Operations with nested tuning structures, such as NVFP4GemmUnifiedRunner, currently support only the INDEPENDENT strategy. This restriction exists because the synchronization mechanism is optimized only for leaf operations and doesn't yet handle nested hierarchies. Signed-off-by: Yukun He <[email protected]>
1 parent ef4ea95 commit 9e7182b

File tree

7 files changed

+364
-79
lines changed

7 files changed

+364
-79
lines changed

examples/layer_wise_benchmarks/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def comma_separated_floats(s):
206206
if autotune_flag:
207207
if args.enable_autotuner:
208208
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
209-
with autotune(cache_path=cache_path, rank=rank):
209+
with autotune(cache_path=cache_path):
210210
run_pack()
211211
if args.run_type == "GEN":
212212
logger.info("Layer-wise benchmarks: Prefill KV cache")

tensorrt_llm/_torch/autotuner.py

Lines changed: 201 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import ast
22
import contextlib
33
import copy
4+
import enum
45
import inspect
56
import itertools
67
import json
@@ -16,8 +17,25 @@
1617
from cuda.bindings import driver
1718

1819
import tensorrt_llm
20+
from tensorrt_llm._torch.distributed import Distributed
1921
from tensorrt_llm.bindings.internal.runtime import delay_kernel
2022
from tensorrt_llm.logger import logger
23+
from tensorrt_llm.mapping import Mapping
24+
25+
26+
class DistributedTuningStrategy(enum.Enum):
27+
"""
28+
Strategy for distributed tuning.
29+
Args:
30+
BROADCAST: One rank (rank 0) tunes and broadcasts results to others
31+
INDEPENDENT: Each rank tunes independently (default for non-comm ops)
32+
MERGE: All ranks participate in tuning and reach merge
33+
PARALLEL: All ranks participate in tuning with partial tactics
34+
"""
35+
BROADCAST = "broadcast"
36+
INDEPENDENT = "independent"
37+
MERGE = "merge"
38+
PARALLEL = "parallel"
2139

2240

2341
@dataclass(slots=True, unsafe_hash=True)
@@ -99,13 +117,15 @@ class TuningConfig:
99117
This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
100118
Notice that not all tuning processes can benefit from this feature.
101119
use_cuda_graph (bool): Whether to use CUDA graph for the tuning process.
120+
distributed_tuning_strategy (DistributedTuningStrategy): Strategy for distributed tuning.
102121
"""
103122
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = ()
104123
constraint_specs: Tuple[ConstraintSpec, ...] = ()
105124
tune_max_num_tokens: int = None
106125
inputs_pre_hook: Callable = None
107126
use_cold_l2_cache: bool = False
108127
use_cuda_graph: bool = True
128+
distributed_tuning_strategy: DistributedTuningStrategy = DistributedTuningStrategy.INDEPENDENT
109129

110130

111131
@dataclass(unsafe_hash=True)
@@ -229,7 +249,16 @@ def unique_id(self):
229249

230250

231251
@contextlib.contextmanager
232-
def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
252+
def autotune(tune_mode: bool = True, cache_path: str = None):
253+
"""Context manager for autotuning with distributed support.
254+
255+
Args:
256+
tune_mode: Whether to enable tuning mode
257+
cache_path: Path to save/load cache files
258+
"""
259+
autotuner = AutoTuner.get()
260+
rank = autotuner.mapping.rank
261+
233262
# if cache_path is provided, use the rank-specific file
234263
tune_required = tune_mode
235264
if cache_path is not None:
@@ -242,25 +271,27 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
242271
if file_exists:
243272
logger.info(
244273
f"[Autotuner] Loading cache from {cache_path_no_ext_rank}")
245-
AutoTuner.get().profiling_cache.load_cache(cache_path_no_ext_rank)
274+
autotuner.profiling_cache.load_cache(cache_path_no_ext_rank)
246275

247276
# record the old tuning mode
248-
old_mode = AutoTuner.get().is_tuning_mode
249-
AutoTuner.get().is_tuning_mode = tune_required
277+
old_mode = autotuner.is_tuning_mode
278+
autotuner.is_tuning_mode = tune_required
250279
autotune_enabled = tune_required and not old_mode
280+
251281
if autotune_enabled:
252282
logger.info("[Autotuner] Autotuning process starts ...")
283+
253284
try:
254285
yield
255286
finally:
256-
AutoTuner.get().is_tuning_mode = old_mode
287+
autotuner.is_tuning_mode = old_mode
257288
if autotune_enabled:
258289
logger.info("[Autotuner] Autotuning process ends")
259290

260291
# save cache
261292
if cache_path is not None:
262293
logger.info(f"[Autotuner] Saving cache to {cache_path_no_ext_rank}")
263-
AutoTuner.get().profiling_cache.save_cache(cache_path_no_ext_rank)
294+
autotuner.profiling_cache.save_cache(cache_path_no_ext_rank)
264295

265296

266297
@dataclass
@@ -399,6 +430,9 @@ def get_cache_key(
399430
),
400431
)
401432

433+
def merge_cache_data(self, cache_data: Dict[str, Any]):
434+
self.cache.update(cache_data)
435+
402436
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
403437
return {k: v for k, v in self.cache.items() if k[0] == custom_op}
404438

@@ -561,6 +595,11 @@ class AutoTuner:
561595
_instance = None
562596

563597
def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
598+
# Increase log level for AutoTuner associated logger`
599+
self._log_level_to_info = os.getenv(
600+
"TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1'
601+
self._debug_logger = logger.info if self._log_level_to_info else logger.debug
602+
564603
self.repeat = repeat
565604
self.warmup = warmup
566605
self.stream_delay_micro_secs = stream_delay_micro_secs
@@ -575,17 +614,19 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
575614
# Last captured choose_one() contexts
576615
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
577616

578-
# Increase log level for AutoTuner associated logger
579-
self._log_level_to_info = os.getenv(
580-
"TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO", '0') == '1'
581-
self._debug_logger = logger.info if self._log_level_to_info else logger.debug
617+
# Dsitributed tuning state
618+
self._dist: Optional[Distributed] = None
619+
self.mapping: Mapping = Mapping()
582620

583621
@classmethod
584622
def get(cls):
585623
if cls._instance is None:
586624
cls._instance = AutoTuner()
587625
return cls._instance
588626

627+
def set_mapping(self, mapping: Mapping = None):
628+
self.mapping = mapping
629+
589630
class TacticsCapture:
590631
"""Object returned by capture() that can be iterated to get all tactic combinations.
591632
@@ -768,42 +809,26 @@ def choose_one(
768809
self.stats.tuned_op_profiled_configs[custom_op] = 0
769810
if custom_op not in self.stats.failed_profiling_count:
770811
self.stats.failed_profiling_count[custom_op] = set()
771-
new_tuning_failure_occured = False
772-
773-
for p in profiles:
774-
tensors = self._prepare_input_tensors(p, inputs)
775-
is_cache_hit, *_ = self.profiling_cache.search_cache(
776-
custom_op, runners, p.get_opt_shapes(), tuning_config)
777-
if not is_cache_hit:
778-
# Initialize runner and tactic as None in case of no valid tactic or runners are found
779-
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
780-
custom_op, runners, tensors, p, tuning_config, **kwargs)
781-
if best_runner_id is not None:
782-
# At least one valid (runner, tactic) pair is found
783-
cache_key = self.profiling_cache.get_cache_key(
784-
custom_op, runners[best_runner_id], p.get_opt_shapes(),
785-
tuning_config)
786-
787-
self._debug_logger(
788-
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
789-
)
790-
# inspect call stack
791-
self.profiling_cache[cache_key] = (best_runner_id,
792-
best_tactic, min_time)
793-
794-
self.stats.tuned_op_profiled_configs[custom_op] += 1
795-
else:
796-
logger.warning_once(
797-
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={input_shapes}. "
798-
f"At least one valid (runner, tactic) pair is required. "
799-
f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
800-
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash.",
801-
key=(custom_op, "warning_autotuning_no_valid_tactic"),
802-
)
803-
new_tuning_failure_occured = new_tuning_failure_occured or has_tuning_failure_occured
812+
new_tuning_failure_occurred = False
813+
814+
# Synchronize ranks before profiling
815+
if self._should_current_rank_tune(
816+
tuning_config.distributed_tuning_strategy):
817+
for p in profiles:
818+
tensors = self._prepare_input_tensors(p, inputs)
819+
is_cache_hit, *_ = self.profiling_cache.search_cache(
820+
custom_op, runners, p.get_opt_shapes(), tuning_config)
821+
if not is_cache_hit:
822+
# Initialize runner and tactic as None in case of no valid tactic or runners are found
823+
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
824+
custom_op, runners, tensors, p, tuning_config, **kwargs)
825+
new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred
826+
827+
self._maybe_sync_cache_data(tuning_config.distributed_tuning_strategy,
828+
custom_op)
804829

805830
# If failed profiling tactics occurs, log the error.
806-
if new_tuning_failure_occured:
831+
if new_tuning_failure_occurred:
807832
logger.warning_once(
808833
f"[Autotuner] New tuning error occurs:"
809834
f"Total failed profiling tactics occurs: {len(self.stats.failed_profiling_count[custom_op])} for custom_op={custom_op}. "
@@ -834,7 +859,7 @@ def _profile_runners(
834859
**kwargs,
835860
) -> float:
836861
min_time = float('inf')
837-
has_tuning_failure_occured = False
862+
has_tuning_failure_occurred = False
838863
best_runner_id, best_tactic = None, None
839864
# If the inputs_pre_hook is provided, it will be called before profiling.
840865
if tuning_config.inputs_pre_hook is not None:
@@ -845,8 +870,11 @@ def _profile_runners(
845870
p.name
846871
for p in inspect.signature(runner.forward).parameters.values()
847872
}
848-
valid_tactics = runner.get_valid_tactics(input_tensors, profile,
849-
**kwargs)
873+
all_valid_tactics = runner.get_valid_tactics(
874+
input_tensors, profile, **kwargs)
875+
876+
valid_tactics = self._maybe_parallelize_tactics(
877+
all_valid_tactics, tuning_config.distributed_tuning_strategy)
850878
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
851879
runner(
852880
input_tensors,
@@ -882,12 +910,36 @@ def _profile_runners(
882910
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
883911
# or some runtime error occurs during profiling.
884912
time_measured = float('inf')
885-
has_tuning_failure_occured = True
913+
has_tuning_failure_occurred = True
886914
if time_measured < min_time:
887915
min_time = time_measured
888916
best_runner_id, best_tactic = runner_id, tac
889917

890-
return best_runner_id, best_tactic, min_time, has_tuning_failure_occured
918+
if best_runner_id is not None:
919+
# At least one valid (runner, tactic) pair is found
920+
cache_key = self.profiling_cache.get_cache_key(
921+
custom_op, runners[best_runner_id], profile.get_opt_shapes(),
922+
tuning_config)
923+
924+
self._debug_logger(
925+
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
926+
)
927+
# inspect call stack
928+
# TODO: use named tuple to make it more readable
929+
self.profiling_cache[cache_key] = (best_runner_id, best_tactic,
930+
min_time)
931+
932+
self.stats.tuned_op_profiled_configs[custom_op] += 1
933+
else:
934+
logger.warning_once(
935+
f"[Autotuner] No valid runner/tactic was found for custom_op={custom_op}, input_shapes={profile.get_opt_shapes()}. "
936+
f"At least one valid (runner, tactic) pair is required. "
937+
f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
938+
f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash.",
939+
key=(custom_op, "warning_autotuning_no_valid_tactic"),
940+
)
941+
942+
return best_runner_id, best_tactic, min_time, has_tuning_failure_occurred
891943

892944
def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]:
893945

@@ -1358,3 +1410,103 @@ def _cudaGetErrorEnum(self, error) -> str:
13581410
return nvrtc.nvrtcGetErrorString(error)[1]
13591411
else:
13601412
raise RuntimeError("Unknown error type: {}".format(error))
1413+
1414+
def setup_distributed_state(self, mapping: Mapping, dist: Distributed):
1415+
"""Setup distributed communication state for autotuning."""
1416+
self.mapping = mapping
1417+
self._dist = dist
1418+
self._debug_logger(
1419+
f"[AutoTuner] Whether using distributed tuning: {self._is_distributed()}"
1420+
)
1421+
1422+
def _is_distributed(self) -> bool:
1423+
"""Check if we're in a distributed environment."""
1424+
return self.mapping is not None and self.mapping.tp_size > 1 and self._dist is not None
1425+
1426+
def _maybe_parallelize_tactics(
1427+
self, all_valid_tactics: List[Any],
1428+
strategy: DistributedTuningStrategy) -> List[Any]:
1429+
"""Parallelize tactics across all TP ranks if strategy is PARALLEL."""
1430+
if strategy == DistributedTuningStrategy.PARALLEL:
1431+
# only distribute across TP ranks
1432+
# each TP rank will only tune the tactics that are assigned to it
1433+
tp_size = self.mapping.tp_size
1434+
tp_rank = self.mapping.tp_rank
1435+
valid_tactics = []
1436+
for idx, tactic in enumerate(all_valid_tactics):
1437+
if idx % tp_size == tp_rank:
1438+
valid_tactics.append(tactic)
1439+
return valid_tactics
1440+
else:
1441+
return all_valid_tactics
1442+
1443+
def _maybe_sync_cache_data(self, strategy: DistributedTuningStrategy,
1444+
custom_op: str):
1445+
"""Synchronize cache data across all ranks."""
1446+
if not self._is_distributed():
1447+
logger.warning(
1448+
f"[AutoTuner] Not in distributed environment, skipping synchronization"
1449+
)
1450+
return
1451+
1452+
if strategy == DistributedTuningStrategy.BROADCAST:
1453+
self._broadcast_cache_data(custom_op)
1454+
elif strategy == DistributedTuningStrategy.INDEPENDENT:
1455+
return
1456+
elif strategy == DistributedTuningStrategy.MERGE:
1457+
self._merge_cache_data(custom_op)
1458+
elif strategy == DistributedTuningStrategy.PARALLEL:
1459+
self._merge_cache_data(custom_op)
1460+
else:
1461+
logger.error(
1462+
f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent"
1463+
)
1464+
return
1465+
1466+
def _merge_cache_data(self, custom_op: str):
1467+
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
1468+
merged_cache_data = dict()
1469+
all_cache_data = self._dist.tp_allgather(obj=cache_data)
1470+
1471+
for data in all_cache_data:
1472+
for key, value in data.items():
1473+
current_time = merged_cache_data.get(key, [
1474+
float('inf'),
1475+
])[-1]
1476+
if value[-1] < current_time:
1477+
merged_cache_data[key] = value
1478+
1479+
self.profiling_cache.merge_cache_data(merged_cache_data)
1480+
1481+
def _broadcast_cache_data(
1482+
self,
1483+
custom_op: str,
1484+
) -> None:
1485+
"""Broadcast tactics from root rank to all other ranks."""
1486+
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
1487+
root = 0
1488+
cache_data = self._dist.tp_broadcast(obj=cache_data, root=root)
1489+
1490+
self.profiling_cache.merge_cache_data(cache_data)
1491+
1492+
def _should_current_rank_tune(self,
1493+
strategy: DistributedTuningStrategy) -> bool:
1494+
"""Determine if this rank should perform tuning based on strategy."""
1495+
if not self._is_distributed():
1496+
return True
1497+
1498+
if strategy == DistributedTuningStrategy.BROADCAST:
1499+
# Only rank 0 tunes
1500+
return self.mapping.rank == 0
1501+
elif strategy in {
1502+
DistributedTuningStrategy.INDEPENDENT,
1503+
DistributedTuningStrategy.MERGE,
1504+
DistributedTuningStrategy.PARALLEL,
1505+
}:
1506+
# All ranks tune independently
1507+
return True
1508+
else:
1509+
logger.error(
1510+
f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent"
1511+
)
1512+
return True

0 commit comments

Comments
 (0)