@@ -49,15 +49,12 @@ class DynamicTensorSpec:
4949 input_idx: The index of the input tensor.
5050 dim_idx: The index of the dimension to tune.
5151 gen_tuning_buckets: A tuple of values to try or a function generating values.
52- map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
53- map_to_runtime_buckets: A function to map dimensions to valid values during inference.
54- If None, use map_to_tuning_buckets.
52+ map_to_tuning_buckets: A function to map dimensions to tuning buckets during inference.
5553 """
5654 input_idx : int
5755 dim_idx : int
5856 gen_tuning_buckets : Union [Tuple [int ], Callable ] = ()
5957 map_to_tuning_buckets : Callable = lambda x : x
60- map_to_runtime_buckets : Optional [Callable ] = None
6158
6259
6360@dataclass (slots = True , unsafe_hash = True )
@@ -84,7 +81,7 @@ class TuningConfig:
8481 should be tuned to optimize performance. Each spec defines:
8582 - Which input tensor dimension is dynamic
8683 - How to generate tuning values
87- - How to map dimensions to valid values during inference
84+ - How to map dimensions to tuning values during inference
8885
8986 Example:
9087 >>> config = TuningConfig(
@@ -395,27 +392,26 @@ def search_cache(
395392 runners : List [TunableRunner ],
396393 input_shapes : Tuple [torch .Size ],
397394 tuning_config : TuningConfig ,
398- use_tuning_mapping : bool = False ,
395+ apply_map_to_tuning_buckets : bool = True ,
399396 ) -> Tuple [bool , int , int , Dict [str , Any ], OptimizationProfile ]:
400397 """Search for cached profiling results matching the current configuration.
401398
402399 Args:
403400 custom_op (str): The name of the custom operation to be tuned
404401 runners (List[TunableRunner]): List of candidate implementations to profile
405402 profile (OptimizationProfile): Optimization profile
406- use_tuning_mapping : If True, use map_to_tuning_buckets for cache key .
407- If False, use map_to_runtime_buckets for runtime cache lookups .
403+ apply_map_to_tuning_buckets : If True, apply map_to_tuning_buckets for runtime cache lookups .
404+ If False, use raw bucket values for tuning cache storage .
408405
409406 Returns:
410407 A tuple containing:
411408 [is_cache_hit, runner_id, tactic, stored_profile]
412409 runner_id is the index in the current runners list
413410 """
414411 for idx , r in enumerate (runners ):
415- if (cache_key :=
416- self .get_cache_key (custom_op , r , input_shapes ,
417- tuning_config ,
418- use_tuning_mapping )) in self .cache :
412+ if (cache_key := self .get_cache_key (
413+ custom_op , r , input_shapes , tuning_config ,
414+ apply_map_to_tuning_buckets )) in self .cache :
419415 # Return the current index in runners list, not the cached runner_id
420416 cached_runner_id , tactic , min_time = self .cache [cache_key ]
421417 return True , idx , tactic , min_time
@@ -428,7 +424,7 @@ def get_cache_key(
428424 runner : TunableRunner ,
429425 input_shapes : Tuple [torch .Size ],
430426 tuning_config : TuningConfig ,
431- use_tuning_mapping : bool = False ,
427+ apply_map_to_tuning_buckets : bool = True ,
432428 ) -> Tuple :
433429 return (
434430 custom_op ,
@@ -439,7 +435,7 @@ def get_cache_key(
439435 tuning_config .dynamic_tensor_specs ,
440436 tuning_config .constraint_specs ,
441437 tuning_config .tune_max_num_tokens ,
442- use_tuning_mapping ,
438+ apply_map_to_tuning_buckets ,
443439 ),
444440 )
445441
@@ -805,7 +801,11 @@ def choose_one(
805801
806802 input_shapes = tuple (self ._get_input_sizes (inputs ))
807803 is_cache_hit , best_runner_id , best_tactic , min_time = self .profiling_cache .search_cache (
808- custom_op , runners , input_shapes , tuning_config )
804+ custom_op ,
805+ runners ,
806+ input_shapes ,
807+ tuning_config ,
808+ apply_map_to_tuning_buckets = True )
809809
810810 # Early return if it's not tuning, use cache found one or fallback one
811811 if not self .is_tuning_mode :
@@ -855,7 +855,7 @@ def choose_one(
855855 runners ,
856856 p .get_opt_shapes (),
857857 tuning_config ,
858- use_tuning_mapping = True ,
858+ apply_map_to_tuning_buckets = False ,
859859 )
860860 if not is_cache_hit :
861861 # Initialize runner and tactic as None in case of no valid tactic or runners are found
@@ -947,7 +947,7 @@ def _profile_runners(
947947 runner ,
948948 profile .get_opt_shapes (),
949949 tuning_config ,
950- use_tuning_mapping = True ))
950+ apply_map_to_tuning_buckets = False ))
951951
952952 # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
953953 # or some runtime error occurs during profiling.
@@ -964,7 +964,7 @@ def _profile_runners(
964964 runners [best_runner_id ],
965965 profile .get_opt_shapes (),
966966 tuning_config ,
967- use_tuning_mapping = True )
967+ apply_map_to_tuning_buckets = False )
968968
969969 self ._debug_logger (
970970 f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
@@ -1141,8 +1141,7 @@ def _optimization_profiles(
11411141 # Add the current input value as one of the opt values
11421142 opt_shapes = set (opt_shapes )
11431143 opt_shapes .add (
1144- spec .map_to_tuning_buckets (
1145- base_profile .shapes [spec .input_idx ][spec .dim_idx ].val ))
1144+ base_profile .shapes [spec .input_idx ][spec .dim_idx ].val )
11461145 opt_shapes = sorted (list (opt_shapes ))
11471146 opt_shapes_max = tuple (opt_shapes [1 :]) + (float ('inf' ), )
11481147 opt_shapes_max = {
@@ -1185,16 +1184,16 @@ def _find_nearest_profile(
11851184 dynamic_tensor_specs : Tuple [DynamicTensorSpec , ...],
11861185 constraint_specs : Tuple [ConstraintSpec , ...],
11871186 tune_max_num_tokens : int = None ,
1188- use_tuning_mapping : bool = False ,
1187+ apply_map_to_tuning_buckets : bool = True ,
11891188 ) -> Tuple :
11901189 """Find the nearest optimization profile for given inputs
11911190 User can define their own nearest profile generation method to reduce the host overhead.
11921191
11931192 Args:
11941193 shapes: Tuple of input tensor shapes
11951194 tuning_config: Tuning configuration
1196- use_tuning_mapping : If True, use map_to_tuning_buckets to store tuning cache.
1197- If False, use map_to_runtime_buckets for runtime cache lookups .
1195+ apply_map_to_tuning_buckets : If True, apply map_to_tuning_buckets for runtime cache lookups .
1196+ If False, use raw bucket values for tuning cache storage .
11981197
11991198 Return:
12001199 Tuple: A tuple containing:
@@ -1204,12 +1203,12 @@ def _find_nearest_profile(
12041203 base_profile = list (list (shape ) for shape in shapes )
12051204
12061205 for spec in dynamic_tensor_specs :
1207-
1208- bucket_mapper = spec . map_to_tuning_buckets
1209- if not use_tuning_mapping and spec . map_to_runtime_buckets is not None :
1210- bucket_mapper = spec .map_to_runtime_buckets
1211- base_profile [ spec .input_idx ][ spec . dim_idx ] = bucket_mapper (
1212- base_profile [spec .input_idx ][spec .dim_idx ])
1206+ # During runtime: apply map_to_tuning_buckets to map input to bucket
1207+ # During tuning: no mapper, use raw bucket value
1208+ if apply_map_to_tuning_buckets :
1209+ base_profile [ spec .input_idx ][
1210+ spec .dim_idx ] = spec . map_to_tuning_buckets (
1211+ base_profile [spec .input_idx ][spec .dim_idx ])
12131212
12141213 if tune_max_num_tokens is not None :
12151214 base_profile [spec .input_idx ][spec .dim_idx ] = min (
0 commit comments