Skip to content

Commit 4121b67

Browse files
committed
[breaking] map_to_tuning_buckets maps input during inference only
Signed-off-by: Anthony Chang <[email protected]>
1 parent 0c27f02 commit 4121b67

File tree

3 files changed

+86
-98
lines changed

3 files changed

+86
-98
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,12 @@ class DynamicTensorSpec:
4949
input_idx: The index of the input tensor.
5050
dim_idx: The index of the dimension to tune.
5151
gen_tuning_buckets: A tuple of values to try or a function generating values.
52-
map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
53-
map_to_runtime_buckets: A function to map dimensions to valid values during inference.
54-
If None, use map_to_tuning_buckets.
52+
map_to_tuning_buckets: A function to map dimensions to tuning buckets during inference.
5553
"""
5654
input_idx: int
5755
dim_idx: int
5856
gen_tuning_buckets: Union[Tuple[int], Callable] = ()
5957
map_to_tuning_buckets: Callable = lambda x: x
60-
map_to_runtime_buckets: Optional[Callable] = None
6158

6259

6360
@dataclass(slots=True, unsafe_hash=True)
@@ -84,7 +81,7 @@ class TuningConfig:
8481
should be tuned to optimize performance. Each spec defines:
8582
- Which input tensor dimension is dynamic
8683
- How to generate tuning values
87-
- How to map dimensions to valid values during inference
84+
- How to map dimensions to tuning values during inference
8885
8986
Example:
9087
>>> config = TuningConfig(
@@ -395,27 +392,26 @@ def search_cache(
395392
runners: List[TunableRunner],
396393
input_shapes: Tuple[torch.Size],
397394
tuning_config: TuningConfig,
398-
use_tuning_mapping: bool = False,
395+
apply_map_to_tuning_buckets: bool = True,
399396
) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]:
400397
"""Search for cached profiling results matching the current configuration.
401398
402399
Args:
403400
custom_op (str): The name of the custom operation to be tuned
404401
runners (List[TunableRunner]): List of candidate implementations to profile
405402
profile (OptimizationProfile): Optimization profile
406-
use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
407-
If False, use map_to_runtime_buckets for runtime cache lookups.
403+
apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups.
404+
If False, use raw bucket values for tuning cache storage.
408405
409406
Returns:
410407
A tuple containing:
411408
[is_cache_hit, runner_id, tactic, stored_profile]
412409
runner_id is the index in the current runners list
413410
"""
414411
for idx, r in enumerate(runners):
415-
if (cache_key :=
416-
self.get_cache_key(custom_op, r, input_shapes,
417-
tuning_config,
418-
use_tuning_mapping)) in self.cache:
412+
if (cache_key := self.get_cache_key(
413+
custom_op, r, input_shapes, tuning_config,
414+
apply_map_to_tuning_buckets)) in self.cache:
419415
# Return the current index in runners list, not the cached runner_id
420416
cached_runner_id, tactic, min_time = self.cache[cache_key]
421417
return True, idx, tactic, min_time
@@ -428,7 +424,7 @@ def get_cache_key(
428424
runner: TunableRunner,
429425
input_shapes: Tuple[torch.Size],
430426
tuning_config: TuningConfig,
431-
use_tuning_mapping: bool = False,
427+
apply_map_to_tuning_buckets: bool = True,
432428
) -> Tuple:
433429
return (
434430
custom_op,
@@ -439,7 +435,7 @@ def get_cache_key(
439435
tuning_config.dynamic_tensor_specs,
440436
tuning_config.constraint_specs,
441437
tuning_config.tune_max_num_tokens,
442-
use_tuning_mapping,
438+
apply_map_to_tuning_buckets,
443439
),
444440
)
445441

@@ -805,7 +801,11 @@ def choose_one(
805801

806802
input_shapes = tuple(self._get_input_sizes(inputs))
807803
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
808-
custom_op, runners, input_shapes, tuning_config)
804+
custom_op,
805+
runners,
806+
input_shapes,
807+
tuning_config,
808+
apply_map_to_tuning_buckets=True)
809809

810810
# Early return if it's not tuning, use cache found one or fallback one
811811
if not self.is_tuning_mode:
@@ -855,7 +855,7 @@ def choose_one(
855855
runners,
856856
p.get_opt_shapes(),
857857
tuning_config,
858-
use_tuning_mapping=True,
858+
apply_map_to_tuning_buckets=False,
859859
)
860860
if not is_cache_hit:
861861
# Initialize runner and tactic as None in case of no valid tactic or runners are found
@@ -947,7 +947,7 @@ def _profile_runners(
947947
runner,
948948
profile.get_opt_shapes(),
949949
tuning_config,
950-
use_tuning_mapping=True))
950+
apply_map_to_tuning_buckets=False))
951951

952952
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
953953
# or some runtime error occurs during profiling.
@@ -964,7 +964,7 @@ def _profile_runners(
964964
runners[best_runner_id],
965965
profile.get_opt_shapes(),
966966
tuning_config,
967-
use_tuning_mapping=True)
967+
apply_map_to_tuning_buckets=False)
968968

969969
self._debug_logger(
970970
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
@@ -1141,8 +1141,7 @@ def _optimization_profiles(
11411141
# Add the current input value as one of the opt values
11421142
opt_shapes = set(opt_shapes)
11431143
opt_shapes.add(
1144-
spec.map_to_tuning_buckets(
1145-
base_profile.shapes[spec.input_idx][spec.dim_idx].val))
1144+
base_profile.shapes[spec.input_idx][spec.dim_idx].val)
11461145
opt_shapes = sorted(list(opt_shapes))
11471146
opt_shapes_max = tuple(opt_shapes[1:]) + (float('inf'), )
11481147
opt_shapes_max = {
@@ -1185,16 +1184,16 @@ def _find_nearest_profile(
11851184
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...],
11861185
constraint_specs: Tuple[ConstraintSpec, ...],
11871186
tune_max_num_tokens: int = None,
1188-
use_tuning_mapping: bool = False,
1187+
apply_map_to_tuning_buckets: bool = True,
11891188
) -> Tuple:
11901189
"""Find the nearest optimization profile for given inputs
11911190
User can define their own nearest profile generation method to reduce the host overhead.
11921191
11931192
Args:
11941193
shapes: Tuple of input tensor shapes
11951194
tuning_config: Tuning configuration
1196-
use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1197-
If False, use map_to_runtime_buckets for runtime cache lookups.
1195+
apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups.
1196+
If False, use raw bucket values for tuning cache storage.
11981197
11991198
Return:
12001199
Tuple: A tuple containing:
@@ -1204,12 +1203,12 @@ def _find_nearest_profile(
12041203
base_profile = list(list(shape) for shape in shapes)
12051204

12061205
for spec in dynamic_tensor_specs:
1207-
1208-
bucket_mapper = spec.map_to_tuning_buckets
1209-
if not use_tuning_mapping and spec.map_to_runtime_buckets is not None:
1210-
bucket_mapper = spec.map_to_runtime_buckets
1211-
base_profile[spec.input_idx][spec.dim_idx] = bucket_mapper(
1212-
base_profile[spec.input_idx][spec.dim_idx])
1206+
# During runtime: apply map_to_tuning_buckets to map input to bucket
1207+
# During tuning: no mapper, use raw bucket value
1208+
if apply_map_to_tuning_buckets:
1209+
base_profile[spec.input_idx][
1210+
spec.dim_idx] = spec.map_to_tuning_buckets(
1211+
base_profile[spec.input_idx][spec.dim_idx])
12131212

12141213
if tune_max_num_tokens is not None:
12151214
base_profile[spec.input_idx][spec.dim_idx] = min(

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -277,16 +277,14 @@ def get_dynamic_tensor_specs(cls,
277277

278278
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
279279

280-
def round_rule(x: int, ep_size_: int) -> int:
281-
value = last_positive_power_of_2(x) // ep_size_
280+
def round_rule(x: int) -> int:
281+
value = last_positive_power_of_2(x) // ep_size
282282
return min(max(1, value), MAX_PROFILE_BUCKET)
283283

284-
specs = (DynamicTensorSpec(
285-
HIDDEN_STATES_IDX,
286-
TUNED_DIM,
287-
m_values,
288-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
289-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
284+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
285+
TUNED_DIM,
286+
m_values,
287+
map_to_tuning_buckets=round_rule), )
290288

291289
return specs
292290

@@ -623,16 +621,14 @@ def get_dynamic_tensor_specs(cls,
623621

624622
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
625623

626-
def round_rule(x: int, ep_size_: int) -> int:
627-
value = last_positive_power_of_2(x) // ep_size_
624+
def round_rule(x: int) -> int:
625+
value = last_positive_power_of_2(x) // ep_size
628626
return min(max(1, value), MAX_PROFILE_BUCKET)
629627

630-
specs = (DynamicTensorSpec(
631-
HIDDEN_STATES_IDX,
632-
TUNED_DIM,
633-
m_values,
634-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
635-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
628+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
629+
TUNED_DIM,
630+
m_values,
631+
map_to_tuning_buckets=round_rule), )
636632

637633
return specs
638634

@@ -918,16 +914,14 @@ def get_dynamic_tensor_specs(cls,
918914

919915
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
920916

921-
def round_rule(x: int, ep_size_: int) -> int:
922-
value = last_positive_power_of_2(x) // ep_size_
917+
def round_rule(x: int) -> int:
918+
value = last_positive_power_of_2(x) // ep_size
923919
return min(max(1, value), MAX_PROFILE_BUCKET)
924920

925-
specs = (DynamicTensorSpec(
926-
HIDDEN_STATES_IDX,
927-
TUNED_DIM,
928-
m_values,
929-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
930-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
921+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
922+
TUNED_DIM,
923+
m_values,
924+
map_to_tuning_buckets=round_rule), )
931925

932926
return specs
933927

@@ -1218,16 +1212,14 @@ def get_dynamic_tensor_specs(cls,
12181212

12191213
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
12201214

1221-
def round_rule(x: int, ep_size_: int) -> int:
1222-
value = last_positive_power_of_2(x) // ep_size_
1215+
def round_rule(x: int) -> int:
1216+
value = last_positive_power_of_2(x) // ep_size
12231217
return min(max(1, value), MAX_PROFILE_BUCKET)
12241218

1225-
specs = (DynamicTensorSpec(
1226-
HIDDEN_STATES_IDX,
1227-
TUNED_DIM,
1228-
m_values,
1229-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1230-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1219+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
1220+
TUNED_DIM,
1221+
m_values,
1222+
map_to_tuning_buckets=round_rule), )
12311223

12321224
return specs
12331225

@@ -1496,16 +1488,14 @@ def get_dynamic_tensor_specs(cls,
14961488

14971489
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
14981490

1499-
def round_rule(x: int, ep_size_: int) -> int:
1500-
value = last_positive_power_of_2(x) // ep_size_
1491+
def round_rule(x: int) -> int:
1492+
value = last_positive_power_of_2(x) // ep_size
15011493
return min(max(1, value), MAX_PROFILE_BUCKET)
15021494

1503-
specs = (DynamicTensorSpec(
1504-
HIDDEN_STATES_IDX,
1505-
TUNED_DIM,
1506-
m_values,
1507-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1508-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1495+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
1496+
TUNED_DIM,
1497+
m_values,
1498+
map_to_tuning_buckets=round_rule), )
15091499

15101500
return specs
15111501

@@ -1759,16 +1749,14 @@ def get_dynamic_tensor_specs(cls,
17591749

17601750
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
17611751

1762-
def round_rule(x: int, ep_size_: int) -> int:
1763-
value = last_positive_power_of_2(x) // ep_size_
1752+
def round_rule(x: int) -> int:
1753+
value = last_positive_power_of_2(x) // ep_size
17641754
return min(max(1, value), MAX_PROFILE_BUCKET)
17651755

1766-
specs = (DynamicTensorSpec(
1767-
HIDDEN_STATES_IDX,
1768-
TUNED_DIM,
1769-
m_values,
1770-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1771-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1756+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX,
1757+
TUNED_DIM,
1758+
m_values,
1759+
map_to_tuning_buckets=round_rule), )
17721760

17731761
return specs
17741762

0 commit comments

Comments
 (0)