Skip to content

Commit 4620487

Browse files
authored
[ez][BE] reuse formatting functions from utils in eval (#61)
1 parent 1f35066 commit 4620487

File tree

2 files changed

+4
-52
lines changed

2 files changed

+4
-52
lines changed

BackendBench/eval.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,20 @@
66

77

88
from BackendBench.utils import uses_cuda_stream
9+
from BackendBench.utils import serialize_args
910

1011
logger = logging.getLogger(__name__)
1112

1213
EXC_MSG = """
1314
Exception raised for {op}:
1415
args: {args}
15-
kwargs: {kwargs}
1616
exc: {exc}
1717
"""
1818

1919

20-
def format_tensor(t):
21-
return f"{t.dtype}{list(t.shape)}"
22-
23-
24-
def format_args(args):
25-
return [format_tensor(arg) if isinstance(arg, torch.Tensor) else arg for arg in args]
26-
27-
28-
def format_kwargs(kwargs):
29-
return {k: format_tensor(v) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
30-
31-
3220
def format_exception(e, op, args, kwargs):
3321
op_name = getattr(op, "__name__", str(op))
34-
return EXC_MSG.format(op=op_name, args=format_args(args), kwargs=format_kwargs(kwargs), exc=e)
22+
return EXC_MSG.format(op=op_name, args=serialize_args(args, kwargs), exc=e)
3523

3624

3725
def allclose(a, b):
@@ -60,9 +48,7 @@ def eval_correctness_test(op, impl, test):
6048
def eval_correctness(op, impl, tests):
6149
correct, total = 0, 0
6250
for test in tests:
63-
logging.debug(
64-
f"Testing {op.__name__} with args {format_args(test.args)} and kwargs {format_kwargs(test.kwargs)}"
65-
)
51+
logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}")
6652
if eval_correctness_test(op, impl, test):
6753
correct += 1
6854
total += 1
@@ -88,7 +74,7 @@ def eval_performance(op, impl, tests):
8874
test_times = []
8975
for test in tests:
9076
logging.debug(
91-
f"Benchmarking {op.__name__} with args {format_args(test.args)} and kwargs {format_kwargs(test.kwargs)}"
77+
f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}"
9278
)
9379
base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs)))
9480
try:

test/test_eval.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
try:
55
import importlib.util
66
from BackendBench.eval import (
7-
format_tensor,
8-
format_args,
9-
format_kwargs,
107
format_exception,
118
allclose,
129
eval_correctness_test,
@@ -24,37 +21,6 @@
2421

2522

2623
class TestFormatFunctions:
27-
def test_format_tensor(self):
28-
tensor = torch.randn(2, 3, 4, dtype=torch.float32)
29-
formatted = format_tensor(tensor)
30-
assert formatted == "torch.float32[2, 3, 4]"
31-
32-
tensor_int = torch.randint(0, 10, (5, 5), dtype=torch.int64)
33-
formatted_int = format_tensor(tensor_int)
34-
assert formatted_int == "torch.int64[5, 5]"
35-
36-
def test_format_args(self):
37-
tensor1 = torch.randn(2, 3)
38-
tensor2 = torch.randn(3, 4)
39-
scalar = 2.5
40-
41-
args = [tensor1, scalar, tensor2]
42-
formatted = format_args(args)
43-
44-
assert len(formatted) == 3
45-
assert formatted[0] == "torch.float32[2, 3]"
46-
assert formatted[1] == 2.5
47-
assert formatted[2] == "torch.float32[3, 4]"
48-
49-
def test_format_kwargs(self):
50-
tensor = torch.randn(2, 3)
51-
kwargs = {"input": tensor, "dim": 1, "keepdim": True}
52-
53-
formatted = format_kwargs(kwargs)
54-
assert formatted["input"] == "torch.float32[2, 3]"
55-
assert formatted["dim"] == 1
56-
assert formatted["keepdim"] is True
57-
5824
def test_format_exception(self):
5925
op = torch.ops.aten.relu.default
6026
args = [torch.randn(2, 3)]

0 commit comments

Comments
 (0)