Skip to content

Commit 901933d

Browse files
committed
use single bucket mapper instead; map down tuning bucket according to expected fill rate
Signed-off-by: Anthony Chang <[email protected]>
1 parent 4bb0114 commit 901933d

File tree

3 files changed

+59
-174
lines changed

3 files changed

+59
-174
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,12 @@ class DynamicTensorSpec:
4646
input_idx: The index of the input tensor.
4747
dim_idx: The index of the dimension to tune.
4848
gen_tuning_buckets: A tuple of values to try or a function generating values.
49-
map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
50-
map_to_runtime_buckets: A function to map dimensions to valid values during inference.
51-
If None, use map_to_tuning_buckets.
49+
map_to_tuning_buckets: A function to map dimensions to valid values during inference.
5250
"""
5351
input_idx: int
5452
dim_idx: int
5553
gen_tuning_buckets: Union[Tuple[int], Callable] = ()
5654
map_to_tuning_buckets: Callable = lambda x: x
57-
map_to_runtime_buckets: Optional[Callable] = None
5855

5956

6057
@dataclass(slots=True, unsafe_hash=True)
@@ -392,27 +389,22 @@ def search_cache(
392389
runners: List[TunableRunner],
393390
input_shapes: Tuple[torch.Size],
394391
tuning_config: TuningConfig,
395-
use_tuning_mapping: bool = False,
396392
) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]:
397393
"""Search for cached profiling results matching the current configuration.
398394
399395
Args:
400396
custom_op (str): The name of the custom operation to be tuned
401397
runners (List[TunableRunner]): List of candidate implementations to profile
402398
profile (OptimizationProfile): Optimization profile
403-
use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
404-
If False, use map_to_runtime_buckets for runtime cache lookups.
405399
406400
Returns:
407401
A tuple containing:
408402
[is_cache_hit, runner_id, tactic, stored_profile]
409403
runner_id is the index in the current runners list
410404
"""
411405
for idx, r in enumerate(runners):
412-
if (cache_key :=
413-
self.get_cache_key(custom_op, r, input_shapes,
414-
tuning_config,
415-
use_tuning_mapping)) in self.cache:
406+
if (cache_key := self.get_cache_key(custom_op, r, input_shapes,
407+
tuning_config)) in self.cache:
416408
# Return the current index in runners list, not the cached runner_id
417409
cached_runner_id, tactic, min_time = self.cache[cache_key]
418410
return True, idx, tactic, min_time
@@ -425,7 +417,6 @@ def get_cache_key(
425417
runner: TunableRunner,
426418
input_shapes: Tuple[torch.Size],
427419
tuning_config: TuningConfig,
428-
use_tuning_mapping: bool = False,
429420
) -> Tuple:
430421
return (
431422
custom_op,
@@ -436,7 +427,6 @@ def get_cache_key(
436427
tuning_config.dynamic_tensor_specs,
437428
tuning_config.constraint_specs,
438429
tuning_config.tune_max_num_tokens,
439-
use_tuning_mapping,
440430
),
441431
)
442432

@@ -827,12 +817,7 @@ def choose_one(
827817
for p in profiles:
828818
tensors = self._prepare_input_tensors(p, inputs)
829819
is_cache_hit, *_ = self.profiling_cache.search_cache(
830-
custom_op,
831-
runners,
832-
p.get_opt_shapes(),
833-
tuning_config,
834-
use_tuning_mapping=True,
835-
)
820+
custom_op, runners, p.get_opt_shapes(), tuning_config)
836821
if not is_cache_hit:
837822
# Initialize runner and tactic as None in case of no valid tactic or runners are found
838823
best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners(
@@ -919,11 +904,8 @@ def _profile_runners(
919904
# Record the failed profiling combinations
920905
self.stats.failed_profiling_count[custom_op].add(
921906
self.profiling_cache.get_cache_key(
922-
custom_op,
923-
runner,
924-
profile.get_opt_shapes(),
925-
tuning_config,
926-
use_tuning_mapping=True))
907+
custom_op, runner, profile.get_opt_shapes(),
908+
tuning_config))
927909

928910
# Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
929911
# or some runtime error occurs during profiling.
@@ -936,11 +918,8 @@ def _profile_runners(
936918
if best_runner_id is not None:
937919
# At least one valid (runner, tactic) pair is found
938920
cache_key = self.profiling_cache.get_cache_key(
939-
custom_op,
940-
runners[best_runner_id],
941-
profile.get_opt_shapes(),
942-
tuning_config,
943-
use_tuning_mapping=True)
921+
custom_op, runners[best_runner_id], profile.get_opt_shapes(),
922+
tuning_config)
944923

945924
self._debug_logger(
946925
f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}."
@@ -1161,16 +1140,13 @@ def _find_nearest_profile(
11611140
dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...],
11621141
constraint_specs: Tuple[ConstraintSpec, ...],
11631142
tune_max_num_tokens: int = None,
1164-
use_tuning_mapping: bool = False,
11651143
) -> Tuple:
11661144
"""Find the nearest optimization profile for given inputs
11671145
User can define their own nearest profile generation method to reduce the host overhead.
11681146
11691147
Args:
11701148
shapes: Tuple of input tensor shapes
11711149
tuning_config: Tuning configuration
1172-
use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1173-
If False, use map_to_runtime_buckets for runtime cache lookups.
11741150
11751151
Return:
11761152
Tuple: A tuple containing:
@@ -1180,12 +1156,9 @@ def _find_nearest_profile(
11801156
base_profile = list(list(shape) for shape in shapes)
11811157

11821158
for spec in dynamic_tensor_specs:
1183-
1184-
bucket_mapper = spec.map_to_tuning_buckets
1185-
if not use_tuning_mapping and spec.map_to_runtime_buckets is not None:
1186-
bucket_mapper = spec.map_to_runtime_buckets
1187-
base_profile[spec.input_idx][spec.dim_idx] = bucket_mapper(
1188-
base_profile[spec.input_idx][spec.dim_idx])
1159+
base_profile[spec.input_idx][
1160+
spec.dim_idx] = spec.map_to_tuning_buckets(
1161+
base_profile[spec.input_idx][spec.dim_idx])
11891162

11901163
if tune_max_num_tokens is not None:
11911164
base_profile[spec.input_idx][spec.dim_idx] = min(

tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py

Lines changed: 48 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -273,20 +273,18 @@ def get_dynamic_tensor_specs(cls,
273273
ep_size: int) -> Tuple[DynamicTensorSpec, ...]:
274274
HIDDEN_STATES_IDX = 2
275275
TUNED_DIM = 0
276-
MAX_PROFILE_BUCKET = 4096
277276

277+
# Extend max profiled bucket by ep_size
278+
MAX_PROFILE_BUCKET = 4096 * ep_size
278279
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
279280

280-
def round_rule(x: int, ep_size_: int) -> int:
281-
return min(
282-
last_positive_power_of_2(x) // ep_size_, MAX_PROFILE_BUCKET)
281+
# 1/ep_size is the expected token fill rate
282+
# Fill rate maps buffer size into expected token count that represents actual works
283+
round_rule = lambda x: min(last_positive_power_of_2(x // ep_size),
284+
MAX_PROFILE_BUCKET)
283285

284-
specs = (DynamicTensorSpec(
285-
HIDDEN_STATES_IDX,
286-
TUNED_DIM,
287-
m_values,
288-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
289-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
286+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
287+
round_rule), )
290288

291289
return specs
292290

@@ -619,20 +617,18 @@ def get_dynamic_tensor_specs(cls,
619617
ep_size: int) -> Tuple[DynamicTensorSpec, ...]:
620618
HIDDEN_STATES_IDX = 2
621619
TUNED_DIM = 0
622-
MAX_PROFILE_BUCKET = 4096
623620

621+
# Extend max profiled bucket by ep_size
622+
MAX_PROFILE_BUCKET = 4096 * ep_size
624623
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
625624

626-
def round_rule(x: int, ep_size_: int) -> int:
627-
return min(
628-
last_positive_power_of_2(x) // ep_size_, MAX_PROFILE_BUCKET)
625+
# 1/ep_size is the expected token fill rate
626+
# Fill rate maps buffer size into expected token count that represents actual works
627+
round_rule = lambda x: min(last_positive_power_of_2(x // ep_size),
628+
MAX_PROFILE_BUCKET)
629629

630-
specs = (DynamicTensorSpec(
631-
HIDDEN_STATES_IDX,
632-
TUNED_DIM,
633-
m_values,
634-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
635-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
630+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
631+
round_rule), )
636632

637633
return specs
638634

@@ -914,20 +910,18 @@ def get_dynamic_tensor_specs(cls,
914910
ep_size: int) -> Tuple[DynamicTensorSpec, ...]:
915911
HIDDEN_STATES_IDX = 2
916912
TUNED_DIM = 0
917-
MAX_PROFILE_BUCKET = 4096
918913

914+
# Extend max profiled bucket by ep_size
915+
MAX_PROFILE_BUCKET = 4096 * ep_size
919916
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
920917

921-
def round_rule(x: int, ep_size_: int) -> int:
922-
return min(
923-
last_positive_power_of_2(x) // ep_size_, MAX_PROFILE_BUCKET)
918+
# 1/ep_size is the expected token fill rate
919+
# Fill rate maps buffer size into expected token count that represents actual works
920+
round_rule = lambda x: min(last_positive_power_of_2(x // ep_size),
921+
MAX_PROFILE_BUCKET)
924922

925-
specs = (DynamicTensorSpec(
926-
HIDDEN_STATES_IDX,
927-
TUNED_DIM,
928-
m_values,
929-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
930-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
923+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
924+
round_rule), )
931925

932926
return specs
933927

@@ -1214,20 +1208,18 @@ def get_dynamic_tensor_specs(cls,
12141208
ep_size: int) -> Tuple[DynamicTensorSpec, ...]:
12151209
HIDDEN_STATES_IDX = 2
12161210
TUNED_DIM = 0
1217-
MAX_PROFILE_BUCKET = 4096
12181211

1212+
# Extend max profiled bucket by ep_size
1213+
MAX_PROFILE_BUCKET = 4096 * ep_size
12191214
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
12201215

1221-
def round_rule(x: int, ep_size_: int) -> int:
1222-
return min(
1223-
last_positive_power_of_2(x) // ep_size_, MAX_PROFILE_BUCKET)
1216+
# 1/ep_size is the expected token fill rate
1217+
# Fill rate maps buffer size into expected token count that represents actual works
1218+
round_rule = lambda x: min(last_positive_power_of_2(x // ep_size),
1219+
MAX_PROFILE_BUCKET)
12241220

1225-
specs = (DynamicTensorSpec(
1226-
HIDDEN_STATES_IDX,
1227-
TUNED_DIM,
1228-
m_values,
1229-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1230-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1221+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
1222+
round_rule), )
12311223

12321224
return specs
12331225

@@ -1492,20 +1484,18 @@ def get_dynamic_tensor_specs(cls,
14921484
ep_size: int) -> Tuple[DynamicTensorSpec, ...]:
14931485
HIDDEN_STATES_IDX = 2
14941486
TUNED_DIM = 0
1495-
MAX_PROFILE_BUCKET = 4096
14961487

1488+
# Extend max profiled bucket by ep_size
1489+
MAX_PROFILE_BUCKET = 4096 * ep_size
14971490
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
14981491

1499-
def round_rule(x: int, ep_size_: int) -> int:
1500-
return min(
1501-
last_positive_power_of_2(x) // ep_size_, MAX_PROFILE_BUCKET)
1492+
# 1/ep_size is the expected token fill rate
1493+
# Fill rate maps buffer size into expected token count that represents actual works
1494+
round_rule = lambda x: min(last_positive_power_of_2(x // ep_size),
1495+
MAX_PROFILE_BUCKET)
15021496

1503-
specs = (DynamicTensorSpec(
1504-
HIDDEN_STATES_IDX,
1505-
TUNED_DIM,
1506-
m_values,
1507-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1508-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1497+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
1498+
round_rule), )
15091499

15101500
return specs
15111501

@@ -1755,20 +1745,18 @@ def get_dynamic_tensor_specs(cls,
17551745
ep_size: int) -> Tuple[DynamicTensorSpec, ...]:
17561746
HIDDEN_STATES_IDX = 2
17571747
TUNED_DIM = 0
1758-
MAX_PROFILE_BUCKET = 4096
17591748

1749+
# Extend max profiled bucket by ep_size
1750+
MAX_PROFILE_BUCKET = 4096 * ep_size
17601751
m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
17611752

1762-
def round_rule(x: int, ep_size_: int) -> int:
1763-
return min(
1764-
last_positive_power_of_2(x) // ep_size_, MAX_PROFILE_BUCKET)
1753+
# 1/ep_size is the expected token fill rate
1754+
# Fill rate maps buffer size into expected token count that represents actual works
1755+
round_rule = lambda x: min(last_positive_power_of_2(x // ep_size),
1756+
MAX_PROFILE_BUCKET)
17651757

1766-
specs = (DynamicTensorSpec(
1767-
HIDDEN_STATES_IDX,
1768-
TUNED_DIM,
1769-
m_values,
1770-
map_to_tuning_buckets=lambda x: round_rule(x, 1),
1771-
map_to_runtime_buckets=lambda x: round_rule(x, ep_size)), )
1758+
specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
1759+
round_rule), )
17721760

17731761
return specs
17741762

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -171,82 +171,6 @@ def test_autotuner_cache_basic():
171171
m //= 2
172172

173173

174-
def test_runtime_bucket_mapping():
175-
"""Test that map_to_runtime_buckets correctly maps runtime sizes to tuning buckets.
176-
177-
This test demonstrates the distinction between map_to_tuning_buckets and map_to_runtime_buckets:
178-
- map_to_tuning_buckets: used during tuning to store cache keys with raw bucket values
179-
- map_to_runtime_buckets: used during runtime to map input sizes to tuning buckets
180-
181-
With inflate_factor=4:
182-
- Tuning stores buckets: 1, 2, 4, 8, 16, 32
183-
- Runtime input 4 -> maps to bucket 1 via round_rule(4) = 4 // 4 = 1
184-
- Runtime input 16 -> maps to bucket 4 via round_rule(16) = 16 // 4 = 4
185-
186-
In MoE EP, the input buffer size is inflated by factor of the EP size to expect the worse case.
187-
Using map_to_runtime_buckets allows us to adjust the expected token count, instead of maximum
188-
possible token count.
189-
"""
190-
w = torch.randn(64, 128)
191-
tuner = AutoTuner.get()
192-
tuner.clear_cache()
193-
194-
# The factor indicating the input shape is inflated by X
195-
def bucket_mapper(x: int, inflate_factor: int) -> int:
196-
return x // inflate_factor
197-
198-
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
199-
input_idx=0,
200-
dim_idx=0,
201-
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
202-
map_to_tuning_buckets=lambda x: bucket_mapper(x, inflate_factor=1),
203-
map_to_runtime_buckets=lambda x: bucket_mapper(x, inflate_factor=4)),
204-
), )
205-
206-
runners = [GemmRunner()]
207-
208-
# Tune with M=32, which generates buckets 1, 2, 4, 8, 16, 32
209-
with autotune():
210-
tuner.choose_one("test_runtime_bucket_mapping", runners, tuning_config,
211-
[torch.randn(M, 64), w])
212-
213-
# Verify cache entries use raw tuning bucket values, not deflated values
214-
cache_entries = tuner.profiling_cache.get_specific_custom_op(
215-
"test_runtime_bucket_mapping")
216-
217-
# Extract the first dimension of the first input shape from each cache key
218-
assert len(cache_entries) == 6, \
219-
f"Expected 6 cache entries (buckets 1, 2, 4, 8, 16, 32), got {len(cache_entries)}"
220-
221-
# Test runtime mapping: input size should be mapped via map_to_runtime_buckets
222-
# to find the correct tuning bucket
223-
test_cases = [
224-
# size 4 maps to bucket 4//4 = 1, tactic 0 (1 <= M // 2)
225-
(4, 1, 0),
226-
# size 8 maps to bucket 8//4 = 2, tactic 0 (2 <= M // 2)
227-
(8, 2, 0),
228-
# size 16 maps to bucket 16//4 = 4, tactic 0 (4 <= M // 2)
229-
(16, 4, 0),
230-
# size 32 maps to bucket 32//4 = 8, tactic 0 (8 <= M // 2)
231-
(32, 8, 0),
232-
# size 64 maps to bucket 64//4 = 16, tactic 0 (16 <= M // 2)
233-
(64, 16, 0),
234-
# size 128 maps to bucket 128//4 = 32, tactic 1 (32 > M // 2)
235-
(128, 32, 1),
236-
# size 256 maps to bucket 256//4 = 64, tactic -1 (64 > M)
237-
(256, 64, -1),
238-
]
239-
240-
for input_size, expected_bucket, expected_tactic in test_cases:
241-
# Verify cache lookup succeeds with the mapped bucket
242-
x = torch.randn(input_size, 64)
243-
runner, tactic = tuner.choose_one("test_runtime_bucket_mapping",
244-
runners, tuning_config, [x, w])
245-
assert (
246-
tactic == expected_tactic
247-
), f"Cache mismatch for input_size={input_size}, expected to map to bucket {expected_tactic} but got {tactic}"
248-
249-
250174
def test_autotuner_try_block():
251175

252176
class PartialCrashedRunner(TunableRunner):

0 commit comments

Comments
 (0)