44from functools import cached_property , wraps
55from itertools import chain
66from statistics import median
7- from typing import Any , Callable
7+ from typing import Any , Callable , Optional , Union
88from typing_extensions import Concatenate , ParamSpec , Self , TypeVar
99
1010import torch
11+ import torch .utils ._pytree as pytree
1112from torch ._dynamo .utils import counters , dynamo_timed
1213from torch ._inductor .config import use_experimental_benchmarker
1314
@@ -92,15 +93,45 @@ def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
9293
9394
9495class Benchmarker :
96+ """
97+ A device-agnostic benchmarking utility for measuring the runtime of
98+ inductor generated callables.
99+ """
100+
95101 def __init__ (self : Self ) -> None :
96102 pass
97103
104+ def infer_device (self , * fn_args : Any , ** fn_kwargs : Any ) -> torch .device :
105+ inferred_device : Optional [torch .device ] = None
106+ for arg_or_kwarg in chain (fn_args , fn_kwargs .values ()):
107+ # Some callables take nested structures as arguments so use the
108+ # flattened form to find any tensors
109+ for arg_or_kwarg_leaf in pytree .tree_leaves (arg_or_kwarg ):
110+ if not isinstance (arg_or_kwarg_leaf , torch .Tensor ):
111+ continue
112+ if inferred_device is None :
113+ inferred_device = arg_or_kwarg_leaf .device
114+ elif arg_or_kwarg_leaf .device != inferred_device :
115+ raise ValueError (
116+ "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
117+ )
118+
119+ if inferred_device is None :
120+ raise ValueError (
121+ "Can't safely infer the device type of `fn` with no device types"
122+ " in `fn_args` or `fn_kwargs`. Use a direct benchmarking method instead e.g. "
123+ "`Benchmarker.benchmark_cpu` or `Benchmarker.benchmark_gpu`."
124+ )
125+
126+ return inferred_device
127+
98128 @time_and_count
99129 def benchmark (
100130 self : Self ,
101131 fn : Callable [..., Any ],
102- fn_args : tuple [Any , ...],
103- fn_kwargs : dict [str , Any ],
132+ fn_args : Optional [tuple [Any , ...]] = None ,
133+ fn_kwargs : Optional [dict [str , Any ]] = None ,
134+ device : Optional [Union [str , torch .device ]] = None ,
104135 ** kwargs : Any ,
105136 ) -> float :
106137 """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
@@ -109,35 +140,54 @@ def benchmark(
109140 device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
110141 `ValueError(...)` if we can't safely infer the device type of `fn`; for example,
111142 if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
112- types are found.
143+ types are found. To bypass device inference, provide the device to the `device`
144+ parameter.
145+
146+ WARNING: if `fn` mutates `fn_args` or `fn_kwargs`, benchmarking may fail unexpectedly.
147+ For example, if `fn` clears a mutable object, subsequent invocations of `fn` during
148+ benchmarking will fail. In such cases, `fn` should handle cloning its arguments internally.
149+ If device inference is required, `Benchmarker.infer_device` can be used prior to calling
150+ this method without any arguments for `fn_args` and `fn_kwargs`.
113151
114152 Arguments:
115153 - fn: The function to benchmark.
116154 - fn_args: The function's arguments.
117155 - fn_kwargs: The function's kwargs.
118156
119157 Keyword Arguments:
158+ - device: Which device to use for benchmarking. If not provided the device will be attempted
159+ to be inferred from `fn_args` and `fn_kwargs`.
120160 - **kwargs: The benchmarking implementation's kwargs.
121161
122162 Returns:
123163 - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
124164 """
125- inferred_device = None
126- # pyrefly: ignore [bad-assignment]
127- for arg_or_kwarg in chain (fn_args , fn_kwargs .values ()):
128- if not isinstance (arg_or_kwarg , torch .Tensor ):
129- continue
130- if inferred_device is None :
131- inferred_device = arg_or_kwarg .device
132- elif arg_or_kwarg .device != inferred_device :
165+ inferred_device : Optional [torch .device ] = None
166+ if device is not None :
167+ inferred_device = (
168+ torch .device (device ) if isinstance (device , str ) else device
169+ )
170+ else :
171+ if fn_args is None and fn_kwargs is None :
133172 raise ValueError (
134- "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`! "
173+ "`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided. "
135174 )
136- if inferred_device is None :
137- raise ValueError (
138- "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
139- )
140- _callable = lambda : fn (* fn_args , ** fn_kwargs ) # noqa: E731
175+
176+ fn_args = fn_args or tuple ()
177+ fn_kwargs = fn_kwargs or {}
178+ inferred_device = self .infer_device (* fn_args , ** fn_kwargs )
179+
180+ assert isinstance (inferred_device , torch .device )
181+
182+ fn_args = fn_args or tuple ()
183+ fn_kwargs = fn_kwargs or {}
184+
185+ # No need to wrap if the callable takes no arguments
186+ if len (fn_args ) == 0 and len (fn_kwargs ) == 0 :
187+ _callable = fn
188+ else :
189+ _callable = lambda : fn (* fn_args , ** fn_kwargs ) # noqa: E731
190+
141191 if inferred_device == torch .device ("cpu" ):
142192 return self .benchmark_cpu (_callable , ** kwargs )
143193 # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
0 commit comments