Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit ee1ece7

Browse files
committed
cleaned up some eager compilation stuff
1 parent 1fb0a0f commit ee1ece7

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

functorch/_src/eager_compilation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,10 @@ def forward(ctx, *flat_args):
128128

129129
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
130130
compiled_bw = bw_compiler(bw_module, bw_args)
131-
132131
fw_outs = compiled_fw(*flat_args)
133132
if not isinstance(fw_outs, list):
134133
fw_outs = [fw_outs]
135-
ctx.activations = fw_outs[num_outs:]
134+
ctx.save_for_backward(*fw_outs[num_outs:])
136135
if num_outs == 1:
137136
return fw_outs[0]
138137
return tuple(fw_outs[0:num_outs])
@@ -141,24 +140,28 @@ def forward(ctx, *flat_args):
141140
def backward(ctx, *flat_args):
142141
# hmm... this doesn't feel right. todo
143142
contiguous_args = [t.contiguous() for t in flat_args]
144-
out = compiled_bw(*ctx.activations, *contiguous_args)
143+
out = compiled_bw(*ctx.saved_tensors, *contiguous_args)
145144
if not isinstance(out, list):
146145
out = [out]
147146
out_iter = iter(out)
148147
grad_out = [next(out_iter) if p else None for p in ctx.needs_input_grad]
149148
return tuple(grad_out)
150-
149+
151150
return CompiledFunction
152151

153152

153+
# using this reduces the overhead by about 50%
154+
# import tree
154155
def compiled_function(fn, fw_compiler, bw_compiler, partition_fn=default_partition):
155156
saved_fn = None
156157

157158
def returned_function(*args, **kwargs):
158159
nonlocal saved_fn
159-
flattened_args, args_spec = pytree.tree_flatten((args, kwargs))
160+
# flattened_args = tree.flatten((args, kwargs))
161+
flattened_args, _ = pytree.tree_flatten((args, kwargs))
160162

161163
if saved_fn is None:
164+
flattened_args, args_spec = pytree.tree_flatten((args, kwargs))
162165
def flat_fn(*args):
163166
args, kwargs = pytree.tree_unflatten(args, args_spec)
164167
return fn(*args, **kwargs)

test/test_pythonkey.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _nop_compile(x, _):
222222

223223
def _outs_and_grads(fn, inps):
224224
outs = fn(*inps)
225-
[out.sum().backward() for out in outs]
225+
[out.sum().backward(retain_graph=True) for out in outs]
226226
grads = [inp.grad for inp in inps]
227227
for inp in inps:
228228
inp.grad = None

0 commit comments

Comments
 (0)