Skip to content

Commit fb30e65

Browse files
committed
[breaking] map_to_tuning_buckets maps input during inference only
Signed-off-by: Anthony Chang <[email protected]>
1 parent 89024d0 commit fb30e65

File tree

3 files changed

+86
-98
lines changed

3 files changed

+86
-98
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,12 @@ class DynamicTensorSpec:
4646
input_idx: The index of the input tensor.
4747
dim_idx: The index of the dimension to tune.
4848
gen_tuning_buckets: A tuple of values to try or a function generating values.
49-
map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
50-
map_to_runtime_buckets: A function to map dimensions to valid values during inference.
51-
If None, use map_to_tuning_buckets.
49+
map_to_tuning_buckets: A function to map dimensions to tuning buckets during inference.
5250
"""
5351
input_idx: int
5452
dim_idx: int
5553
gen_tuning_buckets: Union[Tuple[int], Callable] = ()
5654
map_to_tuning_buckets: Callable = lambda x: x
57-
map_to_runtime_buckets: Optional[Callable] = None
5855

5956

6057
@dataclass(slots=True, unsafe_hash=True)
@@ -81,7 +78,7 @@ class TuningConfig:
8178
should be tuned to optimize performance. Each spec defines:
8279
- Which input tensor dimension is dynamic
8380
- How to generate tuning values
84-
- How to map dimensions to valid values during inference
81+
- How to map dimensions to tuning values during inference
8582
8683
Example:
8784
>>> config = TuningConfig(
@@ -392,27 +389,26 @@ def search_cache(
392389
runners: List[TunableRunner],
393390
input_shapes: Tuple[torch.Size],
394391
tuning_config: TuningConfig,
395-
use_tuning_mapping: bool = False,
392+
apply_map_to_tuning_buckets: bool = True,
396393
) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]:
397394
"""Search for cached profiling results matching the current configuration.
398395
399396
Args:
400397
custom_op (str): The name of the custom operation to be tuned
401398
runners (List[TunableRunner]): List of candidate implementations to profile
402399
profile (OptimizationProfile): Optimization profile
403-
use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
404-
If False, use map_to_runtime_buckets for runtime cache lookups.
400+
apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups.
401+
If False, use raw bucket values for tuning cache storage.
405402
406403
Returns:
407404
A tuple containing:
408405
[is_cache_hit, runner_id, tactic, stored_profile]
409406
runner_id is the index in the current runners list
410407
"""
411408
for idx, r in enumerate(runners):
412-
if (cache_key :=
413-
self.get_cache_key(custom_op, r, input_shapes,
414-
tuning_config,
415-
use_tuning_mapping)) in self.cache:
409+
if (cache_key := self.get_cache_key(
410+
custom_op, r, input_shapes, tuning_config,
411+
apply_map_to_tuning_buckets)) in self.cache:
416412
# Return the current index in runners list, not the cached runner_id
417413
cached_runner_id, tactic, min_time = self.cache[cache_key]
418414
return True, idx, tactic, min_time
@@ -425,7 +421,7 @@ def get_cache_key(
425421
runner: TunableRunner,
426422
input_shapes: Tuple[torch.Size],
427423
tuning_config: TuningConfig,
428-
use_tuning_mapping: bool = False,
424+
apply_map_to_tuning_buckets: bool = True,
429425
) -> Tuple:
430426
return (
431427
custom_op,
@@ -436,7 +432,7 @@ def get_cache_key(
436432
tuning_config.dynamic_tensor_specs,
437433
tuning_config.constraint_specs,
438434
tuning_config.tune_max_num_tokens,
439-
use_tuning_mapping,
435+
apply_map_to_tuning_buckets,
440436
),
441437
)
442438

@@ -789,7 +785,11 @@ def choose_one(
789785

790786
input_shapes = tuple(self._get_input_sizes(inputs))
791787
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
792-
custom_op, runners, input_shapes, tuning_config)
788+
custom_op,
789+
runners,
790+
input_shapes,
791+
tuning_config,
792+
apply_map_to_tuning_buckets=True)
793793

794794
# Early return if it's not tuning, use cache found one or fallback one
795795
if not self.is_tuning_mode:
@@ -831,7 +831,7 @@ def choose_one(
831831
runners,
832832
p.get_opt_shapes(),
833833
tuning_config,
834-
use_tuning_mapping=True,
834+
apply_map_to_tuning_buckets=False,
835835
)
836836
if not is_cache_hit:
837837
# Initialize runner and tactic as None in case of no valid tactic or runners are found
@@ -923,7 +923,7 @@ def _profile_runners(
923923
runner,
924924
profile.get_opt_shapes(),
925925
tuning_config,
926-
use_tuning_mapping=True))
926+
apply_map_to_tuning_buckets=False))
927927

928928
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
929929
# or some runtime error occurs during profiling.
@@ -940,7 +940,7 @@ def _profile_runners(
940940
runners[best_runner_id],
941941
profile.get_opt_shapes(),
942942
tuning_config,
943-
use_tuning_mapping=True)
943+
apply_map_to_tuning_buckets=False)
944944

945945
self._debug_logger(
946946
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
@@ -1117,8 +1117,7 @@ def _optimization_profiles(
11171117
# Add the current input value as one of the opt values
11181118
opt_shapes = set(opt_shapes)
11191119
opt_shapes.add(
1120-
spec.map_to_tuning_buckets(
1121-
base_profile.shapes[spec.input_idx][spec.dim_idx].val))
1120+
base_profile.shapes[spec.input_idx][spec.dim_idx].val)
11221121
opt_shapes = sorted(list(opt_shapes))
11231122
opt_shapes_max = tuple(opt_shapes[1:]) + (float('inf'), )
11241123
opt_shapes_max = {
@@ -1161,16 +1160,16 @@ def _find_nearest_profile(
11611160
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...],
11621161
constraint_specs: Tuple[ConstraintSpec, ...],
11631162
tune_max_num_tokens: int = None,
1164-
use_tuning_mapping: bool = False,
1163+
apply_map_to_tuning_buckets: bool = True,
11651164
) -> Tuple:
11661165
"""Find the nearest optimization profile for given inputs
11671166
User can define their own nearest profile generation method to reduce the host overhead.
11681167
11691168
Args:
11701169
shapes: Tuple of input tensor shapes
11711170
tuning_config: Tuning configuration
1172-
use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1173-
If False, use map_to_runtime_buckets for runtime cache lookups.
1171+
apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups.
1172+
If False, use raw bucket values for tuning cache storage.
11741173
11751174
Return:
11761175
Tuple: A tuple containing:
@@ -1180,12 +1179,12 @@ def _find_nearest_profile(
11801179
base_profile = list(list(shape) for shape in shapes)
11811180

11821181
for spec in dynamic_tensor_specs:
1183-
1184-
bucket_mapper = spec.map_to_tuning_buckets
1185-
if not use_tuning_mapping and spec.map_to_runtime_buckets is not None:
1186-
bucket_mapper = spec.map_to_runtime_buckets
1187-
base_profile[spec.input_idx][spec.dim_idx] = bucket_mapper(
1188-
base_profile[spec.input_idx][spec.dim_idx])
1182+
# During runtime: apply map_to_tuning_buckets to map input to bucket
1183+
# During tuning: no mapper, use raw bucket value
1184+
if apply_map_to_tuning_buckets:
1185+
base_profile[spec.input_idx][
1186+
spec.dim_idx] = spec.map_to_tuning_buckets(
1187+
base_profile[spec.input_idx][spec.dim_idx])
11891188

11901189
if tune_max_num_tokens is not None:
11911190
base_profile[spec.input_idx][spec.dim_idx] = min(

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -277,16 +277,14 @@ def get_dynamic_tensor_specs(cls,
277277

278278
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
279279

280-
def round_rule(x: int, ep_size_: int) -> int:
281-
value = last_positive_power_of_2(x) // ep_size_
280+
def round_rule(x: int) -> int:
281+
value = last_positive_power_of_2(x) // ep_size
282282
return min(max(1, value), MAX_PROFILE_BUCKET)
283283

284-
specs = (DynamicTensorSpec(
285-
HIDDEN_STATES_IDX,
286-
TUNED_DIM,
287-
m_values,
288-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
289-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
284+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
285+
TUNED_DIM,
286+
m_values,
287+
map_to_tuning_buckets=round_rule), )
290288

291289
return specs
292290

@@ -623,16 +621,14 @@ def get_dynamic_tensor_specs(cls,
623621

624622
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
625623

626-
def round_rule(x: int, ep_size_: int) -> int:
627-
value = last_positive_power_of_2(x) // ep_size_
624+
def round_rule(x: int) -> int:
625+
value = last_positive_power_of_2(x) // ep_size
628626
return min(max(1, value), MAX_PROFILE_BUCKET)
629627

630-
specs = (DynamicTensorSpec(
631-
HIDDEN_STATES_IDX,
632-
TUNED_DIM,
633-
m_values,
634-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
635-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
628+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
629+
TUNED_DIM,
630+
m_values,
631+
map_to_tuning_buckets=round_rule), )
636632

637633
return specs
638634

@@ -918,16 +914,14 @@ def get_dynamic_tensor_specs(cls,
918914

919915
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
920916

921-
def round_rule(x: int, ep_size_: int) -> int:
922-
value = last_positive_power_of_2(x) // ep_size_
917+
def round_rule(x: int) -> int:
918+
value = last_positive_power_of_2(x) // ep_size
923919
return min(max(1, value), MAX_PROFILE_BUCKET)
924920

925-
specs = (DynamicTensorSpec(
926-
HIDDEN_STATES_IDX,
927-
TUNED_DIM,
928-
m_values,
929-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
930-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
921+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
922+
TUNED_DIM,
923+
m_values,
924+
map_to_tuning_buckets=round_rule), )
931925

932926
return specs
933927

@@ -1218,16 +1212,14 @@ def get_dynamic_tensor_specs(cls,
12181212

12191213
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
12201214

1221-
def round_rule(x: int, ep_size_: int) -> int:
1222-
value = last_positive_power_of_2(x) // ep_size_
1215+
def round_rule(x: int) -> int:
1216+
value = last_positive_power_of_2(x) // ep_size
12231217
return min(max(1, value), MAX_PROFILE_BUCKET)
12241218

1225-
specs = (DynamicTensorSpec(
1226-
HIDDEN_STATES_IDX,
1227-
TUNED_DIM,
1228-
m_values,
1229-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1230-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1219+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
1220+
TUNED_DIM,
1221+
m_values,
1222+
map_to_tuning_buckets=round_rule), )
12311223

12321224
return specs
12331225

@@ -1496,16 +1488,14 @@ def get_dynamic_tensor_specs(cls,
14961488

14971489
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
14981490

1499-
def round_rule(x: int, ep_size_: int) -> int:
1500-
value = last_positive_power_of_2(x) // ep_size_
1491+
def round_rule(x: int) -> int:
1492+
value = last_positive_power_of_2(x) // ep_size
15011493
return min(max(1, value), MAX_PROFILE_BUCKET)
15021494

1503-
specs = (DynamicTensorSpec(
1504-
HIDDEN_STATES_IDX,
1505-
TUNED_DIM,
1506-
m_values,
1507-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1508-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1495+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
1496+
TUNED_DIM,
1497+
m_values,
1498+
map_to_tuning_buckets=round_rule), )
15091499

15101500
return specs
15111501

@@ -1759,16 +1749,14 @@ def get_dynamic_tensor_specs(cls,
17591749

17601750
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
17611751

1762-
def round_rule(x: int, ep_size_: int) -> int:
1763-
value = last_positive_power_of_2(x) // ep_size_
1752+
def round_rule(x: int) -> int:
1753+
value = last_positive_power_of_2(x) // ep_size
17641754
return min(max(1, value), MAX_PROFILE_BUCKET)
17651755

1766-
specs = (DynamicTensorSpec(
1767-
HIDDEN_STATES_IDX,
1768-
TUNED_DIM,
1769-
m_values,
1770-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1771-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1756+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
1757+
TUNED_DIM,
1758+
m_values,
1759+
map_to_tuning_buckets=round_rule), )
17721760

17731761
return specs
17741762

0 commit comments

Comments
 (0)