Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 3fa6ea9

Browse files
authored
Added support for output pytrees in AOTAutograd (#332)
* Added autominifier * added support for output pytree * Added some test
1 parent 7c4453d commit 3fa6ea9

File tree

3 files changed

+87
-35
lines changed

3 files changed

+87
-35
lines changed

functorch/_src/aot_autograd.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def joint_forward_backward(primals, tangents):
212212
out = fn(*primals)
213213
primals = [p for p in pytree.tree_flatten(primals)[0] if p.requires_grad]
214214
backward_out = []
215-
if primals:
215+
if primals: # todo(chilli): Make it support it if not all outputs have gradients
216216
backward_out = torch.autograd.grad(out, primals, grad_outputs=tangents, allow_unused=True)
217217
return out, backward_out
218218
return joint_forward_backward
@@ -275,8 +275,6 @@ def forward(ctx, *flat_args):
275275
compiled_bw = bw_compiler(bw_module, bw_args)
276276
fw_outs = normalize_as_list(compiled_fw(*flat_args))
277277
ctx.save_for_backward(*fw_outs[num_outs:])
278-
if num_outs == 1:
279-
return fw_outs[0]
280278
return tuple(fw_outs[0:num_outs])
281279

282280
@staticmethod
@@ -303,46 +301,73 @@ class _CompileCache(CompileCache):
303301
HAS_TREE = False
304302
compile_cache = None
305303

304+
# Inspired by autodidax (thanks!)
305+
class PytreeThunk:
306+
spec = None
307+
# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
308+
is_simple = None # if the output spec is a tuple/list, we won't bother unflattening it.
309+
is_really_simple = None # if the output spec is a LeafSpec
310+
311+
def set(self, spec):
312+
assert self.spec is None or self.spec == spec
313+
self.spec = spec
314+
if type(self.spec) in [tuple, list] and all([isinstance(i, pytree.LeafSpec) for i in spec.children_specs]):
315+
self.is_simple = True
316+
if isinstance(self.spec, pytree.LeafSpec):
317+
self.is_really_simple = True
318+
319+
def unflatten(self, x):
320+
if self.is_really_simple:
321+
return x[0]
322+
if self.is_simple:
323+
return x
324+
return pytree.tree_unflatten(x, self.spec)
306325

307326
def compiled_function(
308327
fn, fw_compiler, bw_compiler, partition_fn=default_partition, decompose=False, hasher_type="StaticShapeHasher"
309328
):
310329
global compile_cache
311330
if compile_cache is None:
312331
compile_cache = CompileCache()
313-
cached_fn = None
332+
cached_res = None
314333

315334
fn_id = id(fn)
316335

317336
def returned_function(*args, **kwargs):
318337
global compile_cache
319-
nonlocal cached_fn
338+
nonlocal cached_res
320339
if HAS_TREE:
321340
flattened_args = tree.flatten((args, kwargs))
322341
else:
323342
flattened_args, _ = pytree.tree_flatten((args, kwargs))
324343
num_args = len(flattened_args)
325344
# Check if the fn is already compiled
326-
cached_fn = compile_cache.at(fn_id, num_args, hasher_type, *flattened_args)
345+
cached_res = compile_cache.at(fn_id, num_args, hasher_type, *flattened_args)
327346

328347
# Compile the function and save it in the cache
329-
if cached_fn is None:
348+
if cached_res is None:
330349
# Compile a new function
331350
flattened_args, args_spec = pytree.tree_flatten((args, kwargs))
351+
out_spec = PytreeThunk()
332352
def flat_fn(*args):
353+
nonlocal out_spec
333354
args, kwargs = pytree.tree_unflatten(args, args_spec)
334-
return fn(*args, **kwargs)
335-
336-
cached_fn = create_compiled_function(
355+
tree_out = fn(*args, **kwargs)
356+
flat_out = pytree.tree_flatten(tree_out)
357+
out_spec.set(flat_out[1])
358+
return flat_out[0]
359+
compiled_fn = create_compiled_function(
337360
flat_fn, fw_compiler, bw_compiler, partition_fn, decompose
338361
).apply
339-
362+
cached_res = (compiled_fn, out_spec)
340363
# Save the compiled_fn in the cache
341364
compile_cache.insert(
342-
fn_id, num_args, hasher_type, cached_fn, *flattened_args
365+
fn_id, num_args, hasher_type, cached_res, *flattened_args
343366
)
344367

345-
return cached_fn(*flattened_args)
368+
cached_fn, out_spec = cached_res
369+
out = cached_fn(*flattened_args)
370+
return out_spec.unflatten(out)
346371

347372
return returned_function
348373

functorch/_src/fx_minifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def delta_debugging(cur_graph: fx.Graph, cur_inps):
193193
print([i.shape for i in inps])
194194
return failing_fx, inps
195195

196-
import subprocess
197196
def check_nvfuser_subprocess(f, inps):
197+
import subprocess
198198
f.to_folder("temp")
199199
with open("_temp.py", 'w') as fil:
200200
fil.write(f'''
@@ -213,4 +213,4 @@ def check_nvfuser_subprocess(f, inps):
213213
except Exception as e:
214214
print(e)
215215
return True
216-
return False
216+
return False

test/test_pythonkey.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
skipCUDAIfNoMagma, onlyCPU
2020
import types
2121
from functools import partial, wraps
22+
import copy
2223

2324
import functorch
2425
from functorch import (
@@ -27,7 +28,7 @@
2728
)
2829
from functorch.compile import (
2930
nnc_jit, compiled_function, compiled_module,
30-
partition_with_recompute_fwd_in_bwd, pythonkey_decompose, decomposition_table
31+
partition_with_recompute_fwd_in_bwd, pythonkey_decompose, decomposition_table, aot_function, aot_module
3132
)
3233

3334
from torch.testing._internal.common_device_type import ops, onlyCPU
@@ -257,42 +258,68 @@ def _nop_compile(x, _):
257258

258259
def _outs_and_grads(fn, inps):
259260
outs = fn(*inps)
260-
[out.sum().backward(retain_graph=True) for out in outs]
261-
grads = [inp.grad for inp in inps]
262-
for inp in inps:
261+
[out.sum().backward(retain_graph=True) for out in pytree.tree_flatten(outs)[0]]
262+
grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
263+
for inp in pytree.tree_flatten(inps)[0]:
263264
inp.grad = None
264265
return outs, grads
265266

266-
class TestEagerFusion(TestCase):
267-
def test_single_output(self):
268-
def f(a, b):
269-
return a + b
270-
compiled_f = compiled_function(f, _nop_compile, _nop_compile)
271-
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
267+
268+
269+
class TestAOTAutograd(TestCase):
270+
def verify_aot_autograd(self, f, inp):
271+
if isinstance(f, nn.Module):
272+
compiled_f = aot_module(f, _nop_compile, _nop_compile)
273+
else:
274+
compiled_f = aot_function(f, _nop_compile, _nop_compile)
272275
ref_out, ref_grad = _outs_and_grads(f, inp)
273276
test_out, test_grad = _outs_and_grads(compiled_f, inp)
274277
self.assertEqual(ref_out, test_out)
275278
self.assertEqual(ref_grad, test_grad)
276279

280+
def test_single_output(self):
281+
def f(a, b):
282+
return a + b
283+
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
284+
self.verify_aot_autograd(f, inp)
285+
277286
def test_multi_output(self):
278287
def f(a, b):
279288
return a + b, a - b
280-
compiled_f = compiled_function(f, _nop_compile, _nop_compile)
281289
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
282-
ref_out, ref_grad = _outs_and_grads(f, inp)
283-
test_out, test_grad = _outs_and_grads(compiled_f, inp)
284-
self.assertEqual(ref_out, test_out)
285-
self.assertEqual(ref_grad, test_grad)
290+
self.verify_aot_autograd(f, inp)
286291

287292
def test_multi_output_list(self):
288293
def f(a, b):
289294
return [a + b, a - b]
290-
compiled_f = compiled_function(f, _nop_compile, _nop_compile)
291295
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
292-
ref_out, ref_grad = _outs_and_grads(f, inp)
293-
test_out, test_grad = _outs_and_grads(compiled_f, inp)
294-
self.assertEqual(ref_out, test_out)
295-
self.assertEqual(ref_grad, test_grad)
296+
self.verify_aot_autograd(f, inp)
297+
298+
def test_multi_output_list(self):
299+
def f(a, b):
300+
return [a + b, a - b]
301+
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
302+
self.verify_aot_autograd(f, inp)
303+
304+
def test_output_dict(self):
305+
def f(x):
306+
return {'a': x, 'b': x}
307+
inp = [torch.randn(3, 3, requires_grad=True)]
308+
self.verify_aot_autograd(f, inp)
309+
310+
def f(x, y):
311+
return {'a': x, 'b': y + x}
312+
inp = [torch.randn(3, requires_grad=True), torch.randn(3)]
313+
self.verify_aot_autograd(f, inp)
314+
315+
def f(x):
316+
new_d = {}
317+
for k in x:
318+
new_d[k] = x[k] * 2
319+
return new_d
320+
inp = [{'a': torch.randn(3, requires_grad=True), 'b': torch.randn(3, requires_grad=True)}]
321+
self.verify_aot_autograd(f, inp)
322+
296323

297324
def test_module(self):
298325
mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU())

0 commit comments

Comments
 (0)