66import inspect
77from typing import Dict
88
9- from ..testing import do_bench , do_bench_cudagraph
109from .jit import KernelInterface
1110from .errors import OutOfResources
11+ from .driver import driver
1212
1313
1414class Autotuner (KernelInterface ):
@@ -24,9 +24,10 @@ def __init__(
2424 pre_hook = None ,
2525 post_hook = None ,
2626 prune_configs_by : Dict = None ,
27- warmup = 25 ,
28- rep = 100 ,
27+ warmup = None ,
28+ rep = None ,
2929 use_cuda_graph = False ,
30+ do_bench = None ,
3031 ):
3132 """
3233 :param prune_configs_by: a dict of functions that are used to prune configs, fields:
@@ -88,10 +89,36 @@ def _post_hook(args, exception):
8889 self .base_fn = fn
8990 while not inspect .isfunction (self .base_fn ):
9091 self .base_fn = self .base_fn .fn
91- self .num_warmups = warmup
92- self .num_reps = rep
93- import torch
94- self .use_cuda_graph = use_cuda_graph and torch .cuda .is_available ()
92+
93+ # If we got explicitly called via the old interface, raise a warning
94+ # and proceed with the old behavior.
95+ if warmup is not None or rep is not None or use_cuda_graph :
96+ import warnings
97+ warnings .warn (("warmup, rep, and use_cuda_graph parameters are deprecated. See "
98+ "https://github.com/triton-lang/triton/pull/4496 for details." ), DeprecationWarning ,
99+ stacklevel = 1 )
100+ if use_cuda_graph :
101+ from ..testing import do_bench_cudagraph
102+ self .do_bench = lambda kernel_call , quantiles : do_bench_cudagraph (
103+ kernel_call ,
104+ rep = rep if rep is not None else 100 ,
105+ quantiles = quantiles ,
106+ )
107+ return
108+
109+ import triton .testing
110+ self .do_bench = lambda kernel_call , quantiles : triton .testing .do_bench (
111+ kernel_call ,
112+ warmup = warmup if warmup is not None else 25 ,
113+ rep = rep if rep is not None else 100 ,
114+ quantiles = quantiles ,
115+ )
116+ return
117+
118+ if do_bench is None :
119+ self .do_bench = driver .active .get_benchmarker ()
120+ else :
121+ self .do_bench = do_bench
95122
96123 def _bench (self , * args , config , ** meta ):
97124 from ..compiler .errors import CompileTimeAssertionFailure
@@ -125,9 +152,7 @@ def kernel_call():
125152 self .post_hook (args , exception = None )
126153
127154 try :
128- if self .use_cuda_graph :
129- return do_bench_cudagraph (kernel_call , rep = self .num_reps , quantiles = (0.5 , 0.2 , 0.8 ))
130- return do_bench (kernel_call , warmup = self .num_warmups , rep = self .num_reps , quantiles = (0.5 , 0.2 , 0.8 ))
155+ return self .do_bench (kernel_call , quantiles = (0.5 , 0.2 , 0.8 ))
131156 except (OutOfResources , CompileTimeAssertionFailure ):
132157 return [float ("inf" ), float ("inf" ), float ("inf" )]
133158
@@ -257,7 +282,7 @@ def __str__(self):
257282
258283
259284def autotune (configs , key , prune_configs_by = None , reset_to_zero = None , restore_value = None , pre_hook = None , post_hook = None ,
260- warmup = 25 , rep = 100 , use_cuda_graph = False ):
285+ warmup = None , rep = None , use_cuda_graph = False , do_bench = None ):
261286 """
262287 Decorator for auto-tuning a :code:`triton.jit`'d function.
263288
@@ -305,10 +330,12 @@ def kernel(x_ptr, x_size, **META):
305330 'args': a list of arguments passed to the kernel.
306331 'exception': the exception raised by the kernel in case of a compilation or runtime error.
307332 :type post_hook: lambda args, exception
308- :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25 .
333+ :param warmup: warmup time (in ms) to pass to benchmarking (deprecated) .
309334 :type warmup: int
310- :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100 .
335+ :param rep: repetition time (in ms) to pass to benchmarking (deprecated) .
311336 :type rep: int
337+ :param do_bench: a benchmark function to measure the time of each run.
338+ :type do_bench: lambda fn, quantiles
312339 """
313340
314341 def decorator (fn ):
0 commit comments