@@ -49,12 +49,15 @@ class DynamicTensorSpec:
4949 input_idx: The index of the input tensor.
5050 dim_idx: The index of the dimension to tune.
5151 gen_tuning_buckets: A tuple of values to try or a function generating values.
52- map_to_tuning_buckets: A function to map dimensions to valid values during inference.
52+ map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
53+ map_to_runtime_buckets: A function to map dimensions to valid values during inference.
54+ If None, use map_to_tuning_buckets.
5355 """
5456 input_idx : int
5557 dim_idx : int
5658 gen_tuning_buckets : Union [Tuple [int ], Callable ] = ()
5759 map_to_tuning_buckets : Callable = lambda x : x
60+ map_to_runtime_buckets : Optional [Callable ] = None
5861
5962
6063@dataclass (slots = True , unsafe_hash = True )
@@ -392,22 +395,27 @@ def search_cache(
392395 runners : List [TunableRunner ],
393396 input_shapes : Tuple [torch .Size ],
394397 tuning_config : TuningConfig ,
398+ use_tuning_mapping : bool = False ,
395399 ) -> Tuple [bool , int , int , Dict [str , Any ], OptimizationProfile ]:
396400 """Search for cached profiling results matching the current configuration.
397401
398402 Args:
399403 custom_op (str): The name of the custom operation to be tuned
400404 runners (List[TunableRunner]): List of candidate implementations to profile
401405 profile (OptimizationProfile): Optimization profile
406+ use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
407+ If False, use map_to_runtime_buckets for runtime cache lookups.
402408
403409 Returns:
404410 A tuple containing:
405411 [is_cache_hit, runner_id, tactic, stored_profile]
406412 runner_id is the index in the current runners list
407413 """
408414 for idx , r in enumerate (runners ):
409- if (cache_key := self .get_cache_key (custom_op , r , input_shapes ,
410- tuning_config )) in self .cache :
415+ if (cache_key :=
416+ self .get_cache_key (custom_op , r , input_shapes ,
417+ tuning_config ,
418+ use_tuning_mapping )) in self .cache :
411419 # Return the current index in runners list, not the cached runner_id
412420 cached_runner_id , tactic , min_time = self .cache [cache_key ]
413421 return True , idx , tactic , min_time
@@ -420,6 +428,7 @@ def get_cache_key(
420428 runner : TunableRunner ,
421429 input_shapes : Tuple [torch .Size ],
422430 tuning_config : TuningConfig ,
431+ use_tuning_mapping : bool = False ,
423432 ) -> Tuple :
424433 return (
425434 custom_op ,
@@ -430,6 +439,7 @@ def get_cache_key(
430439 tuning_config .dynamic_tensor_specs ,
431440 tuning_config .constraint_specs ,
432441 tuning_config .tune_max_num_tokens ,
442+ use_tuning_mapping ,
433443 ),
434444 )
435445
@@ -841,7 +851,12 @@ def choose_one(
841851 for p in profiles :
842852 tensors = self ._prepare_input_tensors (p , inputs )
843853 is_cache_hit , * _ = self .profiling_cache .search_cache (
844- custom_op , runners , p .get_opt_shapes (), tuning_config )
854+ custom_op ,
855+ runners ,
856+ p .get_opt_shapes (),
857+ tuning_config ,
858+ use_tuning_mapping = True ,
859+ )
845860 if not is_cache_hit :
846861 # Initialize runner and tactic as None in case of no valid tactic or runners are found
847862 best_runner_id , best_tactic , min_time , has_tuning_failure_occurred = self ._profile_runners (
@@ -928,8 +943,11 @@ def _profile_runners(
928943 # Record the failed profiling combinations
929944 self .stats .failed_profiling_count [custom_op ].add (
930945 self .profiling_cache .get_cache_key (
931- custom_op , runner , profile .get_opt_shapes (),
932- tuning_config ))
946+ custom_op ,
947+ runner ,
948+ profile .get_opt_shapes (),
949+ tuning_config ,
950+ use_tuning_mapping = True ))
933951
934952 # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
935953 # or some runtime error occurs during profiling.
@@ -942,8 +960,11 @@ def _profile_runners(
942960 if best_runner_id is not None :
943961 # At least one valid (runner, tactic) pair is found
944962 cache_key = self .profiling_cache .get_cache_key (
945- custom_op , runners [best_runner_id ], profile .get_opt_shapes (),
946- tuning_config )
963+ custom_op ,
964+ runners [best_runner_id ],
965+ profile .get_opt_shapes (),
966+ tuning_config ,
967+ use_tuning_mapping = True )
947968
948969 self ._debug_logger (
949970 f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
@@ -1164,13 +1185,16 @@ def _find_nearest_profile(
11641185 dynamic_tensor_specs : Tuple [DynamicTensorSpec , ...],
11651186 constraint_specs : Tuple [ConstraintSpec , ...],
11661187 tune_max_num_tokens : int = None ,
1188+ use_tuning_mapping : bool = False ,
11671189 ) -> Tuple :
11681190 """Find the nearest optimization profile for given inputs
11691191 User can define their own nearest profile generation method to reduce the host overhead.
11701192
11711193 Args:
11721194 shapes: Tuple of input tensor shapes
11731195 tuning_config: Tuning configuration
1196+ use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1197+ If False, use map_to_runtime_buckets for runtime cache lookups.
11741198
11751199 Return:
11761200 Tuple: A tuple containing:
@@ -1180,9 +1204,12 @@ def _find_nearest_profile(
11801204 base_profile = list (list (shape ) for shape in shapes )
11811205
11821206 for spec in dynamic_tensor_specs :
1183- base_profile [spec .input_idx ][
1184- spec .dim_idx ] = spec .map_to_tuning_buckets (
1185- base_profile [spec .input_idx ][spec .dim_idx ])
1207+
1208+ bucket_mapper = spec .map_to_tuning_buckets
1209+ if not use_tuning_mapping and spec .map_to_runtime_buckets is not None :
1210+ bucket_mapper = spec .map_to_runtime_buckets
1211+ base_profile [spec .input_idx ][spec .dim_idx ] = bucket_mapper (
1212+ base_profile [spec .input_idx ][spec .dim_idx ])
11861213
11871214 if tune_max_num_tokens is not None :
11881215 base_profile [spec .input_idx ][spec .dim_idx ] = min (
0 commit comments