Skip to content

Commit 13892fa

Browse files
authored
"Scorecard" benchmarks for pointwise op authoring (#193)
1 parent 93c8575 commit 13892fa

File tree

1 file changed

+223
-0
lines changed

1 file changed

+223
-0
lines changed

benchmarks/pointwise_scorecard.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import time
2+
import torch
3+
import inspect
4+
import itertools
5+
6+
from functorch import pointwise_operator
7+
8+
torch.set_num_threads(1)
9+
torch._C._debug_set_fusion_group_inlining(False)
10+
11+
def rand(*shape):
12+
return torch.rand(*shape).mul(16).add(1)
13+
14+
15+
# ------------------------------------------------------------------------------
16+
# Shape test cases
17+
# ------------------------------------------------------------------------------
18+
def scalar():
19+
return (rand(1), rand(1))
20+
21+
def small():
22+
return (rand(32), rand(32))
23+
24+
def small_2d():
25+
return (rand(1, 32), rand(1, 32))
26+
27+
def small_broadcast():
28+
return (rand(4, 32), rand(32))
29+
30+
def medium():
31+
return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
32+
33+
def medium_sliced():
34+
return (rand(32, 12, 64, 64)[..., ::2],
35+
rand(32, 12, 64, 64)[..., ::2])
36+
37+
def medium_transpose():
38+
return (rand(32, 12, 64, 64).transpose(-1, -2),
39+
rand(32, 12, 64, 64).transpose(-1, -2))
40+
41+
def medium_channels_last():
42+
return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
43+
rand(32, 3, 224, 224).to(memory_format=torch.channels_last))
44+
45+
def medium_broadcast():
46+
return (rand(32, 12, 64, 64), rand(64))
47+
48+
def medium_broadcast_channels_last():
49+
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last),
50+
rand(3, 1, 1))
51+
52+
def large():
53+
return (rand(8192, 8192), rand(8192, 8192))
54+
55+
def large_transpose():
56+
return (rand(8192, 8192).transpose(0, 1),
57+
rand(8192, 8192).transpose(0, 1))
58+
59+
def pathological_broadcast():
60+
return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
61+
62+
# ------------------------------------------------------------------------------
63+
# Operator test cases
64+
# ------------------------------------------------------------------------------
65+
def add(a, b):
66+
return a + b
67+
68+
def sub(a, b):
69+
return a - b
70+
71+
def mul(a, b):
72+
return a * b
73+
74+
def div(a, b):
75+
return a / b
76+
77+
def relu(a):
78+
return a.relu()
79+
80+
def sigmoid(a):
81+
return a.sigmoid()
82+
83+
def tanh(a):
84+
return a.tanh()
85+
86+
def log(a):
87+
return a.log()
88+
89+
def exp(a):
90+
return a.exp()
91+
92+
def pow(a):
93+
return a ** 2
94+
95+
def fma(a, b):
96+
return a * b + b
97+
98+
def hardswish(a):
99+
return a * (a + 3).clamp(0, 6) / 6
100+
101+
def native_hardswish(a):
102+
return torch._C._nn.hardswish(a)
103+
104+
def softplus(a):
105+
return (a * 1.0).exp().log1p() / 1.0
106+
107+
def mish(a):
108+
return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
109+
110+
shapes = [
111+
scalar,
112+
small,
113+
small_2d,
114+
small_broadcast,
115+
medium,
116+
medium_sliced,
117+
medium_transpose,
118+
medium_channels_last,
119+
medium_broadcast,
120+
medium_broadcast_channels_last,
121+
large,
122+
large_transpose,
123+
pathological_broadcast,
124+
]
125+
126+
operators = [
127+
add,
128+
sub,
129+
mul,
130+
div,
131+
relu,
132+
sigmoid,
133+
tanh,
134+
log,
135+
exp,
136+
pow,
137+
fma,
138+
hardswish,
139+
native_hardswish,
140+
]
141+
#shapes = [large_transpose]
142+
#operators = [add]
143+
#shapes = [scalar]
144+
#operators = [add]
145+
nope = set()
146+
for shape, operator in itertools.product(shapes, operators):
147+
nargs = len(inspect.signature(operator).parameters)
148+
args = shape()[:nargs]
149+
#print(f"{operator.__name__} {shape.__name__}")
150+
151+
try:
152+
if shape == medium_transpose:
153+
raise RuntimeError("pointwise_operator hangs on medium_transpose")
154+
pw_op = pointwise_operator(operator)
155+
torch.testing.assert_allclose(operator(*args), pw_op(*args))
156+
except Exception:
157+
print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}")
158+
nope.add((operator, shape))
159+
160+
ts_op = torch.jit.script(operator)
161+
torch.testing.assert_allclose(operator(*args), ts_op(*args))
162+
163+
def time_cpu(fn, args, iters):
164+
s = time.perf_counter()
165+
for _ in range(iters):
166+
fn(*args)
167+
e = time.perf_counter()
168+
return e - s
169+
170+
def time_cuda(fn, args, iters):
171+
start = torch.cuda.Event(enable_timing=True)
172+
end = torch.cuda.Event(enable_timing=True)
173+
start.record()
174+
for _ in range(iters):
175+
fn(*args)
176+
end.record()
177+
torch.cuda.synchronize()
178+
return start.elapsed_time(end) / 1e3
179+
180+
def benchmark_with_timer(fn, args, timer):
181+
timer(fn, args, 3)
182+
calibration = timer(fn, args, 1)
183+
iters = int(1.0 / calibration)
184+
return timer(fn, args, iters) / iters
185+
186+
def benchmark(fn, args):
187+
timer = time_cpu if args[0].device.type == "cpu" else time_cuda
188+
return benchmark_with_timer(fn, args, timer)
189+
190+
def micros(s):
191+
return f"{s * 1e6:.1f}"
192+
193+
results = []
194+
for shape, operator in itertools.product(shapes, operators):
195+
nargs = len(inspect.signature(operator).parameters)
196+
args = shape()[:nargs]
197+
198+
result = benchmark(operator, args)
199+
print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
200+
try:
201+
if shape == medium_transpose:
202+
raise RuntimeError("pointwise_operator hangs on medium_transpose")
203+
if (operator, shape) in nope:
204+
raise RuntimeError("pointwise_operator fails on medium_transpose")
205+
pw_op = pointwise_operator(operator)
206+
result = benchmark(pw_op, args)
207+
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
208+
except Exception:
209+
#print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}")
210+
#nope.add((operator, shape))
211+
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))]))
212+
213+
ts_op = torch.jit.script(operator)
214+
result = benchmark(ts_op, args)
215+
print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
216+
217+
# cpu
218+
# parallel cpu
219+
# cuda
220+
221+
# casts
222+
223+
# inplace?

0 commit comments

Comments
 (0)