Skip to content

Commit 3a38ec7

Browse files
kundaMwizapytorchmergebot
authored andcommitted
[inductor] Expand use of generic benchmark function (pytorch#164938)
Use the more generic `Benchmarker.benchmark` function to allow benchmarking other devices that support the required functionality, for example prologue and epilogue fusion can be benchmarked for triton CPU. Pull Request resolved: pytorch#164938 Approved by: https://github.com/nmacchioni, https://github.com/eellison
1 parent 77b9399 commit 3a38ec7

File tree

10 files changed

+122
-47
lines changed

10 files changed

+122
-47
lines changed

torch/_inductor/codegen/multi_kernel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,12 @@ def inner():
381381
return inner
382382

383383
return [
384-
benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40)
384+
benchmarker.benchmark(
385+
wrap_fn(kernel, index),
386+
# Currently the kernel type must be a CachingAutotuner
387+
device=kernel.device_props.type,
388+
rep=40,
389+
)
385390
for index, kernel in enumerate(self.kernels)
386391
]
387392

torch/_inductor/codegen/subgraph.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float:
111111
bm_func([*sym_inputs, *args])
112112
if config.profile_bandwidth_with_do_bench_using_profiling:
113113
return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args]))
114-
115-
if self.layout.device.type == "cpu":
116-
return benchmarker.benchmark_cpu(lambda: bm_func([*sym_inputs, *args]))
117-
else:
118-
return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args]))
114+
return benchmarker.benchmark(
115+
# Shallow clone args since bm_func may clear args
116+
lambda: bm_func([*sym_inputs, *args]),
117+
device=benchmarker.infer_device(*sym_inputs, *args),
118+
)
119119

