Skip to content

Commit 072f236

Browse files
authored
[None][fix] Fully resolve the tactic recovery issues in AutoTuner serialized cache (#9835)
Restrict tactic types to those compatible with AutoTuner cache serialization and deserialization. Signed-off-by: Yukun He <[email protected]>
1 parent df1adfb commit 072f236

File tree

2 files changed

+70
-26
lines changed

2 files changed

+70
-26
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def get_valid_tactics(self, inputs: List[torch.Tensor],
169169
means. User can choose to implement their own types of tactic for flexibility, such as using a dict-typed
170170
to represent a collection of named configs.
171171
172+
The type of the tactic is arbitrary. But serialization/deserialization of the cache requires that the type is compatible with json.dumps/json.loads.
173+
To evaluate if a type of tactic is compatible with current workflow, try the following code:
174+
* assert YOUR_TACTIC_OBJECT == eval(repr(YOUR_TACTIC_OBJECT))
175+
172176
tactic==-1 has special meaning, means the fallback kernel which should be able to implement any shapes
173177
This fallback tactic is needed for 2 reasons:
174178
* when the autotuner cannot find a valid tactic in it's cache.
@@ -475,14 +479,22 @@ def _serialize_cache_to_json(self) -> Dict[str, Any]:
475479
}
476480

477481
for key, value in self.cache.items():
478-
# Convert tuple key to string for JSON compatibility
482+
# Convert any simple object to string for JSON compatibility
479483
key_str = str(key)
480-
481484
runner_id, tactic, min_time = value
485+
tactic_str = repr(tactic)
486+
try:
487+
assert tactic == ast.literal_eval(
488+
tactic_str
489+
), f"Tactic is not compatible with json.dumps/json.loads"
490+
except Exception as e:
491+
logger.warning_once(
492+
f"[AutoTuner] Could not serialize tactic: {tactic_str} for cache key {key_str} due to {e}. Deserialization may fail.",
493+
key=tactic_str)
482494

483495
serializable_cache["cache_data"][key_str] = {
484496
"runner_id": runner_id,
485-
"tactic": tactic,
497+
"tactic": tactic_str,
486498
"min_time": min_time,
487499
}
488500

@@ -511,22 +523,22 @@ def _deserialize_cache_from_json(
511523
cache = {}
512524
cache_data = serializable_cache["cache_data"]
513525

514-
def lists_to_tuples(obj):
515-
if isinstance(obj, list):
516-
return tuple(lists_to_tuples(x) for x in obj)
517-
return obj
518-
519526
for key_str, value in cache_data.items():
520527
# Reconstruct the tuple key safely
521528
try:
522-
key = ast.literal_eval(key_str) # Safer than eval()
529+
key = ast.literal_eval(key_str)
523530
except (ValueError, SyntaxError):
524531
logger.warning(
525532
f"[AutoTuner] Could not reconstruct cache key: {key_str}")
526533
continue
534+
try:
535+
tactic = ast.literal_eval(value["tactic"])
536+
except (ValueError, TypeError):
537+
logger.warning_once(
538+
f"[AutoTuner] Could not deserialize tactic: {value['tactic']} for cache key {key_str}",
539+
key=value["tactic"])
527540

528541
runner_id = value["runner_id"]
529-
tactic = lists_to_tuples(value["tactic"])
530542
min_time = value["min_time"]
531543

532544
cache[key] = (runner_id, tactic, min_time)

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import itertools
12
import os
23
import tempfile
3-
from typing import Dict, List
4+
from typing import Any, List
45

56
import torch
67

@@ -327,48 +328,63 @@ def test_multiple_dynamic_shapes_cache():
327328

328329

329330
class GemmRunnerComplexTuningConfigs(TunableRunner):
331+
332+
# test serialization of different types of tactics
330333
valid_tactic_ids = [-1, 0, 1]
334+
valid_tile_sizes = [(128, 128), (256, 256)]
335+
valid_cluster_sizes = [[1, 1, 1], [2, 2, 1]]
336+
331337
tune_max_num_tokens = 32
332338

333339
def get_valid_tactics(
334340
self,
335341
inputs: List[FakeTensor],
336342
profile: OptimizationProfile,
337343
**kwargs,
338-
) -> List[Dict[str, int]]:
344+
) -> List[Any]:
339345
# During the tuning process, we verify if the tuning config behaves as expected
340-
341346
assert inputs[0].shape[0] <= self.tune_max_num_tokens, \
342347
f"Input shape {inputs[0].shape[0]} is larger than the max num tokens {self.tune_max_num_tokens}"
343348

344349
assert inputs[0][-1, 0] == inputs[0].shape[0], \
345350
f"Input shape {inputs[0].shape[0]} is not set through the pre_hook correctly"
346351

347-
# The simulated delay is not deterministic, so we need to return specific tactics here
348352
return [{
349-
"block_size": block_size,
350-
"tactic_id": tactic_id
351-
} for tactic_id in self.valid_tactic_ids for block_size in [128, 256]]
353+
"int_tactic_id": tactic_id,
354+
"tuple_tile_size": tile_size,
355+
"list_cluster_size": cluster_size,
356+
} for tactic_id, tile_size, cluster_size in itertools.product(
357+
self.valid_tactic_ids,
358+
self.valid_tile_sizes,
359+
self.valid_cluster_sizes,
360+
)]
352361

