@@ -46,15 +46,12 @@ class DynamicTensorSpec:
4646 input_idx: The index of the input tensor.
4747 dim_idx: The index of the dimension to tune.
4848 gen_tuning_buckets: A tuple of values to try or a function generating values.
49- map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
50- map_to_runtime_buckets: A function to map dimensions to valid values during inference.
51- If None, use map_to_tuning_buckets.
49+ map_to_tuning_buckets: A function to map dimensions to tuning buckets during inference.
5250 """
5351 input_idx : int
5452 dim_idx : int
5553 gen_tuning_buckets : Union [Tuple [int ], Callable ] = ()
5654 map_to_tuning_buckets : Callable = lambda x : x
57- map_to_runtime_buckets : Optional [Callable ] = None
5855
5956
6057@dataclass (slots = True , unsafe_hash = True )
@@ -81,7 +78,7 @@ class TuningConfig:
8178 should be tuned to optimize performance. Each spec defines:
8279 - Which input tensor dimension is dynamic
8380 - How to generate tuning values
84- - How to map dimensions to valid values during inference
81+ - How to map dimensions to tuning values during inference
8582
8683 Example:
8784 >>> config = TuningConfig(
@@ -392,27 +389,26 @@ def search_cache(
392389 runners : List [TunableRunner ],
393390 input_shapes : Tuple [torch .Size ],
394391 tuning_config : TuningConfig ,
395- use_tuning_mapping : bool = False ,
392+ apply_map_to_tuning_buckets : bool = True ,
396393 ) -> Tuple [bool , int , int , Dict [str , Any ], OptimizationProfile ]:
397394 """Search for cached profiling results matching the current configuration.
398395
399396 Args:
400397 custom_op (str): The name of the custom operation to be tuned
401398 runners (List[TunableRunner]): List of candidate implementations to profile
402399 profile (OptimizationProfile): Optimization profile
403- use_tuning_mapping : If True, use map_to_tuning_buckets for cache key .
404- If False, use map_to_runtime_buckets for runtime cache lookups .
400+ apply_map_to_tuning_buckets : If True, apply map_to_tuning_buckets for runtime cache lookups .
401+ If False, use raw bucket values for tuning cache storage .
405402
406403 Returns:
407404 A tuple containing:
408405 [is_cache_hit, runner_id, tactic, stored_profile]
409406 runner_id is the index in the current runners list
410407 """
411408 for idx , r in enumerate (runners ):
412- if (cache_key :=
413- self .get_cache_key (custom_op , r , input_shapes ,
414- tuning_config ,
415- use_tuning_mapping )) in self .cache :
409+ if (cache_key := self .get_cache_key (
410+ custom_op , r , input_shapes , tuning_config ,
411+ apply_map_to_tuning_buckets )) in self .cache :
416412 # Return the current index in runners list, not the cached runner_id
417413 cached_runner_id , tactic , min_time = self .cache [cache_key ]
418414 return True , idx , tactic , min_time
@@ -425,7 +421,7 @@ def get_cache_key(
425421 runner : TunableRunner ,
426422 input_shapes : Tuple [torch .Size ],
427423 tuning_config : TuningConfig ,
428- use_tuning_mapping : bool = False ,
424+ apply_map_to_tuning_buckets : bool = True ,
429425 ) -> Tuple :
430426 return (
431427 custom_op ,
@@ -436,7 +432,7 @@ def get_cache_key(
436432 tuning_config .dynamic_tensor_specs ,
437433 tuning_config .constraint_specs ,
438434 tuning_config .tune_max_num_tokens ,
439- use_tuning_mapping ,
435+ apply_map_to_tuning_buckets ,
440436 ),
441437 )
442438
@@ -789,7 +785,11 @@ def choose_one(
789785
790786 input_shapes = tuple (self ._get_input_sizes (inputs ))
791787 is_cache_hit , best_runner_id , best_tactic , min_time = self .profiling_cache .search_cache (
792- custom_op , runners , input_shapes , tuning_config )
788+ custom_op ,
789+ runners ,
790+ input_shapes ,
791+ tuning_config ,
792+ apply_map_to_tuning_buckets = True )
793793
794794 # Early return if it's not tuning, use cache found one or fallback one
795795 if not self .is_tuning_mode :
@@ -831,7 +831,7 @@ def choose_one(
831831 runners ,
832832 p .get_opt_shapes (),
833833 tuning_config ,
834- use_tuning_mapping = True ,
834+ apply_map_to_tuning_buckets = False ,
835835 )
836836 if not is_cache_hit :
837837 # Initialize runner and tactic as None in case of no valid tactic or runners are found
@@ -923,7 +923,7 @@ def _profile_runners(
923923 runner ,
924924 profile .get_opt_shapes (),
925925 tuning_config ,
926- use_tuning_mapping = True ))
926+ apply_map_to_tuning_buckets = False ))
927927
928928 # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
929929 # or some runtime error occurs during profiling.
@@ -940,7 +940,7 @@ def _profile_runners(
940940 runners [best_runner_id ],
941941 profile .get_opt_shapes (),
942942 tuning_config ,
943- use_tuning_mapping = True )
943+ apply_map_to_tuning_buckets = False )
944944
945945 self ._debug_logger (
946946 f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
@@ -1117,8 +1117,7 @@ def _optimization_profiles(
11171117 # Add the current input value as one of the opt values
11181118 opt_shapes = set (opt_shapes )
11191119 opt_shapes .add (
1120- spec .map_to_tuning_buckets (
1121- base_profile .shapes [spec .input_idx ][spec .dim_idx ].val ))
1120+ base_profile .shapes [spec .input_idx ][spec .dim_idx ].val )
11221121 opt_shapes = sorted (list (opt_shapes ))
11231122 opt_shapes_max = tuple (opt_shapes [1 :]) + (float ('inf' ), )
11241123 opt_shapes_max = {
@@ -1161,16 +1160,16 @@ def _find_nearest_profile(
11611160 dynamic_tensor_specs : Tuple [DynamicTensorSpec , ...],
11621161 constraint_specs : Tuple [ConstraintSpec , ...],
11631162 tune_max_num_tokens : int = None ,
1164- use_tuning_mapping : bool = False ,
1163+ apply_map_to_tuning_buckets : bool = True ,
11651164 ) -> Tuple :
11661165 """Find the nearest optimization profile for given inputs
11671166 User can define their own nearest profile generation method to reduce the host overhead.
11681167
11691168 Args:
11701169 shapes: Tuple of input tensor shapes
11711170 tuning_config: Tuning configuration
1172- use_tuning_mapping : If True, use map_to_tuning_buckets to store tuning cache.
1173- If False, use map_to_runtime_buckets for runtime cache lookups .
1171+ apply_map_to_tuning_buckets : If True, apply map_to_tuning_buckets for runtime cache lookups .
1172+ If False, use raw bucket values for tuning cache storage .
11741173
11751174 Return:
11761175 Tuple: A tuple containing:
@@ -1180,12 +1179,12 @@ def _find_nearest_profile(
11801179 base_profile = list (list (shape ) for shape in shapes )
11811180
11821181 for spec in dynamic_tensor_specs :
1183-
1184- bucket_mapper = spec . map_to_tuning_buckets
1185- if not use_tuning_mapping and spec . map_to_runtime_buckets is not None :
1186- bucket_mapper = spec .map_to_runtime_buckets
1187- base_profile [ spec .input_idx ][ spec . dim_idx ] = bucket_mapper (
1188- base_profile [spec .input_idx ][spec .dim_idx ])
1182+ # During runtime: apply map_to_tuning_buckets to map input to bucket
1183+ # During tuning: no mapper, use raw bucket value
1184+ if apply_map_to_tuning_buckets :
1185+ base_profile [ spec .input_idx ][
1186+ spec .dim_idx ] = spec . map_to_tuning_buckets (
1187+ base_profile [spec .input_idx ][spec .dim_idx ])
11891188
11901189 if tune_max_num_tokens is not None :
11911190 base_profile [spec .input_idx ][spec .dim_idx ] = min (
0 commit comments