Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 71a446a

Browse files
authored
Re-land the compile cache (#169)
* Re-land the compile cache There were a few minor problems, but the major thing was that handle_torch_function_no_python_arg_parser wasn't exposed via TORCH_API. * make autograd ops work
1 parent 0b319e5 commit 71a446a

File tree

7 files changed

+1717
-0
lines changed

7 files changed

+1717
-0
lines changed

benchmarks/operator_authoring.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
from functools import partial
2+
import numpy as np
3+
import pandas as pd
4+
import timeit
5+
import torch
6+
from functorch import pointwise_operator
7+
8+
WRITE_CSV = False
9+
CUDA = False
10+
SIZES = [1, 512, 8192]
11+
NUMBER = [100, 10, 1, 1]
12+
REPEAT = 20
13+
14+
15+
@pointwise_operator
16+
def nnc_add(a, b):
17+
return a + b
18+
19+
20+
@pointwise_operator
21+
def nnc_addnorm(a, b, mean, std):
22+
return (a + b - mean) / std
23+
24+
25+
def eager_addnorm(a, b, mean, std):
26+
return (a + b - mean) / std
27+
28+
29+
def inplace_addnorm(a, b, mean, std, out):
30+
out = torch.add(a, b, out=out)
31+
torch.sub(out, mean, out=out)
32+
torch.div(out, std, out=out)
33+
return out
34+
35+
36+
ts_addnorm = torch.jit.script(eager_addnorm)
37+
ts_ip_addnorm = torch.jit.script(inplace_addnorm)
38+
39+
40+
def maybe_synced(fn):
41+
if CUDA:
42+
synchronize = torch.cuda.synchronize
43+
synchronize() # warmup
44+
45+
def _fn():
46+
result = fn()
47+
synchronize()
48+
return result
49+
50+
return _fn
51+
return fn
52+
53+
54+
def benchmark_loop(setup):
55+
result = np.zeros((REPEAT, len(SIZES), 2), dtype=np.float64)
56+
for s, n in enumerate(SIZES):
57+
nnc, aten = setup(n)
58+
nnc = maybe_synced(nnc)
59+
aten = maybe_synced(aten)
60+
61+
for r in range(result.shape[0]):
62+
result[r, s, 0] = timeit.timeit(nnc, number=NUMBER[s])
63+
result[r, s, 1] = timeit.timeit(aten, number=NUMBER[s])
64+
65+
result = np.median(result, axis=0)
66+
assert result.shape == (len(SIZES), 2)
67+
result = result[:, 1] / result[:, 0]
68+
print(result)
69+
return result
70+
71+
72+
def test(make_args, nnc=nnc_add, aten=torch.add):
73+
def setup(n):
74+
args = make_args(n)
75+
result_aten = aten(*args)
76+
result_nnc = nnc(*args)
77+
assert result_nnc.dtype == result_aten.dtype
78+
assert result_nnc.size() == result_aten.size()
79+
assert result_nnc.stride() == result_aten.stride()
80+
torch.testing.assert_allclose(result_aten, result_nnc)
81+
return (lambda: nnc(*args), lambda: aten(*args))
82+
83+
return benchmark_loop(setup)
84+
85+
86+
def test_inplace(make_args, nnc=nnc_add, aten=torch.add):
87+
def inplace_setup(n):
88+
a, b = make_args(n)
89+
result_aten = torch.clone(a)
90+
result_nnc = torch.clone(a)
91+
nnc(result_nnc, b, out=result_nnc)
92+
aten(result_aten, b, out=result_aten)
93+
torch.testing.assert_allclose(result_aten, result_nnc)
94+
return (lambda: nnc(a, b, out=a), lambda: aten(a, b, out=a))
95+
96+
return benchmark_loop(inplace_setup)
97+
98+
99+
def test_out(make_args, out, nnc=nnc_add, aten=torch.add):
100+
def out_setup(n):
101+
args = make_args(n)
102+
result_aten = out(n)
103+
result_nnc = out(n)
104+
aten(*args, out=result_aten)
105+
nnc(*args, out=result_nnc)
106+
torch.testing.assert_allclose(result_aten, result_nnc)
107+
result = out(n)
108+
return (lambda: nnc(*args, out=result), lambda: aten(*args, out=result))
109+
110+
return benchmark_loop(out_setup)
111+
112+
113+
def test_backwards(make_args, nnc=nnc_add, aten=torch.add):
114+
def backwards_setup(n):
115+
args = make_args(n)
116+
(grad_var,) = [a for a in args if a.requires_grad]
117+
aten(*args).sum().backward()
118+
correct = grad_var.grad.clone()
119+
grad_var.grad.zero_()
120+
nnc(*args).sum().backward()
121+
torch.testing.assert_allclose(correct, grad_var.grad)
122+
return (
123+
lambda: nnc(*args).sum().backward(),
124+
lambda: aten(*args).sum().backward(),
125+
)
126+
127+
return benchmark_loop(backwards_setup)
128+
129+
130+
def main():
131+
torch.set_num_threads(1) # TODO(jansel): add parallel support
132+
torch._C._jit_override_can_fuse_on_cpu(True)
133+
134+
device = "cuda" if CUDA else "cpu"
135+
I = partial(torch.randint, 0, 100, device=device)
136+
R = partial(torch.randn, device=device)
137+
138+
results = [
139+
("add", test(lambda n: (R(n, n), R(n, n)))),
140+
("broadcast1", test(lambda n: (R(n, n), R(1)))),
141+
("broadcast2", test(lambda n: (R(n, n), R(n, 1)))),
142+
("broadcast3", test(lambda n: (R(n, 1), R(1, n)))),
143+
("inplace", test_inplace(lambda n: (R(n, n), R(n, 1)))),
144+
("out=", test_out(lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n))),
145+
("transposed1", test(lambda n: (R(n, n), R(n, n).transpose(0, 1)))),
146+
(
147+
"transposed2",
148+
test(lambda n: (R(n, n).transpose(0, 1), R(n, n).transpose(0, 1))),
149+
),
150+
("slice1", test(lambda n: (R(n + 1, n + 1, 2)[:n, :n, 0], R(n, n)))),
151+
("slice2", test(lambda n: (R(n, n, 2)[:, :, 0], R(n, n, 2)[:, :, 0]))),
152+
(
153+
"strided out",
154+
test_out(
155+
lambda n: (R(n, n), R(n, n)),
156+
out=lambda n: R(n + 1, n + 1, 2)[:n, :n, 0],
157+
),
158+
),
159+
(
160+
"out convert",
161+
test_out(
162+
lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n, dtype=torch.float64)
163+
),
164+
),
165+
("issue #57611 (n,32,32,2)", test(lambda n: (R(1, 32, 32, 2), R(n, 1, 1, 2)))),
166+
("float+double", test(lambda n: (R(n, n), R(n, n, dtype=torch.float64)))),
167+
(
168+
"int+long",
169+
test(
170+
lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int64))
171+
),
172+
),
173+
(
174+
"int+short",
175+
test(
176+
lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int16))
177+
),
178+
),
179+
(
180+
"float+int",
181+
test(
182+
lambda n: (R([n, n], dtype=torch.float32), I([n, n], dtype=torch.int32))
183+
),
184+
),
185+
(
186+
"double+long",
187+
test(
188+
lambda n: (R([n, n], dtype=torch.float64), I([n, n], dtype=torch.int64))
189+
),
190+
),
191+
(
192+
"fused addnorm",
193+
test(
194+
lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
195+
nnc=nnc_addnorm,
196+
aten=eager_addnorm,
197+
),
198+
),
199+
(
200+
"fused addnorm (vs TS)",
201+
test(
202+
lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
203+
nnc=nnc_addnorm,
204+
aten=ts_addnorm,
205+
),
206+
),
207+
(
208+
"fused addnorm out=",
209+
test_out(
210+
lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
211+
nnc=nnc_addnorm,
212+
aten=inplace_addnorm,
213+
out=lambda n: R(n, n),
214+
),
215+
),
216+
(
217+
"fused addnorm out= (vs TS)",
218+
test_out(
219+
lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)),
220+
nnc=nnc_addnorm,
221+
aten=ts_ip_addnorm,
222+
out=lambda n: R(n, n),
223+
),
224+
),
225+
(
226+
"fused addnorm backward",
227+
test_backwards(
228+
lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)),
229+
nnc=nnc_addnorm,
230+
aten=eager_addnorm,
231+
),
232+
),
233+
(
234+
"fused addnorm backward (vs TS)",
235+
test_backwards(
236+
lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)),
237+
nnc=nnc_addnorm,
238+
aten=ts_addnorm,
239+
),
240+
),
241+
]
242+
243+
df = pd.DataFrame(
244+
np.stack([r for n, r in results]),
245+
columns=[f"{n}x{n}".rjust(9) for n in SIZES],
246+
index=[n for n, r in results],
247+
)
248+
249+
if WRITE_CSV:
250+
df.to_csv("../operator_authoring_results.csv")
251+
print("wrote ../operator_authoring_results.csv")
252+
253+
print()
254+
print("Speedups over aten")
255+
pd.options.display.float_format = "{:.2f}x".format
256+
print(df)
257+
258+
259+
if __name__ == "__main__":
260+
main()

functorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from ._src.python_key import wrap_key, PythonTensor, pythonkey_trace, make_fx, nnc_jit, make_nnc
2323
from ._src.nnc_compile import nnc_compile, get_ops
2424
from ._src.eager_compilation import compiled_function, compiled_module, tvm_compile, draw_joint_graph, default_partition
25+
from ._src.operator_authoring import pointwise_operator
26+
2527

2628
# Monkeypatching lol
2729
_old_cross_entropy = torch.nn.functional.cross_entropy

0 commit comments

Comments
 (0)