5
5
6
6
import torch
7
7
import triton
8
+ from torch ._inductor .runtime .benchmarking import benchmarker
8
9
9
10
NS_TO_MS = 1e-6
10
11
@@ -125,6 +126,42 @@ def _summarize_statistics(times, quantiles, return_mode):
125
126
return getattr (torch , return_mode )(times ).item ()
126
127
127
128
129
+ def _do_bench_inductor (fn , warmup , rep , grad_to_none = None ):
130
+ """Measure latency using inductor benchmarker.
131
+
132
+ Args:
133
+ warmup: Target warmup time in milliseconds (matches triton.testing.do_bench)
134
+ rep: Target total measurement time in milliseconds (matches triton.testing.do_bench)
135
+ grad_to_none: Tensors whose gradients should be cleared before each measurement
136
+
137
+ Returns:
138
+ List of measured times in milliseconds.
139
+ """
140
+ # First, estimate the runtime with a single measurement
141
+ estimate_ms = benchmarker .benchmark_gpu (fn , estimation_iters = 5 , benchmark_iters = 10 )
142
+
143
+ # Calculate number of iterations based on target rep time
144
+ # Similar to how triton.testing.do_bench calculates iterations
145
+ if estimate_ms == 0 :
146
+ n_repeat = 1000 # Default if function is very fast
147
+ else :
148
+ n_repeat = max (1 , int (rep / estimate_ms ))
149
+
150
+ # Collect multiple measurements like triton.testing.do_bench with return_mode='all'
151
+ times_ms = []
152
+ for _ in range (n_repeat ):
153
+ # Clear gradients BEFORE timing (like triton.testing.do_bench)
154
+ if grad_to_none is not None :
155
+ for x in grad_to_none :
156
+ x .grad = None
157
+
158
+ # Measure only the function execution time
159
+ ms_time = benchmarker .benchmark_gpu (fn )
160
+ times_ms .append (ms_time )
161
+
162
+ return times_ms
163
+
164
+
128
165
def _do_bench_cpu (
129
166
fn , warmup , rep = 20 , grad_to_none = None , quantiles = None , return_mode = "mean"
130
167
):
@@ -174,8 +211,13 @@ def do_bench_wrapper(
174
211
device : str = "cuda" ,
175
212
use_cuda_graphs : bool = False ,
176
213
bypass_fail : bool = False ,
214
+ latency_measure_mode : str = "triton_do_bench" ,
177
215
) -> Optional [Latency ]:
178
- """Wrapper to triton's do_bench to gain latency."""
216
+ """Wrapper to triton's do_bench to gain latency.
217
+
218
+ Args:
219
+ latency_measure_mode: Either "triton_do_bench" (default) or "inductor_benchmarker"
220
+ """
179
221
try :
180
222
if device == "cpu" :
181
223
return Latency (
@@ -198,15 +240,25 @@ def do_bench_wrapper(
198
240
)
199
241
)
200
242
else :
201
- return Latency (
202
- times = triton .testing .do_bench (
203
- fn ,
204
- warmup = warmup ,
205
- rep = rep ,
206
- return_mode = "all" ,
207
- grad_to_none = grad_to_none ,
243
+ if latency_measure_mode == "inductor_benchmarker" :
244
+ return Latency (
245
+ times = _do_bench_inductor (
246
+ fn ,
247
+ warmup = warmup ,
248
+ rep = rep ,
249
+ grad_to_none = grad_to_none ,
250
+ )
251
+ )
252
+ else : # default to triton do_bench
253
+ return Latency (
254
+ times = triton .testing .do_bench (
255
+ fn ,
256
+ warmup = warmup ,
257
+ rep = rep ,
258
+ return_mode = "all" ,
259
+ grad_to_none = grad_to_none ,
260
+ )
208
261
)
209
- )
210
262
except Exception as e :
211
263
if not bypass_fail :
212
264
raise e
0 commit comments