@@ -46,12 +46,15 @@ class DynamicTensorSpec:
4646 input_idx: The index of the input tensor.
4747 dim_idx: The index of the dimension to tune.
4848 gen_tuning_buckets: A tuple of values to try or a function generating values.
49- map_to_tuning_buckets: A function to map dimensions to valid values during inference.
49+ map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
50+ map_to_runtime_buckets: A function to map dimensions to valid values during inference.
51+ If None, use map_to_tuning_buckets.
5052 """
5153 input_idx : int
5254 dim_idx : int
5355 gen_tuning_buckets : Union [Tuple [int ], Callable ] = ()
5456 map_to_tuning_buckets : Callable = lambda x : x
57+ map_to_runtime_buckets : Optional [Callable ] = None
5558
5659
5760@dataclass (slots = True , unsafe_hash = True )
@@ -389,22 +392,27 @@ def search_cache(
389392 runners : List [TunableRunner ],
390393 input_shapes : Tuple [torch .Size ],
391394 tuning_config : TuningConfig ,
395+ use_tuning_mapping : bool = False ,
392396 ) -> Tuple [bool , int , int , Dict [str , Any ], OptimizationProfile ]:
393397 """Search for cached profiling results matching the current configuration.
394398
395399 Args:
396400 custom_op (str): The name of the custom operation to be tuned
397401 runners (List[TunableRunner]): List of candidate implementations to profile
398402 profile (OptimizationProfile): Optimization profile
403+ use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
404+ If False, use map_to_runtime_buckets for runtime cache lookups.
399405
400406 Returns:
401407 A tuple containing:
402408 [is_cache_hit, runner_id, tactic, stored_profile]
403409 runner_id is the index in the current runners list
404410 """
405411 for idx , r in enumerate (runners ):
406- if (cache_key := self .get_cache_key (custom_op , r , input_shapes ,
407- tuning_config )) in self .cache :
412+ if (cache_key :=
413+ self .get_cache_key (custom_op , r , input_shapes ,
414+ tuning_config ,
415+ use_tuning_mapping )) in self .cache :
408416 # Return the current index in runners list, not the cached runner_id
409417 cached_runner_id , tactic , min_time = self .cache [cache_key ]
410418 return True , idx , tactic , min_time
@@ -417,6 +425,7 @@ def get_cache_key(
417425 runner : TunableRunner ,
418426 input_shapes : Tuple [torch .Size ],
419427 tuning_config : TuningConfig ,
428+ use_tuning_mapping : bool = False ,
420429 ) -> Tuple :
421430 return (
422431 custom_op ,
@@ -427,6 +436,7 @@ def get_cache_key(
427436 tuning_config .dynamic_tensor_specs ,
428437 tuning_config .constraint_specs ,
429438 tuning_config .tune_max_num_tokens ,
439+ use_tuning_mapping ,
430440 ),
431441 )
432442
@@ -817,7 +827,12 @@ def choose_one(
817827 for p in profiles :
818828 tensors = self ._prepare_input_tensors (p , inputs )
819829 is_cache_hit , * _ = self .profiling_cache .search_cache (
820- custom_op , runners , p .get_opt_shapes (), tuning_config )
830+ custom_op ,
831+ runners ,
832+ p .get_opt_shapes (),
833+ tuning_config ,
834+ use_tuning_mapping = True ,
835+ )
821836 if not is_cache_hit :
822837 # Initialize runner and tactic as None in case of no valid tactic or runners are found
823838 best_runner_id , best_tactic , min_time , has_tuning_failure_occurred = self ._profile_runners (
@@ -904,8 +919,11 @@ def _profile_runners(
904919 # Record the failed profiling combinations
905920 self .stats .failed_profiling_count [custom_op ].add (
906921 self .profiling_cache .get_cache_key (
907- custom_op , runner , profile .get_opt_shapes (),
908- tuning_config ))
922+ custom_op ,
923+ runner ,
924+ profile .get_opt_shapes (),
925+ tuning_config ,
926+ use_tuning_mapping = True ))
909927
910928 # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
911929 # or some runtime error occurs during profiling.
@@ -918,8 +936,11 @@ def _profile_runners(
918936 if best_runner_id is not None :
919937 # At least one valid (runner, tactic) pair is found
920938 cache_key = self .profiling_cache .get_cache_key (
921- custom_op , runners [best_runner_id ], profile .get_opt_shapes (),
922- tuning_config )
939+ custom_op ,
940+ runners [best_runner_id ],
941+ profile .get_opt_shapes (),
942+ tuning_config ,
943+ use_tuning_mapping = True )
923944
924945 self ._debug_logger (
925946 f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
@@ -1140,13 +1161,16 @@ def _find_nearest_profile(
11401161 dynamic_tensor_specs : Tuple [DynamicTensorSpec , ...],
11411162 constraint_specs : Tuple [ConstraintSpec , ...],
11421163 tune_max_num_tokens : int = None ,
1164+ use_tuning_mapping : bool = False ,
11431165 ) -> Tuple :
11441166 """Find the nearest optimization profile for given inputs
11451167 User can define their own nearest profile generation method to reduce the host overhead.
11461168
11471169 Args:
11481170 shapes: Tuple of input tensor shapes
11491171 tuning_config: Tuning configuration
1172+ use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1173+ If False, use map_to_runtime_buckets for runtime cache lookups.
11501174
11511175 Return:
11521176 Tuple: A tuple containing:
@@ -1156,9 +1180,12 @@ def _find_nearest_profile(
11561180 base_profile = list (list (shape ) for shape in shapes )
11571181
11581182 for spec in dynamic_tensor_specs :
1159- base_profile [spec .input_idx ][
1160- spec .dim_idx ] = spec .map_to_tuning_buckets (
1161- base_profile [spec .input_idx ][spec .dim_idx ])
1183+
1184+ bucket_mapper = spec .map_to_tuning_buckets
1185+ if not use_tuning_mapping and spec .map_to_runtime_buckets is not None :
1186+ bucket_mapper = spec .map_to_runtime_buckets
1187+ base_profile [spec .input_idx ][spec .dim_idx ] = bucket_mapper (
1188+ base_profile [spec .input_idx ][spec .dim_idx ])
11621189
11631190 if tune_max_num_tokens is not None :
11641191 base_profile [spec .input_idx ][spec .dim_idx ] = min (
0 commit comments