353362
def forward(
354363
self,
355364
/,
356365
inputs: List[torch.Tensor],
357366
*,
358-
tactic: dict = {},
367+
tactic: Any = -1,
359368
) -> torch.Tensor:
360369
# Notice that in fallback case tactic is -1
361370
if tactic == -1:
362371
# assign default configs for fallback case
363-
block_size, tactic_id = 128, -1
372+
tactic_id, tile_size, cluster_size = -1, (128, 256), [1, 1, 1]
364373
else:
365-
block_size, tactic_id = tactic["block_size"], tactic["tactic_id"]
366-
assert tactic_id in self.valid_tactic_ids
374+
tactic_id, tile_size, cluster_size = tactic[
375+
"int_tactic_id"], tactic["tuple_tile_size"], tactic[
376+
"list_cluster_size"]
377+
378+
assert isinstance(tactic_id, int) and tactic_id in self.valid_tactic_ids
379+
assert isinstance(tile_size, tuple) and len(tile_size) == 2 \
380+
and tile_size in self.valid_tile_sizes
381+
assert isinstance(cluster_size, list) and len(cluster_size) == 3 \
382+
and cluster_size in self.valid_cluster_sizes
367383
return [gemm_0, gemm_1, gemm_fallback][tactic_id](*inputs)
368384

369385
@staticmethod
370386
def inputs_pre_hook(inputs: List[torch.Tensor]):
371-
# always set the first element to bo iota in x
387+
# always set the first element to be the number of tokens in x
372388
x, w = inputs
373389
x_hooked = torch.zeros_like(x)
374390
x_hooked[-1, 0] = x.shape[0]
@@ -389,13 +405,29 @@ def test_autotuner_tuning_configs():
389405
# Test if the number of tuning tokens is clipped to 32
390406
tune_max_num_tokens=GemmRunnerComplexTuningConfigs.tune_max_num_tokens,
391407
inputs_pre_hook=GemmRunnerComplexTuningConfigs.inputs_pre_hook,
408+
use_cold_l2_cache=True,
409+
use_cuda_graph=False,
392410
)
393-
with autotune():
411+
temp_dir = tempfile.TemporaryDirectory()
412+
with autotune(cache_path=os.path.join(
413+
temp_dir.name, "test_autotuner_tactic_configs.json")):
394414
tuner = AutoTuner.get()
395-
runner, tactic = tuner.choose_one("test_autotuner_tactic_configs",
396-
runners, tuning_config, [x, w])
415+
runner, best_tactic = tuner.choose_one("test_autotuner_tactic_configs",
416+
runners, tuning_config, [x, w])
417+
418+
runner_0([x, w], tactic=best_tactic)
419+
420+
# Test if the tactic can be loaded from cache correctly
421+
AutoTuner.get().profiling_cache.clear()
422+
AutoTuner.get().profiling_cache.load_cache(
423+
os.path.join(temp_dir.name, "test_autotuner_tactic_configs.rank0.json"))
424+
425+
# No further tuning should be performed.
426+
runner, deserialized_tactic = tuner.choose_one(
427+
"test_autotuner_tactic_configs", runners, tuning_config, [x, w])
428+
assert best_tactic == deserialized_tactic, "Tactic should be the same after deserialization"
397429

398-
runner_0.forward(inputs=[x, w], tactic=tactic)
430+
runner_0([x, w], tactic=deserialized_tactic)
399431

400432

401433
def test_kernel_testing_single_context():

0 commit comments

Comments
 (0)