Skip to content

Commit 0c27f02

Browse files
committed
TRTLLM MoE maps to lower tuning buckets when ep>1
Signed-off-by: Anthony Chang <[email protected]>
1 parent 6732c76 commit 0c27f02

File tree

3 files changed

+229
-86
lines changed

3 files changed

+229
-86
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ class DynamicTensorSpec:
4949
input_idx: The index of the input tensor.
5050
dim_idx: The index of the dimension to tune.
5151
gen_tuning_buckets: A tuple of values to try or a function generating values.
52-
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
52+
map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
53+
map_to_runtime_buckets: A function to map dimensions to valid values during inference.
54+
If None, use map_to_tuning_buckets.
5355
"""
5456
input_idx: int
5557
dim_idx: int
5658
gen_tuning_buckets: Union[Tuple[int], Callable] = ()
5759
map_to_tuning_buckets: Callable = lambda x: x
60+
map_to_runtime_buckets: Optional[Callable] = None
5861

5962

6063
@dataclass(slots=True, unsafe_hash=True)
@@ -392,22 +395,27 @@ def search_cache(
392395
runners: List[TunableRunner],
393396
input_shapes: Tuple[torch.Size],
394397
tuning_config: TuningConfig,
398+
use_tuning_mapping: bool = False,
395399
) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]:
396400
"""Search for cached profiling results matching the current configuration.
397401
398402
Args:
399403
custom_op (str): The name of the custom operation to be tuned
400404
runners (List[TunableRunner]): List of candidate implementations to profile
401405
profile (OptimizationProfile): Optimization profile
406+
use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
407+
If False, use map_to_runtime_buckets for runtime cache lookups.
402408
403409
Returns:
404410
A tuple containing:
405411
[is_cache_hit, runner_id, tactic, stored_profile]
406412
runner_id is the index in the current runners list
407413
"""
408414
for idx, r in enumerate(runners):
409-
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
410-
tuning_config)) in self.cache:
415+
if (cache_key :=
416+
self.get_cache_key(custom_op, r, input_shapes,
417+
tuning_config,
418+
use_tuning_mapping)) in self.cache:
411419
# Return the current index in runners list, not the cached runner_id
412420
cached_runner_id, tactic, min_time = self.cache[cache_key]
413421
return True, idx, tactic, min_time
@@ -420,6 +428,7 @@ def get_cache_key(
420428
runner: TunableRunner,
421429
input_shapes: Tuple[torch.Size],
422430
tuning_config: TuningConfig,
431+
use_tuning_mapping: bool = False,
423432
) -> Tuple:
424433
return (
425434
custom_op,
@@ -430,6 +439,7 @@ def get_cache_key(
430439
tuning_config.dynamic_tensor_specs,
431440
tuning_config.constraint_specs,
432441
tuning_config.tune_max_num_tokens,
442+
use_tuning_mapping,
433443
),
434444
)
435445

@@ -841,7 +851,12 @@ def choose_one(
841851
for p in profiles:
842852
tensors = self._prepare_input_tensors(p, inputs)
843853
is_cache_hit, *_ = self.profiling_cache.search_cache(
844-
custom_op, runners, p.get_opt_shapes(), tuning_config)
854+
custom_op,
855+
runners,
856+
p.get_opt_shapes(),
857+
tuning_config,
858+
use_tuning_mapping=True,
859+
)
845860
if not is_cache_hit:
846861
# Initialize runner and tactic as None in case of no valid tactic or runners are found
847862
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
@@ -928,8 +943,11 @@ def _profile_runners(
928943
# Record the failed profiling combinations
929944
self.stats.failed_profiling_count[custom_op].add(
930945
self.profiling_cache.get_cache_key(
931-
custom_op, runner, profile.get_opt_shapes(),
932-
tuning_config))
946+
custom_op,
947+
runner,
948+
profile.get_opt_shapes(),
949+
tuning_config,
950+
use_tuning_mapping=True))
933951

934952
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
935953
# or some runtime error occurs during profiling.
@@ -942,8 +960,11 @@ def _profile_runners(
942960
if best_runner_id is not None:
943961
# At least one valid (runner, tactic) pair is found
944962
cache_key = self.profiling_cache.get_cache_key(
945-
custom_op, runners[best_runner_id], profile.get_opt_shapes(),
946-
tuning_config)
963+
custom_op,
964+
runners[best_runner_id],
965+
profile.get_opt_shapes(),
966+
tuning_config,
967+
use_tuning_mapping=True)
947968

948969
self._debug_logger(
949970
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
@@ -1164,13 +1185,16 @@ def _find_nearest_profile(
11641185
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...],
11651186
constraint_specs: Tuple[ConstraintSpec, ...],
11661187
tune_max_num_tokens: int = None,
1188+
use_tuning_mapping: bool = False,
11671189
) -> Tuple:
11681190
"""Find the nearest optimization profile for given inputs
11691191
User can define their own nearest profile generation method to reduce the host overhead.
11701192
11711193
Args:
11721194
shapes: Tuple of input tensor shapes
11731195
tuning_config: Tuning configuration
1196+
use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1197+
If False, use map_to_runtime_buckets for runtime cache lookups.
11741198
11751199
Return:
11761200
Tuple: A tuple containing:
@@ -1180,9 +1204,12 @@ def _find_nearest_profile(
11801204
base_profile = list(list(shape) for shape in shapes)
11811205

11821206
for spec in dynamic_tensor_specs:
1183-
base_profile[spec.input_idx][
1184-
spec.dim_idx] = spec.map_to_tuning_buckets(
1185-
base_profile[spec.input_idx][spec.dim_idx])
1207+
1208+
bucket_mapper = spec.map_to_tuning_buckets
1209+
if not use_tuning_mapping and spec.map_to_runtime_buckets is not None:
1210+
bucket_mapper = spec.map_to_runtime_buckets
1211+
base_profile[spec.input_idx][spec.dim_idx] = bucket_mapper(
1212+
base_profile[spec.input_idx][spec.dim_idx])
11861213

11871214
if tune_max_num_tokens is not None:
11881215
base_profile[spec.input_idx][spec.dim_idx] = min(

0 commit comments

Comments
 (0)