5
5
6
6
import torch
7
7
import triton
8
+ from torch ._inductor .runtime .benchmarking import benchmarker
8
9
9
10
NS_TO_MS = 1e-6
10
11
@@ -125,6 +126,41 @@ def _summarize_statistics(times, quantiles, return_mode):
125
126
return getattr (torch , return_mode )(times ).item ()
126
127
127
128
129
+ def _do_bench_inductor (fn , warmup , rep , grad_to_none = None ):
130
+ """Measure latency using inductor benchmarker.
131
+
132
+ Args:
133
+ warmup: Target warmup time in milliseconds (matches triton.testing.do_bench)
134
+ rep: Target total measurement time in milliseconds (matches triton.testing.do_bench)
135
+ grad_to_none: Tensors whose gradients should be cleared before each measurement
136
+
137
+ Returns:
138
+ List of measured times in milliseconds.
139
+ """
140
+ # First, estimate the runtime with a single measurement
141
+ estimate_ms = benchmarker .benchmark_gpu (fn , estimation_iters = 5 , benchmark_iters = 10 )
142
+
143
+ # Calculate number of iterations based on target rep time
144
+ # Similar to how triton.testing.do_bench calculates iterations
145
+ if estimate_ms == 0 :
146
+ n_repeat = 1000 # Default if function is very fast
147
+ else :
148
+ n_repeat = max (1 , int (rep / estimate_ms ))
149
+
150
+ # Collect multiple measurements like triton.testing.do_bench with return_mode='all'
151
+ times_ms = []
152
+ for _ in range (n_repeat ):
153
+ if grad_to_none is not None :
154
+ for x in grad_to_none :
155
+ x .grad = None
156
+
157
+ # Measure only the function execution time
158
+ ms_time = benchmarker .benchmark_gpu (fn )
159
+ times_ms .append (ms_time )
160
+
161
+ return times_ms
162
+
163
+
128
164
def _do_bench_cpu (
129
165
fn , warmup , rep = 20 , grad_to_none = None , quantiles = None , return_mode = "mean"
130
166
):
@@ -174,8 +210,13 @@ def do_bench_wrapper(
174
210
device : str = "cuda" ,
175
211
use_cuda_graphs : bool = False ,
176
212
bypass_fail : bool = False ,
213
+ latency_measure_mode : str = "triton_do_bench" ,
177
214
) -> Optional [Latency ]:
178
- """Wrapper to triton's do_bench to gain latency."""
215
+ """Wrapper to triton's do_bench to gain latency.
216
+
217
+ Args:
218
+ latency_measure_mode: Either "triton_do_bench" (default) or "inductor_benchmarker"
219
+ """
179
220
try :
180
221
if device == "cpu" :
181
222
return Latency (
@@ -198,15 +239,25 @@ def do_bench_wrapper(
198
239
)
199
240
)
200
241
else :
201
- return Latency (
202
- times = triton .testing .do_bench (
203
- fn ,
204
- warmup = warmup ,
205
- rep = rep ,
206
- return_mode = "all" ,
207
- grad_to_none = grad_to_none ,
242
+ if latency_measure_mode == "inductor_benchmarker" :
243
+ return Latency (
244
+ times = _do_bench_inductor (
245
+ fn ,
246
+ warmup = warmup ,
247
+ rep = rep ,
248
+ grad_to_none = grad_to_none ,
249
+ )
250
+ )
251
+ else : # default to triton do_bench
252
+ return Latency (
253
+ times = triton .testing .do_bench (
254
+ fn ,
255
+ warmup = warmup ,
256
+ rep = rep ,
257
+ return_mode = "all" ,
258
+ grad_to_none = grad_to_none ,
259
+ )
208
260
)
209
- )
210
261
except Exception as e :
211
262
if not bypass_fail :
212
263
raise e
0 commit comments