120120
def hash_key(self) -> str:
121121
return "-".join(

torch/_inductor/codegen/triton.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4838,7 +4838,7 @@ def codegen_kernel_benchmark(self, num_gb: Optional[float]) -> IndentedBuffer:
48384838

48394839
result.writeline("args = get_args()")
48404840
result.writeline(
4841-
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
4841+
f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)" # noqa: B950 line too long
48424842
)
48434843
result.writeline(f"num_gb = {num_gb}")
48444844
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
@@ -5797,18 +5797,21 @@ def load_cache():
57975797
# skip benchmarking the kernel if there are register spills
57985798
ms = float("inf")
57995799
else:
5800+
device = V.graph.get_current_device_or_throw()
58005801
# We have to clone the inplace updated arguments to avoid earlier calls
58015802
# generating out of range indices for later calls.
5802-
ms = benchmarker.benchmark_gpu(
5803-
lambda: call(wrapped_jit_function.clone_args(*args)[0])
5803+
ms = benchmarker.benchmark(
5804+
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
5805+
device=device,
58045806
)
58055807
# overhead of cloning args gives bias for fusing the kernel
58065808
# in the case of mutating/in-placeable second fusion
58075809
# TODO - would be better as a hook in triton do_bench that reset
58085810
# the input values between benchmarking
58095811
if len(wrapped_jit_function.mutated_arg_names) > 0:
5810-
ms = ms - benchmarker.benchmark_gpu(
5811-
lambda: wrapped_jit_function.clone_args(*args)
5812+
ms = ms - benchmarker.benchmark(
5813+
lambda: wrapped_jit_function.clone_args(*args),
5814+
device=str(device),
58125815
)
58135816

58145817
log.debug(
@@ -5977,13 +5980,16 @@ def store_cache():
59775980
# skip benchmarking the kernel if there are register spills
59785981
ms = ms_clone = float("inf")
59795982
else:
5983+
device = V.graph.get_current_device_or_throw()
59805984
# We have to clone the inplace updated arguments to avoid earlier calls
59815985
# generating out of range indices for later calls.
5982-
ms = benchmarker.benchmark_gpu(
5983-
lambda: call(wrapped_jit_function.clone_args(*args)[0])
5986+
ms = benchmarker.benchmark(
5987+
lambda: call(wrapped_jit_function.clone_args(*args)[0]),
5988+
device=device,
59845989
)
5985-
ms_clone = benchmarker.benchmark_gpu(
5986-
lambda: wrapped_jit_function.clone_args(*args)[0]
5990+
ms_clone = benchmarker.benchmark(
5991+
lambda: wrapped_jit_function.clone_args(*args)[0],
5992+
device=device,
59875993
)
59885994

59895995
log.debug(

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer:
896896
result.writeline(f"return {', '.join(var_names)},")
897897

898898
result.writelines(["\n", "\n", "def call(args):"])
899+
device = V.graph.get_current_device_or_throw()
899900
index = V.graph.get_current_device_or_throw().index
900901
with result.indent():
901902
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
@@ -930,7 +931,7 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer:
930931

931932
result.writeline("args = get_args()")
932933
result.writeline(
933-
"ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)"
934+
f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)"
934935
)
935936
result.writeline(f"num_gb = {num_gb}")
936937
result.writeline("gb_per_s = num_gb / (ms / 1e3)")

torch/_inductor/ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5230,7 +5230,9 @@ def benchmark(self, *args: Any, out: torch.Tensor) -> float:
52305230
}
52315231
if config.profile_bandwidth_with_do_bench_using_profiling:
52325232
return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type]
5233-
return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs)
5233+
return benchmarker.benchmark(
5234+
algo, args, {"out": out}, device=None, **benchmark_configs
5235+
)
52345236

52355237
def call_name(self) -> str:
52365238
raise NotImplementedError

torch/_inductor/runtime/benchmarking.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from functools import cached_property, wraps
55
from itertools import chain
66
from statistics import median
7-
from typing import Any, Callable
7+
from typing import Any, Callable, Optional, Union
88
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
99

1010
import torch
11+
import torch.utils._pytree as pytree
1112
from torch._dynamo.utils import counters, dynamo_timed
1213
from torch._inductor.config import use_experimental_benchmarker
1314

@@ -92,15 +93,45 @@ def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
9293

9394

9495
class Benchmarker:
96+
"""
97+
A device-agnostic benchmarking utility for measuring the runtime of
98+
inductor generated callables.
99+
"""
100+
95101
def __init__(self: Self) -> None:
96102
pass
97103

104+
def infer_device(self, *fn_args: Any, **fn_kwargs: Any) -> torch.device:
105+
inferred_device: Optional[torch.device] = None
106+
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
107+
# Some callables take nested structures as arguments so use the
108+
# flattened form to find any tensors
109+
for arg_or_kwarg_leaf in pytree.tree_leaves(arg_or_kwarg):
110+
if not isinstance(arg_or_kwarg_leaf, torch.Tensor):
111+
continue
112+
if inferred_device is None:
113+
inferred_device = arg_or_kwarg_leaf.device
114+
elif arg_or_kwarg_leaf.device != inferred_device:
115+
raise ValueError(
116+
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
117+
)
118+
119+
if inferred_device is None:
120+
raise ValueError(
121+
"Can't safely infer the device type of `fn` with no device types"
122+
" in `fn_args` or `fn_kwargs`. Use a direct benchmarking method instead e.g. "
123+
"`Benchmarker.benchmark_cpu` or `Benchmarker.benchmark_gpu`."
124+
)
125+
126+
return inferred_device
127+
98128
@time_and_count
99129
def benchmark(
100130
self: Self,
101131
fn: Callable[..., Any],
102-
fn_args: tuple[Any, ...],
103-
fn_kwargs: dict[str, Any],
132+
fn_args: Optional[tuple[Any, ...]] = None,
133+
fn_kwargs: Optional[dict[str, Any]] = None,
134+
device: Optional[Union[str, torch.device]] = None,
104135
**kwargs: Any,
105136
) -> float:
106137
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
@@ -109,35 +140,54 @@ def benchmark(
109140
device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
110141
`ValueError(...)` if we can't safely infer the device type of `fn`; for example,
111142
if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
112-
types are found.
143+
types are found. To bypass device inference, provide the device to the `device`
144+
parameter.
145+
146+
WARNING: if `fn` mutates `fn_args` or `fn_kwargs`, benchmarking may fail unexpectedly.
147+
For example, if `fn` clears a mutable object, subsequent invocations of `fn` during
148+
benchmarking will fail. In such cases, `fn` should handle cloning its arguments internally.
149+
If device inference is required, `Benchmarker.infer_device` can be used prior to calling
150+
this method without any arguments for `fn_args` and `fn_kwargs`.
113151
114152
Arguments:
115153
- fn: The function to benchmark.
116154
- fn_args: The function's arguments.
117155
- fn_kwargs: The function's kwargs.
118156
119157
Keyword Arguments:
158+
- device: Which device to use for benchmarking. If not provided the device will be attempted
159+
to be inferred from `fn_args` and `fn_kwargs`.
120160
- **kwargs: The benchmarking implementation's kwargs.
121161
122162
Returns:
123163
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
124164
"""
125-
inferred_device = None
126-
# pyrefly: ignore [bad-assignment]
127-
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
128-
if not isinstance(arg_or_kwarg, torch.Tensor):
129-
continue
130-
if inferred_device is None:
131-
inferred_device = arg_or_kwarg.device
132-
elif arg_or_kwarg.device != inferred_device:
165+
inferred_device: Optional[torch.device] = None
166+
if device is not None:
167+
inferred_device = (
168+
torch.device(device) if isinstance(device, str) else device
169+
)
170+
else:
171+
if fn_args is None and fn_kwargs is None:
133172
raise ValueError(
134-
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
173+
"`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided."
135174
)
136-
if inferred_device is None:
137-
raise ValueError(
138-
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
139-
)
140-
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
175+
176+
fn_args = fn_args or tuple()
177+
fn_kwargs = fn_kwargs or {}
178+
inferred_device = self.infer_device(*fn_args, **fn_kwargs)
179+
180+
assert isinstance(inferred_device, torch.device)
181+
182+
fn_args = fn_args or tuple()
183+
fn_kwargs = fn_kwargs or {}
184+
185+
# No need to wrap if the callable takes no arguments
186+
if len(fn_args) == 0 and len(fn_kwargs) == 0:
187+
_callable = fn
188+
else:
189+
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
190+
141191
if inferred_device == torch.device("cpu"):
142192
return self.benchmark_cpu(_callable, **kwargs)
143193
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -917,11 +917,15 @@ def kernel_call():
917917

918918
return do_bench_using_profiling(kernel_call, warmup=10, rep=40)
919919

920-
if self.device_props.type == "cpu":
921-
return benchmarker.benchmark_cpu(kernel_call)
922-
923-
return benchmarker.benchmark_gpu(
924-
kernel_call, rep=40, is_vetted_benchmarking=True
920+
benchmark_kwargs = (
921+
{}
922+
if self.device_props.type == "cpu"
923+
else {"rep": 40, "is_vetted_benchmarking": True}
924+
)
925+
return benchmarker.benchmark(
926+
fn=kernel_call,
927+
device=self.device_props.type,
928+
**benchmark_kwargs, # type: ignore[arg-type]
925929
)
926930

927931
def copy_args_to_cpu_if_needed(self, *args, **kwargs):

torch/_inductor/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3552,8 +3552,8 @@ def speedup_by_fusion(
35523552
device = node_list_1[0].get_device()
35533553
assert device
35543554

3555-
# don't support benchmark fusion for CPU right now.
3556-
if device.type == "cpu":
3555+
# don't support benchmark fusion for CPU C++ backend right now.
3556+
if device.type == "cpu" and config.cpu_backend != "triton":
35573557
return True
35583558

35593559
node_list_2 = node2.get_nodes()
@@ -5921,8 +5921,8 @@ def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool:
59215921
subkernel_nodes = nodes
59225922
device = subkernel_nodes[0].get_device()
59235923

5924-
# don't support benchmark fusion for CPU right now.
5925-
if device is None or device.type == "cpu":
5924+
# don't support benchmark fusion for CPU C++ backend right now.
5925+
if device is None or (device.type == "cpu" and config.cpu_backend != "triton"):
59265926
return True
59275927

59285928
from triton.compiler.errors import CompilationError

torch/_inductor/select_algorithm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2709,8 +2709,10 @@ def __call__(
27092709

27102710
# Templates selected with input_gen_fns require specific input data to avoid IMA
27112711
# Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection
2712-
# TODO(jgong5): support multi-template on CPU
2713-
if input_gen_fns is not None or layout.device.type == "cpu":
2712+
# TODO(jgong5): support multi-template on CPU C++ backend
2713+
if input_gen_fns is not None or (
2714+
layout.device.type == "cpu" and config.cpu_backend != "triton"
2715+
):
27142716
return_multi_template = False
27152717

27162718
# TODO - assert that we have not mutating kernels here

torch/_inductor/wrapper_benchmark.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def benchmark_all_kernels(
9393
continue
9494

9595
triton_kernel = get_triton_kernel(kernel_mod)
96+
device_type = triton_kernel.device_props.type
9697
kernel_category = get_kernel_category(kernel_mod)
9798
args = kernel_mod.get_args()
9899
num_in_out_ptrs = len(
@@ -137,7 +138,11 @@ def get_info_str(
137138
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
138139
)
139140
else:
140-
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
141+
ms = benchmarker.benchmark(
142+
lambda: kernel_mod.call(args),
143+
device=device_type,
144+
rep=40,
145+
)
141146
assert len(triton_kernel.launchers) == 1, (
142147
"Autotuner should have selected the best config"
143148
)

0 commit comments

Comments
 (0)