Skip to content

Commit 90625ca

Browse files
authored
Minor improvements for _autograd_grad (#750)
I was really annoyed at the fact that we preallocate result tensors for everything and then throw most of them out. New code variant doesn't do that. Signed-off-by: Edward Z. Yang <[email protected]>
1 parent 105976d commit 90625ca

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

test/test_ops.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,46 @@
3535
from functorch._src.eager_transforms import _as_tuple, jvp
3636
aten = torch.ops.aten
3737

38-
# Version of autograd.grad that handles outputs that don't depend on inputs
3938

40-
41-
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
39+
# Version of autograd.grad with some differences:
40+
# - pytree inputs is allowed (but leaves of the pytree have to all
41+
# be tensors)
42+
# - if an input is not used as part of derivatives, we will return a
43+
# zero-filled tensor for the result
44+
def _autograd_grad(
45+
outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
46+
):
4247
inputs, inputs_spec = tree_flatten(inputs)
43-
result = [torch.zeros_like(inp) for inp in inputs]
44-
diff_argnums = tuple(i for i, inp in enumerate(inputs) if inp.requires_grad)
45-
inputs = tuple(inputs[i] for i in diff_argnums)
48+
diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
4649
if grad_outputs is None:
4750
diff_outputs = tuple(out for out in outputs if out.requires_grad)
4851
else:
49-
something = [(out, go) for out, go in zip(outputs, grad_outputs)
50-
if out.requires_grad]
51-
if len(something) == 0:
52+
diff_grad_outputs = [
53+
(out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
54+
]
55+
if len(diff_grad_outputs) == 0:
5256
diff_outputs, grad_outputs = (), ()
5357
else:
54-
diff_outputs, grad_outputs = zip(*something)
55-
if len(diff_outputs) == 0:
56-
return tuple(torch.zeros_like(inp) for inp in inputs)
57-
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
58-
retain_graph=retain_graph,
59-
create_graph=create_graph,
60-
allow_unused=True)
61-
grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
62-
for gi, inp in zip(grad_inputs, inputs))
63-
for idx, grad_inp in zip(diff_argnums, grad_inputs):
64-
result[idx] = grad_inp
58+
diff_outputs, grad_outputs = zip(*diff_grad_outputs)
59+
grad_inputs = torch.autograd.grad(
60+
diff_outputs,
61+
diff_inputs,
62+
grad_outputs,
63+
retain_graph=retain_graph,
64+
create_graph=create_graph,
65+
allow_unused=True,
66+
)
67+
result = []
68+
grad_inputs_iter = iter(grad_inputs)
69+
for inp in inputs:
70+
if inp.requires_grad:
71+
grad_input = next(grad_inputs_iter)
72+
if grad_input is None:
73+
result.append(torch.zeros_like(inp))
74+
else:
75+
result.append(grad_input)
76+
else:
77+
result.append(torch.zeros_like(inp))
6578
return tree_unflatten(result, inputs_spec)
6679

6780

0 commit comments

Comments
 (0)