Skip to content

Commit f18528e

Browse files
authored
Reuse assert_close for benchmarks from triton.testing (#3161)
Part of #3160 Signed-off-by: Anatoly Myachev <[email protected]>
1 parent ec71843 commit f18528e

File tree

2 files changed

+11
-38
lines changed

2 files changed

+11
-38
lines changed
Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
import os
22

3-
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, BENCHMARKING_METHOD # type: ignore # noqa: F401
3+
from triton.testing import assert_close
4+
5+
from .benchmark_testing import do_bench, perf_report, Benchmark, BENCHMARKING_METHOD
46

57
if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
68
os.environ["INJECT_PYTORCH"] = "True"
9+
10+
__all__ = [
11+
"assert_close",
12+
"do_bench",
13+
"perf_report",
14+
"Benchmark",
15+
"BENCHMARKING_METHOD",
16+
]

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -160,43 +160,6 @@ def extract_kernels(funcs):
160160
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")
161161

162162

163-
def assert_close(x, y, atol=None, rtol=None, err_msg=""):
164-
import numpy as np
165-
import torch
166-
167-
# canonicalize arguments to be tensors
168-
if not isinstance(x, torch.Tensor):
169-
x = torch.tensor(x)
170-
if not isinstance(y, torch.Tensor):
171-
y = torch.tensor(y)
172-
# absolute tolerance
173-
if atol is None:
174-
atol = 1e-2
175-
atol = atol(x.dtype) if callable(atol) else atol
176-
# relative tolerance hook
177-
if rtol is None:
178-
rtol = 0.
179-
rtol = rtol(x.dtype) if callable(rtol) else rtol
180-
# we use numpy instead of pytorch
181-
# as it seems more memory efficient
182-
# pytorch tends to oom on large tensors
183-
if isinstance(x, torch.Tensor):
184-
if x.dtype == torch.bfloat16:
185-
x = x.float()
186-
x = x.cpu().detach().numpy()
187-
if isinstance(y, torch.Tensor):
188-
if y.dtype == torch.bfloat16:
189-
y = y.float()
190-
y = y.cpu().detach().numpy()
191-
# we handle size==1 case separately as we can
192-
# provide better error message there
193-
if x.size > 1 or y.size > 1:
194-
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
195-
return
196-
if not np.allclose(x, y, atol=atol, rtol=rtol):
197-
raise AssertionError(f"{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})")
198-
199-
200163
def perf_report(benchmarks):
201164
"""
202165
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.

0 commit comments

Comments
 (0)