Skip to content

Commit 33539d0

Browse files
authored
Clean up perf scorecard and add barplot generation script (#212)
1 parent 9ea21a2 commit 33539d0

File tree

2 files changed

+73
-48
lines changed

2 files changed

+73
-48
lines changed

benchmarks/pointwise_scorecard.py

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import time
23
import torch
34
import inspect
@@ -38,6 +39,12 @@ def medium_transpose():
3839
return (rand(32, 12, 64, 64).transpose(-1, -2),
3940
rand(32, 12, 64, 64).transpose(-1, -2))
4041

42+
def medium2():
43+
return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
44+
45+
def medium3d():
46+
return (rand(16, 32, 64), rand(16, 32, 64))
47+
4148
def medium_channels_last():
4249
return (rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
4350
rand(32, 3, 224, 224).to(memory_format=torch.channels_last))
@@ -56,6 +63,10 @@ def large_transpose():
5663
return (rand(8192, 8192).transpose(0, 1),
5764
rand(8192, 8192).transpose(0, 1))
5865

66+
def large_channels_last():
67+
return (rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
68+
rand(32, 32, 256, 256).to(memory_format=torch.channels_last))
69+
5970
def pathological_broadcast():
6071
return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
6172

@@ -89,14 +100,14 @@ def log(a):
89100
def exp(a):
90101
return a.exp()
91102

92-
def pow(a):
103+
def square(a):
93104
return a ** 2
94105

95106
def fma(a, b):
96107
return a * b + b
97108

98109
def hardswish(a):
99-
return a * (a + 3).clamp(0, 6) / 6
110+
return a * (a + 3.0).clamp(0.0, 6.0) / 6.0
100111

101112
def native_hardswish(a):
102113
return torch._C._nn.hardswish(a)
@@ -107,19 +118,55 @@ def softplus(a):
107118
def mish(a):
108119
return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
109120

121+
# ------------------------------------------------------------------------------
122+
# Helpers
123+
# ------------------------------------------------------------------------------
124+
def time_cpu(fn, args, iters):
125+
s = time.perf_counter()
126+
for _ in range(iters):
127+
fn(*args)
128+
e = time.perf_counter()
129+
return e - s
130+
131+
def time_cuda(fn, args, iters):
132+
start = torch.cuda.Event(enable_timing=True)
133+
end = torch.cuda.Event(enable_timing=True)
134+
start.record()
135+
for _ in range(iters):
136+
fn(*args)
137+
end.record()
138+
torch.cuda.synchronize()
139+
return start.elapsed_time(end) / 1e3
140+
141+
def benchmark_with_timer(fn, args, timer):
142+
timer(fn, args, 3)
143+
calibration = timer(fn, args, 1)
144+
iters = int(1.0 / calibration)
145+
return timer(fn, args, iters) / iters
146+
147+
def benchmark(fn, args):
148+
timer = time_cpu if args[0].device.type == "cpu" else time_cuda
149+
return benchmark_with_timer(fn, args, timer)
150+
151+
def micros(s):
152+
return f"{s * 1e6:.1f}"
153+
110154
shapes = [
111155
scalar,
112156
small,
113157
small_2d,
114158
small_broadcast,
115159
medium,
160+
medium2,
161+
medium3d,
116162
medium_sliced,
117163
medium_transpose,
118164
medium_channels_last,
119165
medium_broadcast,
120166
medium_broadcast_channels_last,
121167
large,
122168
large_transpose,
169+
large_channels_last,
123170
pathological_broadcast,
124171
]
125172

@@ -133,20 +180,16 @@ def mish(a):
133180
tanh,
134181
log,
135182
exp,
136-
pow,
183+
square,
137184
fma,
138185
hardswish,
139186
native_hardswish,
140187
]
141-
#shapes = [large_transpose]
142-
#operators = [add]
143-
#shapes = [scalar]
144-
#operators = [add]
188+
145189
nope = set()
146190
for shape, operator in itertools.product(shapes, operators):
147191
nargs = len(inspect.signature(operator).parameters)
148192
args = shape()[:nargs]
149-
#print(f"{operator.__name__} {shape.__name__}")
150193

151194
try:
152195
if shape == medium_transpose:
@@ -160,41 +203,13 @@ def mish(a):
160203
ts_op = torch.jit.script(operator)
161204
torch.testing.assert_allclose(operator(*args), ts_op(*args))
162205

163-
def time_cpu(fn, args, iters):
164-
s = time.perf_counter()
165-
for _ in range(iters):
166-
fn(*args)
167-
e = time.perf_counter()
168-
return e - s
169-
170-
def time_cuda(fn, args, iters):
171-
start = torch.cuda.Event(enable_timing=True)
172-
end = torch.cuda.Event(enable_timing=True)
173-
start.record()
174-
for _ in range(iters):
175-
fn(*args)
176-
end.record()
177-
torch.cuda.synchronize()
178-
return start.elapsed_time(end) / 1e3
179-
180-
def benchmark_with_timer(fn, args, timer):
181-
timer(fn, args, 3)
182-
calibration = timer(fn, args, 1)
183-
iters = int(1.0 / calibration)
184-
return timer(fn, args, iters) / iters
185-
186-
def benchmark(fn, args):
187-
timer = time_cpu if args[0].device.type == "cpu" else time_cuda
188-
return benchmark_with_timer(fn, args, timer)
189-
190-
def micros(s):
191-
return f"{s * 1e6:.1f}"
192206

207+
print("fuser,device,operator,shape,time")
193208
results = []
194209
for shape, operator in itertools.product(shapes, operators):
195210
nargs = len(inspect.signature(operator).parameters)
196211
args = shape()[:nargs]
197-
212+
198213
result = benchmark(operator, args)
199214
print(",".join(["eager", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
200215
try:
@@ -206,18 +221,9 @@ def micros(s):
206221
result = benchmark(pw_op, args)
207222
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
208223
except Exception:
209-
#print(f"pointwise_operator failed on {operator.__name__}, {shape.__name__}")
210-
#nope.add((operator, shape))
211224
print(",".join(["pointwise", args[0].device.type, operator.__name__, shape.__name__, micros(float("nan"))]))
212225

213226
ts_op = torch.jit.script(operator)
214227
result = benchmark(ts_op, args)
215228
print(",".join(["fuser", args[0].device.type, operator.__name__, shape.__name__, micros(result)]))
216-
217-
# cpu
218-
# parallel cpu
219-
# cuda
220-
221-
# casts
222-
223-
# inplace?
229+
sys.stdout.flush()

benchmarks/process_scorecard.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pandas
2+
import matplotlib.pyplot as plt
3+
4+
df = pandas.read_csv("perf.csv")
5+
6+
ops = pandas.unique(df["operator"])
7+
nops = len(ops)
8+
pivot_op_shape = df.pivot_table(values="time", index=["operator", "shape"], columns=["fuser"])
9+
pivot_speedups = (pivot_op_shape.T / pivot_op_shape["eager"]).T
10+
11+
plt.rcParams["figure.figsize"] = (20,100)
12+
fig, axs = plt.subplots(nops)
13+
plt.subplots_adjust(hspace=0.5)
14+
for idx, op in enumerate(ops):
15+
op_speedups = pivot_speedups.T[op].T
16+
op_speedups.plot(ax=axs[idx], kind="bar", ylim=(0, 5), rot=45)
17+
axs[idx].set_title(op)
18+
axs[idx].set_xlabel("")
19+
plt.savefig("scorecard.svg")

0 commit comments

Comments
 (0)