diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 2ee8d29ccca..ace02766bc6 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1657,7 +1657,7 @@ def _( class AllReduceRunner(TunableRunner): tuning_config = TuningConfig( dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, get_last_power_of_2_num_tokens_buckets(8192), + 0, 0, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2), ), constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ), distributed_tuning_strategy=DistributedTuningStrategy.MERGE,