Skip to content

Commit 0b07662

Browse files
committed
TRTLLM MoE maps to lower tuning buckets when ep>1
Signed-off-by: Anthony Chang <[email protected]>
1 parent af2849c commit 0b07662

File tree

3 files changed

+231
-86
lines changed

3 files changed

+231
-86
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ class DynamicTensorSpec:
2828
input_idx: The index of the input tensor.
2929
dim_idx: The index of the dimension to tune.
3030
gen_tuning_buckets: A tuple of values to try or a function generating values.
31-
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
31+
map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
32+
map_to_runtime_buckets: A function to map dimensions to valid values during inference.
33+
If None, use map_to_tuning_buckets.
3234
"""
3335
input_idx: int
3436
dim_idx: int
3537
gen_tuning_buckets: Union[Tuple[int], Callable] = ()
3638
map_to_tuning_buckets: Callable = lambda x: x
39+
map_to_runtime_buckets: Optional[Callable] = None
3740

3841

3942
@dataclass(slots=True, unsafe_hash=True)
@@ -358,22 +361,27 @@ def search_cache(
358361
runners: List[TunableRunner],
359362
input_shapes: Tuple[torch.Size],
360363
tuning_config: TuningConfig,
364+
use_tuning_mapping: bool = False,
361365
) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]:
362366
"""Search for cached profiling results matching the current configuration.
363367
364368
Args:
365369
custom_op (str): The name of the custom operation to be tuned
366370
runners (List[TunableRunner]): List of candidate implementations to profile
367371
profile (OptimizationProfile): Optimization profile
372+
use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
373+
If False, use map_to_runtime_buckets for runtime cache lookups.
368374
369375
Returns:
370376
A tuple containing:
371377
[is_cache_hit, runner_id, tactic, stored_profile]
372378
runner_id is the index in the current runners list
373379
"""
374380
for idx, r in enumerate(runners):
375-
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
376-
tuning_config)) in self.cache:
381+
if (cache_key :=
382+
self.get_cache_key(custom_op, r, input_shapes,
383+
tuning_config,
384+
use_tuning_mapping)) in self.cache:
377385
# Return the current index in runners list, not the cached runner_id
378386
cached_runner_id, tactic, min_time = self.cache[cache_key]
379387
return True, idx, tactic, min_time
@@ -386,6 +394,7 @@ def get_cache_key(
386394
runner: TunableRunner,
387395
input_shapes: Tuple[torch.Size],
388396
tuning_config: TuningConfig,
397+
use_tuning_mapping: bool = False,
389398
) -> Tuple:
390399
return (
391400
custom_op,
@@ -396,6 +405,7 @@ def get_cache_key(
396405
tuning_config.dynamic_tensor_specs,
397406
tuning_config.constraint_specs,
398407
tuning_config.tune_max_num_tokens,
408+
use_tuning_mapping,
399409
),
400410
)
401411

@@ -773,16 +783,23 @@ def choose_one(
773783
for p in profiles:
774784
tensors = self._prepare_input_tensors(p, inputs)
775785
is_cache_hit, *_ = self.profiling_cache.search_cache(
776-
custom_op, runners, p.get_opt_shapes(), tuning_config)
786+
custom_op,
787+
runners,
788+
p.get_opt_shapes(),
789+
tuning_config,
790+
use_tuning_mapping=True)
777791
if not is_cache_hit:
778792
# Initialize runner and tactic as None in case of no valid tactic or runners are found
779793
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
780794
custom_op, runners, tensors, p, tuning_config, **kwargs)
781795
if best_runner_id is not None:
782796
# At least one valid (runner, tactic) pair is found
783797
cache_key = self.profiling_cache.get_cache_key(
784-
custom_op, runners[best_runner_id], p.get_opt_shapes(),
785-
tuning_config)
798+
custom_op,
799+
runners[best_runner_id],
800+
p.get_opt_shapes(),
801+
tuning_config,
802+
use_tuning_mapping=True)
786803

787804
self._debug_logger(
788805
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
@@ -876,8 +893,11 @@ def _profile_runners(
876893
# Record the failed profiling combinations
877894
self.stats.failed_profiling_count[custom_op].add(
878895
self.profiling_cache.get_cache_key(
879-
custom_op, runner, profile.get_opt_shapes(),
880-
tuning_config))
896+
custom_op,
897+
runner,
898+
profile.get_opt_shapes(),
899+
tuning_config,
900+
use_tuning_mapping=True))
881901

882902
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
883903
# or some runtime error occurs during profiling.
@@ -1088,13 +1108,16 @@ def _find_nearest_profile(
10881108
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...],
10891109
constraint_specs: Tuple[ConstraintSpec, ...],
10901110
tune_max_num_tokens: int = None,
1111+
use_tuning_mapping: bool = False,
10911112
) -> Tuple:
10921113
"""Find the nearest optimization profile for given inputs
10931114
User can define their own nearest profile generation method to reduce the host overhead.
10941115
10951116
Args:
10961117
shapes: Tuple of input tensor shapes
10971118
tuning_config: Tuning configuration
1119+
use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1120+
If False, use map_to_runtime_buckets for runtime cache lookups.
10981121
10991122
Return:
11001123
Tuple: A tuple containing:
@@ -1104,9 +1127,12 @@ def _find_nearest_profile(
11041127
base_profile = list(list(shape) for shape in shapes)
11051128

11061129
for spec in dynamic_tensor_specs:
1107-
base_profile[spec.input_idx][
1108-
spec.dim_idx] = spec.map_to_tuning_buckets(
1109-
base_profile[spec.input_idx][spec.dim_idx])
1130+
1131+
bucket_mapper = spec.map_to_tuning_buckets
1132+
if not use_tuning_mapping and spec.map_to_runtime_buckets is not None:
1133+
bucket_mapper = spec.map_to_runtime_buckets
1134+
base_profile[spec.input_idx][spec.dim_idx] = bucket_mapper(
1135+
base_profile[spec.input_idx][spec.dim_idx])
11101136

11111137
if tune_max_num_tokens is not None:
11121138
base_profile[spec.input_idx][spec.dim_idx] = min(

0 commit comments

Comments
 (0)