Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class DynamicTensorSpec:
input_idx: The index of the input tensor.
dim_idx: The index of the dimension to tune.
gen_tuning_buckets: A tuple of values to try or a function generating values.
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
map_to_tuning_buckets: A function to map dimensions to tuning buckets during inference.
"""
input_idx: int
dim_idx: int
Expand Down Expand Up @@ -81,7 +81,7 @@ class TuningConfig:
should be tuned to optimize performance. Each spec defines:
- Which input tensor dimension is dynamic
- How to generate tuning values
- How to map dimensions to valid values during inference
- How to map dimensions to tuning values during inference

Example:
>>> config = TuningConfig(
Expand Down Expand Up @@ -392,22 +392,26 @@ def search_cache(
runners: List[TunableRunner],
input_shapes: Tuple[torch.Size],
tuning_config: TuningConfig,
apply_map_to_tuning_buckets: bool = True,
) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]:
"""Search for cached profiling results matching the current configuration.

Args:
custom_op (str): The name of the custom operation to be tuned
runners (List[TunableRunner]): List of candidate implementations to profile
profile (OptimizationProfile): Optimization profile
apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups.
If False, use raw bucket values for tuning cache storage.

Returns:
A tuple containing:
[is_cache_hit, runner_id, tactic, stored_profile]
runner_id is the index in the current runners list
"""
for idx, r in enumerate(runners):
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
tuning_config)) in self.cache:
if (cache_key := self.get_cache_key(
custom_op, r, input_shapes, tuning_config,
apply_map_to_tuning_buckets)) in self.cache:
# Return the current index in runners list, not the cached runner_id
cached_runner_id, tactic, min_time = self.cache[cache_key]
return True, idx, tactic, min_time
Expand All @@ -420,6 +424,7 @@ def get_cache_key(
runner: TunableRunner,
input_shapes: Tuple[torch.Size],
tuning_config: TuningConfig,
apply_map_to_tuning_buckets: bool = True,
) -> Tuple:
return (
custom_op,
Expand All @@ -430,6 +435,7 @@ def get_cache_key(
tuning_config.dynamic_tensor_specs,
tuning_config.constraint_specs,
tuning_config.tune_max_num_tokens,
apply_map_to_tuning_buckets,
),
)

Expand Down Expand Up @@ -795,7 +801,11 @@ def choose_one(

input_shapes = tuple(self._get_input_sizes(inputs))
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
custom_op, runners, input_shapes, tuning_config)
custom_op,
runners,
input_shapes,
tuning_config,
apply_map_to_tuning_buckets=True)

# Early return if it's not tuning, use cache found one or fallback one
if not self.is_tuning_mode:
Expand Down Expand Up @@ -841,7 +851,12 @@ def choose_one(
for p in profiles:
tensors = self._prepare_input_tensors(p, inputs)
is_cache_hit, *_ = self.profiling_cache.search_cache(
custom_op, runners, p.get_opt_shapes(), tuning_config)
custom_op,
runners,
p.get_opt_shapes(),
tuning_config,
apply_map_to_tuning_buckets=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rosenrodt , For TRTLLM Gen MoE, why should we disable apply_map_to_tuning_buckets when do autotuning? Does this affect other operators?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion with @hyukn , I understand the case for TRTLLM Gen now. The problem is that

  • When do autotuning, the per-rank workload is "full", assuming all tokens activate experts at the local rank
  • When do inference, the per-rank workload is approximated by full_workload/ep_size

Normally, we close this gap by inputs_hook, which modifies the inputs when do autotuning. Specific to your case, you can modify the inputs so that the workload is divided by ep_size.

You may refer to the CuteDSL implementation:

def generate_num_tokens_per_expert(self, num_tokens: int) -> List[int]:
average_num_tokens_per_expert = num_tokens * self.top_k / self.num_experts
balance = 0
num_tokens_per_expert = []
for i in range(self.num_local_experts):
balance += average_num_tokens_per_expert
if balance <= 1e-3:
continue
curr_num_tokens = int(balance) + 1
num_tokens_per_expert.append(curr_num_tokens)
balance -= curr_num_tokens
return num_tokens_per_expert

Currently, this PR introduces inconsistency between the autotuning and inference shapes, which is a bit concerning.

Copy link
Collaborator

@hyukn hyukn Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@syuoni provides a better option here. Thanks a lot for the suggestion!

Current process of assembling autotuner cache key is:

  • Generating tuning_buckets as a list of profiles
  • Generate dummy input tensor according to the shape information in the profiles.
  • Apply input_pre_hook on the input tensor, and use inputs after pre_hook as the inputs for the runner.forward.
  • Generate cache key with map_to_tuning_bucket method.

By defining input_pre_hook, we always generate the tensors with the shapes corresponding to the correct workloads for runner.forward. And the shapes stored in the cache remain to be the original bucket shapes (before input_pre_hook). This means we can also keep map_to_tuning_bucket to a simple bucket mapping method instead of dividing it with ep_size to adjust the workload model.

This actuall extends the usage of input_pre_hook (originally I only want to use it to manipulate the tensor data), but it trully works. I think we should also revise the docstring to clarify this usage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@syuoni I fully agree with the input_pre_hook() approach used in cuteDSL but I think it's not directly applicable to TRTLLM MoE.

Let's look at the two main changes in this PR:

First, map_to_tuning_bucket() should not be applied during tuning and this PR addresses that by applying it only during inference. Do you agree that we should keep this change, @syuoni?

  • The autotuner tunes the buckets coming solely from gen_tuning_buckets without involving map_to_tuning_bucket(). The map_to_tuning_bucket() then maps the buckets to cache keys which is not the intended behavior as discussed with @hyukn.
  • This change should not affect the existing ops.

Second—this is the controversial part—the TRTLLM MoE repurposes map_to_tuning_bucket() to account for workload sparsity in a convenient/confusing way depending on how you look at it. Long story short is routing information is not exposed in TRTLLM MoE interface and would require rewrites to fully adopt CuteDSL's approach. I would suggest we defer your suggested approach to a later PR if that's necessary. @syuoni @hyukn let me know what you think :D

  • The approach in CuteDSL MoE is sensible. CuteDSL MoE as a module appears to accept routing information at more granular level (tile_idx_to_group_idx, tile_idx_to_mn_limit, etc.)
  • TRTLLM MoE accepts either routing_logits or topk_id/top_weights and computes the routing information internally.
  • TRTLLM MoE is not able to adopt the same approach without a lot of rewrites around how TRTLLM MoE interacts with the caller.

)
if not is_cache_hit:
# Initialize runner and tactic as None in case of no valid tactic or runners are found
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
Expand Down Expand Up @@ -928,8 +943,11 @@ def _profile_runners(
# Record the failed profiling combinations
self.stats.failed_profiling_count[custom_op].add(
self.profiling_cache.get_cache_key(
custom_op, runner, profile.get_opt_shapes(),
tuning_config))
custom_op,
runner,
profile.get_opt_shapes(),
tuning_config,
apply_map_to_tuning_buckets=False))

# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
# or some runtime error occurs during profiling.
Expand All @@ -942,8 +960,11 @@ def _profile_runners(
if best_runner_id is not None:
# At least one valid (runner, tactic) pair is found
cache_key = self.profiling_cache.get_cache_key(
custom_op, runners[best_runner_id], profile.get_opt_shapes(),
tuning_config)
custom_op,
runners[best_runner_id],
profile.get_opt_shapes(),
tuning_config,
apply_map_to_tuning_buckets=False)

self._debug_logger(
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
Expand Down Expand Up @@ -1119,9 +1140,15 @@ def _optimization_profiles(
opt_shapes = spec.gen_tuning_buckets
# Add the current input value as one of the opt values
opt_shapes = set(opt_shapes)
opt_shapes.add(
spec.map_to_tuning_buckets(
base_profile.shapes[spec.input_idx][spec.dim_idx].val))
if tuning_config.tune_max_num_tokens is not None:
opt_shapes.add(
min(
tuning_config.tune_max_num_tokens,
base_profile.shapes[spec.input_idx][spec.dim_idx].val,
))
else:
opt_shapes.add(
base_profile.shapes[spec.input_idx][spec.dim_idx].val)
opt_shapes = sorted(list(opt_shapes))
opt_shapes_max = tuple(opt_shapes[1:]) + (float('inf'), )
opt_shapes_max = {
Expand Down Expand Up @@ -1164,13 +1191,16 @@ def _find_nearest_profile(
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...],
constraint_specs: Tuple[ConstraintSpec, ...],
tune_max_num_tokens: int = None,
apply_map_to_tuning_buckets: bool = True,
) -> Tuple:
"""Find the nearest optimization profile for given inputs
User can define their own nearest profile generation method to reduce the host overhead.

Args:
shapes: Tuple of input tensor shapes
tuning_config: Tuning configuration
apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups.
If False, use raw bucket values for tuning cache storage.

Return:
Tuple: A tuple containing:
Expand All @@ -1180,9 +1210,12 @@ def _find_nearest_profile(
base_profile = list(list(shape) for shape in shapes)

for spec in dynamic_tensor_specs:
base_profile[spec.input_idx][
spec.dim_idx] = spec.map_to_tuning_buckets(
base_profile[spec.input_idx][spec.dim_idx])
# During runtime: apply map_to_tuning_buckets to map input to bucket
# During tuning: no mapper, use raw bucket value
if apply_map_to_tuning_buckets:
base_profile[spec.input_idx][
spec.dim_idx] = spec.map_to_tuning_buckets(
base_profile[spec.input_idx][spec.dim_idx])

if tune_max_num_tokens is not None:
base_profile[spec.input_idx][spec.dim_idx] = min(
Expand Down
Loading