Skip to content

Commit d60d2c8

Browse files
committed
TRTLLM MoE maps to lower tuning buckets when ep>1
Signed-off-by: Anthony Chang <[email protected]>
1 parent c1cfb61 commit d60d2c8

File tree

3 files changed

+232
-86
lines changed

3 files changed

+232
-86
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@ class DynamicTensorSpec:
4646
input_idx: The index of the input tensor.
4747
dim_idx: The index of the dimension to tune.
4848
gen_tuning_buckets: A tuple of values to try or a function generating values.
49-
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
49+
map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
50+
map_to_runtime_buckets: A function to map dimensions to valid values during inference.
51+
If None, use map_to_tuning_buckets.
5052
"""
5153
input_idx: int
5254
dim_idx: int
5355
gen_tuning_buckets: Union[Tuple[int], Callable] = ()
5456
map_to_tuning_buckets: Callable = lambda x: x
57+
map_to_runtime_buckets: Optional[Callable] = None
5558

5659

5760
@dataclass(slots=True, unsafe_hash=True)
@@ -389,22 +392,27 @@ def search_cache(
389392
runners: List[TunableRunner],
390393
input_shapes: Tuple[torch.Size],
391394
tuning_config: TuningConfig,
395+
use_tuning_mapping: bool = False,
392396
) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]:
393397
"""Search for cached profiling results matching the current configuration.
394398
395399
Args:
396400
custom_op (str): The name of the custom operation to be tuned
397401
runners (List[TunableRunner]): List of candidate implementations to profile
398402
profile (OptimizationProfile): Optimization profile
403+
use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
404+
If False, use map_to_runtime_buckets for runtime cache lookups.
399405
400406
Returns:
401407
A tuple containing:
402408
[is_cache_hit, runner_id, tactic, stored_profile]
403409
runner_id is the index in the current runners list
404410
"""
405411
for idx, r in enumerate(runners):
406-
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
407-
tuning_config)) in self.cache:
412+
if (cache_key :=
413+
self.get_cache_key(custom_op, r, input_shapes,
414+
tuning_config,
415+
use_tuning_mapping)) in self.cache:
408416
# Return the current index in runners list, not the cached runner_id
409417
cached_runner_id, tactic, min_time = self.cache[cache_key]
410418
return True, idx, tactic, min_time
@@ -417,6 +425,7 @@ def get_cache_key(
417425
runner: TunableRunner,
418426
input_shapes: Tuple[torch.Size],
419427
tuning_config: TuningConfig,
428+
use_tuning_mapping: bool = False,
420429
) -> Tuple:
421430
return (
422431
custom_op,
@@ -427,6 +436,7 @@ def get_cache_key(
427436
tuning_config.dynamic_tensor_specs,
428437
tuning_config.constraint_specs,
429438
tuning_config.tune_max_num_tokens,
439+
use_tuning_mapping,
430440
),
431441
)
432442

@@ -817,7 +827,12 @@ def choose_one(
817827
for p in profiles:
818828
tensors = self._prepare_input_tensors(p, inputs)
819829
is_cache_hit, *_ = self.profiling_cache.search_cache(
820-
custom_op, runners, p.get_opt_shapes(), tuning_config)
830+
custom_op,
831+
runners,
832+
p.get_opt_shapes(),
833+
tuning_config,
834+
use_tuning_mapping=True,
835+
)
821836
if not is_cache_hit:
822837
# Initialize runner and tactic as None in case of no valid tactic or runners are found
823838
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
@@ -904,8 +919,11 @@ def _profile_runners(
904919
# Record the failed profiling combinations
905920
self.stats.failed_profiling_count[custom_op].add(
906921
self.profiling_cache.get_cache_key(
907-
custom_op, runner, profile.get_opt_shapes(),
908-
tuning_config))
922+
custom_op,
923+
runner,
924+
profile.get_opt_shapes(),
925+
tuning_config,
926+
use_tuning_mapping=True))
909927

910928
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
911929
# or some runtime error occurs during profiling.
@@ -918,8 +936,11 @@ def _profile_runners(
918936
if best_runner_id is not None:
919937
# At least one valid (runner, tactic) pair is found
920938
cache_key = self.profiling_cache.get_cache_key(
921-
custom_op, runners[best_runner_id], profile.get_opt_shapes(),
922-
tuning_config)
939+
custom_op,
940+
runners[best_runner_id],
941+
profile.get_opt_shapes(),
942+
tuning_config,
943+
use_tuning_mapping=True)
923944

924945
self._debug_logger(
925946
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
@@ -1140,13 +1161,16 @@ def _find_nearest_profile(
11401161
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...],
11411162
constraint_specs: Tuple[ConstraintSpec, ...],
11421163
tune_max_num_tokens: int = None,
1164+
use_tuning_mapping: bool = False,
11431165
) -> Tuple:
11441166
"""Find the nearest optimization profile for given inputs
11451167
User can define their own nearest profile generation method to reduce the host overhead.
11461168
11471169
Args:
11481170
shapes: Tuple of input tensor shapes
11491171
tuning_config: Tuning configuration
1172+
use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1173+
If False, use map_to_runtime_buckets for runtime cache lookups.
11501174
11511175
Return:
11521176
Tuple: A tuple containing:
@@ -1156,9 +1180,12 @@ def _find_nearest_profile(
11561180
base_profile = list(list(shape) for shape in shapes)
11571181

11581182
for spec in dynamic_tensor_specs:
1159-
base_profile[spec.input_idx][
1160-
spec.dim_idx] = spec.map_to_tuning_buckets(
1161-
base_profile[spec.input_idx][spec.dim_idx])
1183+
1184+
bucket_mapper = spec.map_to_tuning_buckets
1185+
if not use_tuning_mapping and spec.map_to_runtime_buckets is not None:
1186+
bucket_mapper = spec.map_to_runtime_buckets
1187+
base_profile[spec.input_idx][spec.dim_idx] = bucket_mapper(
1188+
base_profile[spec.input_idx][spec.dim_idx])
11621189

11631190
if tune_max_num_tokens is not None:
11641191
base_profile[spec.input_idx][spec.dim_idx] = min(

0 commit comments

Comments
 (0)