@@ -28,12 +28,15 @@ class DynamicTensorSpec:
2828 input_idx: The index of the input tensor.
2929 dim_idx: The index of the dimension to tune.
3030 gen_tuning_buckets: A tuple of values to try or a function generating values.
31- map_to_tuning_buckets: A function to map dimensions to valid values during inference.
31+ map_to_tuning_buckets: A function to map dimensions to valid values during tuning.
32+ map_to_runtime_buckets: A function to map dimensions to valid values during inference.
33+ If None, use map_to_tuning_buckets.
3234 """
3335 input_idx : int
3436 dim_idx : int
3537 gen_tuning_buckets : Union [Tuple [int ], Callable ] = ()
3638 map_to_tuning_buckets : Callable = lambda x : x
39+ map_to_runtime_buckets : Optional [Callable ] = None
3740
3841
3942@dataclass (slots = True , unsafe_hash = True )
@@ -358,22 +361,27 @@ def search_cache(
358361 runners : List [TunableRunner ],
359362 input_shapes : Tuple [torch .Size ],
360363 tuning_config : TuningConfig ,
364+ use_tuning_mapping : bool = False ,
361365 ) -> Tuple [bool , int , int , Dict [str , Any ], OptimizationProfile ]:
362366 """Search for cached profiling results matching the current configuration.
363367
364368 Args:
365369 custom_op (str): The name of the custom operation to be tuned
366370 runners (List[TunableRunner]): List of candidate implementations to profile
367371 profile (OptimizationProfile): Optimization profile
372+ use_tuning_mapping: If True, use map_to_tuning_buckets for cache key.
373+ If False, use map_to_runtime_buckets for runtime cache lookups.
368374
369375 Returns:
370376 A tuple containing:
371377 [is_cache_hit, runner_id, tactic, stored_profile]
372378 runner_id is the index in the current runners list
373379 """
374380 for idx , r in enumerate (runners ):
375- if (cache_key := self .get_cache_key (custom_op , r , input_shapes ,
376- tuning_config )) in self .cache :
381+ if (cache_key :=
382+ self .get_cache_key (custom_op , r , input_shapes ,
383+ tuning_config ,
384+ use_tuning_mapping )) in self .cache :
377385 # Return the current index in runners list, not the cached runner_id
378386 cached_runner_id , tactic , min_time = self .cache [cache_key ]
379387 return True , idx , tactic , min_time
@@ -386,6 +394,7 @@ def get_cache_key(
386394 runner : TunableRunner ,
387395 input_shapes : Tuple [torch .Size ],
388396 tuning_config : TuningConfig ,
397+ use_tuning_mapping : bool = False ,
389398 ) -> Tuple :
390399 return (
391400 custom_op ,
@@ -396,6 +405,7 @@ def get_cache_key(
396405 tuning_config .dynamic_tensor_specs ,
397406 tuning_config .constraint_specs ,
398407 tuning_config .tune_max_num_tokens ,
408+ use_tuning_mapping ,
399409 ),
400410 )
401411
@@ -773,16 +783,23 @@ def choose_one(
773783 for p in profiles :
774784 tensors = self ._prepare_input_tensors (p , inputs )
775785 is_cache_hit , * _ = self .profiling_cache .search_cache (
776- custom_op , runners , p .get_opt_shapes (), tuning_config )
786+ custom_op ,
787+ runners ,
788+ p .get_opt_shapes (),
789+ tuning_config ,
790+ use_tuning_mapping = True )
777791 if not is_cache_hit :
778792 # Initialize runner and tactic as None in case of no valid tactic or runners are found
779793 best_runner_id , best_tactic , min_time , has_tuning_failure_occured = self ._profile_runners (
780794 custom_op , runners , tensors , p , tuning_config , ** kwargs )
781795 if best_runner_id is not None :
782796 # At least one valid (runner, tactic) pair is found
783797 cache_key = self .profiling_cache .get_cache_key (
784- custom_op , runners [best_runner_id ], p .get_opt_shapes (),
785- tuning_config )
798+ custom_op ,
799+ runners [best_runner_id ],
800+ p .get_opt_shapes (),
801+ tuning_config ,
802+ use_tuning_mapping = True )
786803
787804 self ._debug_logger (
788805 f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
@@ -876,8 +893,11 @@ def _profile_runners(
876893 # Record the failed profiling combinations
877894 self .stats .failed_profiling_count [custom_op ].add (
878895 self .profiling_cache .get_cache_key (
879- custom_op , runner , profile .get_opt_shapes (),
880- tuning_config ))
896+ custom_op ,
897+ runner ,
898+ profile .get_opt_shapes (),
899+ tuning_config ,
900+ use_tuning_mapping = True ))
881901
882902 # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
883903 # or some runtime error occurs during profiling.
@@ -1088,13 +1108,16 @@ def _find_nearest_profile(
10881108 dynamic_tensor_specs : Tuple [DynamicTensorSpec , ...],
10891109 constraint_specs : Tuple [ConstraintSpec , ...],
10901110 tune_max_num_tokens : int = None ,
1111+ use_tuning_mapping : bool = False ,
10911112 ) -> Tuple :
10921113 """Find the nearest optimization profile for given inputs
10931114 User can define their own nearest profile generation method to reduce the host overhead.
10941115
10951116 Args:
10961117 shapes: Tuple of input tensor shapes
10971118 tuning_config: Tuning configuration
1119+ use_tuning_mapping: If True, use map_to_tuning_buckets to store tuning cache.
1120+ If False, use map_to_runtime_buckets for runtime cache lookups.
10981121
10991122 Return:
11001123 Tuple: A tuple containing:
@@ -1104,9 +1127,12 @@ def _find_nearest_profile(
11041127 base_profile = list (list (shape ) for shape in shapes )
11051128
11061129 for spec in dynamic_tensor_specs :
1107- base_profile [spec .input_idx ][
1108- spec .dim_idx ] = spec .map_to_tuning_buckets (
1109- base_profile [spec .input_idx ][spec .dim_idx ])
1130+
1131+ bucket_mapper = spec .map_to_tuning_buckets
1132+ if not use_tuning_mapping and spec .map_to_runtime_buckets is not None :
1133+ bucket_mapper = spec .map_to_runtime_buckets
1134+ base_profile [spec .input_idx ][spec .dim_idx ] = bucket_mapper (
1135+ base_profile [spec .input_idx ][spec .dim_idx ])
11101136
11111137 if tune_max_num_tokens is not None :
11121138 base_profile [spec .input_idx ][spec .dim_idx ] = min (
0 commit comments