diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 33ef41af8a9..41caf648adf 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -49,7 +49,7 @@ class DynamicTensorSpec: input_idx: The index of the input tensor. dim_idx: The index of the dimension to tune. gen_tuning_buckets: A tuple of values to try or a function generating values. - map_to_tuning_buckets: A function to map dimensions to valid values during inference. + map_to_tuning_buckets: A function to map dimensions to tuning buckets during inference. """ input_idx: int dim_idx: int @@ -81,7 +81,7 @@ class TuningConfig: should be tuned to optimize performance. Each spec defines: - Which input tensor dimension is dynamic - How to generate tuning values - - How to map dimensions to valid values during inference + - How to map dimensions to tuning values during inference Example: >>> config = TuningConfig( @@ -392,6 +392,7 @@ def search_cache( runners: List[TunableRunner], input_shapes: Tuple[torch.Size], tuning_config: TuningConfig, + apply_map_to_tuning_buckets: bool = True, ) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]: """Search for cached profiling results matching the current configuration. @@ -399,6 +400,8 @@ def search_cache( custom_op (str): The name of the custom operation to be tuned runners (List[TunableRunner]): List of candidate implementations to profile profile (OptimizationProfile): Optimization profile + apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups. + If False, use raw bucket values for tuning cache storage. Returns: A tuple containing: @@ -406,8 +409,9 @@ def search_cache( runner_id is the index in the current runners list """ for idx, r in enumerate(runners): - if (cache_key := self.get_cache_key(custom_op, r, input_shapes, - tuning_config)) in self.cache: + if (cache_key := self.get_cache_key( + custom_op, r, input_shapes, tuning_config, + apply_map_to_tuning_buckets)) in self.cache: # Return the current index in runners list, not the cached runner_id cached_runner_id, tactic, min_time = self.cache[cache_key] return True, idx, tactic, min_time @@ -420,6 +424,7 @@ def get_cache_key( runner: TunableRunner, input_shapes: Tuple[torch.Size], tuning_config: TuningConfig, + apply_map_to_tuning_buckets: bool = True, ) -> Tuple: return ( custom_op, @@ -430,6 +435,7 @@ def get_cache_key( tuning_config.dynamic_tensor_specs, tuning_config.constraint_specs, tuning_config.tune_max_num_tokens, + apply_map_to_tuning_buckets, ), ) @@ -795,7 +801,11 @@ def choose_one( input_shapes = tuple(self._get_input_sizes(inputs)) is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache( - custom_op, runners, input_shapes, tuning_config) + custom_op, + runners, + input_shapes, + tuning_config, + apply_map_to_tuning_buckets=True) # Early return if it's not tuning, use cache found one or fallback one if not self.is_tuning_mode: @@ -841,7 +851,12 @@ def choose_one( for p in profiles: tensors = self._prepare_input_tensors(p, inputs) is_cache_hit, *_ = self.profiling_cache.search_cache( - custom_op, runners, p.get_opt_shapes(), tuning_config) + custom_op, + runners, + p.get_opt_shapes(), + tuning_config, + apply_map_to_tuning_buckets=False, + ) if not is_cache_hit: # Initialize runner and tactic as None in case of no valid tactic or runners are found best_runner_id, best_tactic, min_time, has_tuning_failure_occurred = self._profile_runners( @@ -928,8 +943,11 @@ def _profile_runners( # Record the failed profiling combinations self.stats.failed_profiling_count[custom_op].add( self.profiling_cache.get_cache_key( - custom_op, runner, profile.get_opt_shapes(), - tuning_config)) + custom_op, + runner, + profile.get_opt_shapes(), + tuning_config, + apply_map_to_tuning_buckets=False)) # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics # or some runtime error occurs during profiling. @@ -942,8 +960,11 @@ def _profile_runners( if best_runner_id is not None: # At least one valid (runner, tactic) pair is found cache_key = self.profiling_cache.get_cache_key( - custom_op, runners[best_runner_id], profile.get_opt_shapes(), - tuning_config) + custom_op, + runners[best_runner_id], + profile.get_opt_shapes(), + tuning_config, + apply_map_to_tuning_buckets=False) self._debug_logger( f"[Autotuner] Profiling runner={runners[best_runner_id]}, tactic={best_tactic} for cache_key={cache_key}." @@ -1119,9 +1140,15 @@ def _optimization_profiles( opt_shapes = spec.gen_tuning_buckets # Add the current input value as one of the opt values opt_shapes = set(opt_shapes) - opt_shapes.add( - spec.map_to_tuning_buckets( - base_profile.shapes[spec.input_idx][spec.dim_idx].val)) + if tuning_config.tune_max_num_tokens is not None: + opt_shapes.add( + min( + tuning_config.tune_max_num_tokens, + base_profile.shapes[spec.input_idx][spec.dim_idx].val, + )) + else: + opt_shapes.add( + base_profile.shapes[spec.input_idx][spec.dim_idx].val) opt_shapes = sorted(list(opt_shapes)) opt_shapes_max = tuple(opt_shapes[1:]) + (float('inf'), ) opt_shapes_max = { @@ -1164,6 +1191,7 @@ def _find_nearest_profile( dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...], constraint_specs: Tuple[ConstraintSpec, ...], tune_max_num_tokens: int = None, + apply_map_to_tuning_buckets: bool = True, ) -> Tuple: """Find the nearest optimization profile for given inputs User can define their own nearest profile generation method to reduce the host overhead. @@ -1171,6 +1199,8 @@ def _find_nearest_profile( Args: shapes: Tuple of input tensor shapes tuning_config: Tuning configuration + apply_map_to_tuning_buckets: If True, apply map_to_tuning_buckets for runtime cache lookups. + If False, use raw bucket values for tuning cache storage. Return: Tuple: A tuple containing: @@ -1180,9 +1210,12 @@ def _find_nearest_profile( base_profile = list(list(shape) for shape in shapes) for spec in dynamic_tensor_specs: - base_profile[spec.input_idx][ - spec.dim_idx] = spec.map_to_tuning_buckets( - base_profile[spec.input_idx][spec.dim_idx]) + # During runtime: apply map_to_tuning_buckets to map input to bucket + # During tuning: no mapper, use raw bucket value + if apply_map_to_tuning_buckets: + base_profile[spec.input_idx][ + spec.dim_idx] = spec.map_to_tuning_buckets( + base_profile[spec.input_idx][spec.dim_idx]) if tune_max_num_tokens is not None: base_profile[spec.input_idx][spec.dim_idx] = min( diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index a8236d88fcf..7802ba4243a 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -207,8 +207,8 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], self.routing_method_type = routing_method_type self.do_finalize = do_finalize - FP4BlockScaleMoERunner.tuning_config = FP4BlockScaleMoERunner.get_tuning_config( - ) + self.tuning_config = FP4BlockScaleMoERunner.get_tuning_config( + self.num_experts // self.local_num_experts) # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing type does not matter @@ -269,17 +269,22 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], return tactics @classmethod - def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + def get_dynamic_tensor_specs(cls, + ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 MAX_PROFILE_BUCKET = 4096 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) - round_rule = lambda x: min(last_positive_power_of_2(x), - MAX_PROFILE_BUCKET) - specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, - round_rule), ) + def round_rule(x: int) -> int: + value = last_positive_power_of_2(x) // ep_size + return min(max(1, value), MAX_PROFILE_BUCKET) + + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, + TUNED_DIM, + m_values, + map_to_tuning_buckets=round_rule), ) return specs @@ -340,9 +345,9 @@ def _constrain_fp4_linear_layout(shapes: Tuple[torch.Size]) -> int: @classmethod @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: + def get_tuning_config(cls, ep_size: int) -> TuningConfig: - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + dynamic_tensor_specs = cls.get_dynamic_tensor_specs(ep_size) constraint_specs = cls.get_constraint_specs() tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, @@ -399,7 +404,7 @@ def fp4_block_scale_moe_runner( topk_ids=topk_ids, hidden_states=hidden_states, routing_logits=routing_logits, - base_tuning_config=FP4BlockScaleMoERunner.get_tuning_config(), + base_tuning_config=kernel_runner.tuning_config, top_k=top_k, num_experts=num_experts, n_group=n_group, @@ -550,8 +555,8 @@ def __init__( self.routed_scaling_factor = routed_scaling_factor self.routing_method_type = routing_method_type - FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config( - ) + self.tuning_config = FP8BlockScaleMoERunner.get_tuning_config( + self.num_experts // self.local_num_experts) # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing @@ -608,18 +613,22 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], return tactics @classmethod - def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + def get_dynamic_tensor_specs(cls, + ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) - round_rule = lambda x: min(last_positive_power_of_2(x), - MAX_PROFILE_BUCKET) - specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, - round_rule), ) + def round_rule(x: int) -> int: + value = last_positive_power_of_2(x) // ep_size + return min(max(1, value), MAX_PROFILE_BUCKET) + + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, + TUNED_DIM, + m_values, + map_to_tuning_buckets=round_rule), ) return specs @@ -662,9 +671,9 @@ def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int: @classmethod @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: + def get_tuning_config(cls, ep_size: int) -> TuningConfig: - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + dynamic_tensor_specs = cls.get_dynamic_tensor_specs(ep_size) constraint_specs = cls.get_constraint_specs() tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, @@ -696,19 +705,17 @@ def fp8_block_scale_moe_runner( topk_ids: Optional[torch.Tensor] = None) -> torch.Tensor: tuner = AutoTuner.get() - kernel_runners = [ - FP8BlockScaleMoERunner( - num_experts, - top_k, - n_group, - topk_group, - intermediate_size, - local_expert_offset, - local_num_experts, - routed_scaling_factor, - routing_method_type, - ) - ] + kernel_runner = FP8BlockScaleMoERunner( + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + routing_method_type, + ) # Prepare dummy topk tensors and hook for AutoTuner profiling routing_logits_for_tuner, topk_weights_for_tuner, topk_ids_for_tuner, tuning_config_with_hook = \ @@ -718,7 +725,7 @@ def fp8_block_scale_moe_runner( topk_ids=topk_ids, hidden_states=hidden_states, routing_logits=routing_logits, - base_tuning_config=FP8BlockScaleMoERunner.get_tuning_config(), + base_tuning_config=kernel_runner.tuning_config, top_k=top_k, num_experts=num_experts, n_group=n_group, @@ -742,7 +749,7 @@ def fp8_block_scale_moe_runner( kernel_runner, best_tactic = tuner.choose_one( "trtllm::fp8_block_scale_moe_runner", - kernel_runners, + [kernel_runner], tuning_config_with_hook, input_tensors_for_tuner, ) @@ -827,8 +834,8 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], self.routing_method_type = routing_method_type self.act_type = act_type - MxE4m3MxE2m1BlockScaleMoERunner.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config( - ) + self.tuning_config = MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config( + self.num_experts // self.local_num_experts) # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing @@ -899,15 +906,22 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], return tactics @classmethod - def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + def get_dynamic_tensor_specs(cls, + ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 + MAX_PROFILE_BUCKET = 4096 - m_values = get_last_power_of_2_num_tokens_buckets(4096) - round_rule = lambda x: min(last_positive_power_of_2(x), 4096) + m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) + + def round_rule(x: int) -> int: + value = last_positive_power_of_2(x) // ep_size + return min(max(1, value), MAX_PROFILE_BUCKET) - specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, - round_rule), ) + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, + TUNED_DIM, + m_values, + map_to_tuning_buckets=round_rule), ) return specs @@ -961,9 +975,9 @@ def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: @classmethod @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: + def get_tuning_config(cls, ep_size: int) -> TuningConfig: - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + dynamic_tensor_specs = cls.get_dynamic_tensor_specs(ep_size) constraint_specs = cls.get_constraint_specs() tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, @@ -1028,7 +1042,7 @@ def mxe4m3_mxe2m1_block_scale_moe_runner( topk_ids=topk_ids, hidden_states=hidden_states, routing_logits=routing_logits, - base_tuning_config=MxE4m3MxE2m1BlockScaleMoERunner.get_tuning_config(), + base_tuning_config=kernel_runner.tuning_config, top_k=top_k, num_experts=num_experts, n_group=n_group, @@ -1118,8 +1132,8 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], self.routing_method_type = routing_method_type self.act_type = act_type - E4m3MxE2m1BlockScaleMoERunner.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config( - ) + self.tuning_config = E4m3MxE2m1BlockScaleMoERunner.get_tuning_config( + self.num_experts // self.local_num_experts) # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing @@ -1190,15 +1204,22 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], return tactics @classmethod - def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + def get_dynamic_tensor_specs(cls, + ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 + MAX_PROFILE_BUCKET = 4096 - m_values = get_last_power_of_2_num_tokens_buckets(4096) - round_rule = lambda x: min(last_positive_power_of_2(x), 4096) + m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) + + def round_rule(x: int) -> int: + value = last_positive_power_of_2(x) // ep_size + return min(max(1, value), MAX_PROFILE_BUCKET) - specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, - round_rule), ) + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, + TUNED_DIM, + m_values, + map_to_tuning_buckets=round_rule), ) return specs @@ -1232,9 +1253,9 @@ def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: @classmethod @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: + def get_tuning_config(cls, ep_size: int) -> TuningConfig: - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + dynamic_tensor_specs = cls.get_dynamic_tensor_specs(ep_size) constraint_specs = cls.get_constraint_specs() tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, @@ -1300,7 +1321,7 @@ def e4m3_mxe2m1_block_scale_moe_runner( topk_ids=topk_ids, hidden_states=hidden_states, routing_logits=routing_logits, - base_tuning_config=E4m3MxE2m1BlockScaleMoERunner.get_tuning_config(), + base_tuning_config=kernel_runner.tuning_config, top_k=top_k, num_experts=num_experts, n_group=n_group, @@ -1389,8 +1410,8 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], self.routing_method_type = routing_method_type self.act_type = act_type - Bf16MxE2m1BlockScaleMoERunner.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config( - ) + self.tuning_config = Bf16MxE2m1BlockScaleMoERunner.get_tuning_config( + self.num_experts // self.local_num_experts) # The unique_id is used by the autotuner to get the cache key, so we hash on members # that influence tactic validity here. e.g. we are tuning FC1 and FC2 so the routing @@ -1459,15 +1480,22 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], return tactics @classmethod - def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + def get_dynamic_tensor_specs(cls, + ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 + MAX_PROFILE_BUCKET = 4096 - m_values = get_last_power_of_2_num_tokens_buckets(4096) - round_rule = lambda x: min(last_positive_power_of_2(x), 4096) + m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) + + def round_rule(x: int) -> int: + value = last_positive_power_of_2(x) // ep_size + return min(max(1, value), MAX_PROFILE_BUCKET) - specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, - round_rule), ) + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, + TUNED_DIM, + m_values, + map_to_tuning_buckets=round_rule), ) return specs @@ -1501,9 +1529,9 @@ def _constrain_routing_logits(shapes: Tuple[torch.Size]) -> int: @classmethod @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: + def get_tuning_config(cls, ep_size: int) -> TuningConfig: - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + dynamic_tensor_specs = cls.get_dynamic_tensor_specs(ep_size) constraint_specs = cls.get_constraint_specs() tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, @@ -1566,7 +1594,7 @@ def bf16_mxe2m1_block_scale_moe_runner( topk_ids=topk_ids, hidden_states=hidden_states, routing_logits=routing_logits, - base_tuning_config=Bf16MxE2m1BlockScaleMoERunner.get_tuning_config(), + base_tuning_config=kernel_runner.tuning_config, top_k=top_k, num_experts=num_experts, n_group=n_group, @@ -1650,8 +1678,8 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], self.do_finalize = do_finalize self.act_type = act_type - FP8FP4BlockScaleMoERunner.tuning_config = FP8FP4BlockScaleMoERunner.get_tuning_config( - ) + self.tuning_config = FP8FP4BlockScaleMoERunner.get_tuning_config( + self.num_experts // self.local_num_experts) def unique_id(self): return ( @@ -1713,17 +1741,22 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], return tactics @classmethod - def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: + def get_dynamic_tensor_specs(cls, + ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 MAX_PROFILE_BUCKET = 4096 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) - round_rule = lambda x: min(last_positive_power_of_2(x), - MAX_PROFILE_BUCKET) - specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, - round_rule), ) + def round_rule(x: int) -> int: + value = last_positive_power_of_2(x) // ep_size + return min(max(1, value), MAX_PROFILE_BUCKET) + + specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, + TUNED_DIM, + m_values, + map_to_tuning_buckets=round_rule), ) return specs @@ -1759,9 +1792,9 @@ def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int: @classmethod @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: + def get_tuning_config(cls, ep_size: int) -> TuningConfig: - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() + dynamic_tensor_specs = cls.get_dynamic_tensor_specs(ep_size) constraint_specs = cls.get_constraint_specs() tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, @@ -1820,7 +1853,7 @@ def fp8_fp4_block_scale_moe_runner( topk_ids=topk_ids, hidden_states=hidden_states, routing_logits=routing_logits, - base_tuning_config=FP8FP4BlockScaleMoERunner.get_tuning_config(), + base_tuning_config=kernel_runner.tuning_config, top_k=top_k, num_experts=num_experts, n_group=n_group, diff --git a/tests/unittest/_torch/misc/test_autotuner.py b/tests/unittest/_torch/misc/test_autotuner.py index a6116d544f2..fcbf6550bdf 100644 --- a/tests/unittest/_torch/misc/test_autotuner.py +++ b/tests/unittest/_torch/misc/test_autotuner.py @@ -38,18 +38,21 @@ def test_multi_dynamic_dims(): tuner = autotuner.AutoTuner() x = torch.rand([5, 1024]) - w = torch.rand([7, 19]) + w = torch.rand([7, 9]) dynamic_tensor_specs = ( DynamicTensorSpec(0, 0, [1, 3, 5]), DynamicTensorSpec(0, 1, [16, 24, 1024]), - DynamicTensorSpec(1, 1, [3, 7, 9], lambda x: x // 2), + # map_to_tuning_buckets is only applied at runtime, not during tuning + DynamicTensorSpec(1, + 1, [3, 7, 9], + map_to_tuning_buckets=lambda x: x // 2), ) profiles = tuner._optimization_profiles( tuning_config=TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs), inputs=[x, w]) # choice(0, 0) * choice(0, 1) * choice(1, 1) - # 3 * 3 * 3 = 27, because 19 is mapped to 9 and already inside the bucket + # 3 * 3 * 3 = 27, input value 9 is already inside the bucket assert len(profiles) == 27 sample_0 = OptimizationProfile(shapes=[[ DynamicDim(min=1, opt=1, max=3), @@ -171,6 +174,75 @@ def test_autotuner_cache_basic(): m //= 2 +def test_bucket_mapping(): + """Test that map_to_tuning_buckets correctly maps runtime sizes to tuning buckets. + + This test demonstrates the single mapper approach: + - During tuning: NO mapper is applied, raw bucket values are used as cache keys + - During runtime: map_to_tuning_buckets is applied to map buffer size to actual work size + + With sparsity=0.25, the buffer contains 25% actual work: + - Tuning stores buckets: 1, 2, 4, 8, 16, 32 as raw cache keys + - Runtime buffer 4 -> maps to bucket int(4 * 0.25) = 1 + - Runtime buffer 16 -> maps to bucket int(16 * 0.25) = 4 + + In MoE EP, the input buffer is allocated for worst-case but sparsely filled. + Using map_to_tuning_buckets allows us to map buffer size to actual work size at runtime. + """ + w = torch.randn(64, 128) + tuner = AutoTuner.get() + tuner.clear_cache() + + # Sparsity indicates the fraction of buffer containing valid work + sparsity = 0.25 + + tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( + input_idx=0, + dim_idx=0, + gen_tuning_buckets=get_power_of_2_num_tokens_buckets(M), + map_to_tuning_buckets=lambda x: int(x * sparsity)), ), ) + + with autotune(): + tuner.choose_one("test_bucket_mapping", [GemmRunner()], tuning_config, + [torch.randn(1, 64), w]) + + # Verify cache entries use raw tuning bucket values + cache_entries = tuner.profiling_cache.get_specific_custom_op( + "test_bucket_mapping") + + # Extract the first dimension of the first input shape from each cache key + assert len(cache_entries) == len(tuning_config.dynamic_tensor_specs[0].gen_tuning_buckets), \ + f"Expected {len(tuning_config.dynamic_tensor_specs[0].gen_tuning_buckets)} cache entries, got {len(cache_entries)}" + + # Test runtime mapping: buffer size is mapped via map_to_runtime_buckets + # to find the correct tuning bucket based on actual work size + test_cases = [ + # size 4 -> valid work size (4*0.25)=1, tactic 0 since 1 <= M//2 + (4, 1, 0), + # size 8 -> valid work size (8*0.25)=2, tactic 0 since 2 <= M//2 + (8, 2, 0), + # size 16 -> valid work size (16*0.25)=4, tactic 0 since 4 <= M//2 + (16, 4, 0), + # size 32 -> valid work size (32*0.25)=8, tactic 0 since 8 <= M//2 + (32, 8, 0), + # size 64 -> valid work size (64*0.25)=16, tactic 0 since 16 <= M//2 + (64, 16, 0), + # size 128 -> valid work size (128*0.25)=32, tactic 1 since 32 > M//2 + (128, 32, 1), + # size 256 -> valid work size (256*0.25)=64, tactic -1 since 64 > M + (256, 64, -1), + ] + + for buffer_size, valid_size, expected_tactic in test_cases: + # Verify cache lookup succeeds with the mapped bucket + x = torch.randn(buffer_size, 64) + runner, tactic = tuner.choose_one("test_bucket_mapping", [GemmRunner()], + tuning_config, [x, w]) + assert ( + tactic == expected_tactic + ), f"buffer size={buffer_size} -> valid work size={valid_size}, expected tactic {expected_tactic} but got {tactic}" + + def test_autotuner_try_block(): class PartialCrashedRunner(TunableRunner):