|
35 | 35 | from functorch._src.eager_transforms import _as_tuple, jvp
|
36 | 36 | aten = torch.ops.aten
|
37 | 37 |
|
38 |
| -# Version of autograd.grad that handles outputs that don't depend on inputs |
39 | 38 |
|
40 |
| - |
41 |
| -def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True): |
| 39 | +# Version of autograd.grad with some differences: |
| 40 | +# - pytree inputs is allowed (but leaves of the pytree have to all |
| 41 | +# be tensors) |
| 42 | +# - if an input is not used as part of derivatives, we will return a |
| 43 | +# zero-filled tensor for the result |
| 44 | +def _autograd_grad( |
| 45 | + outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True |
| 46 | +): |
42 | 47 | inputs, inputs_spec = tree_flatten(inputs)
|
43 |
| - result = [torch.zeros_like(inp) for inp in inputs] |
44 |
| - diff_argnums = tuple(i for i, inp in enumerate(inputs) if inp.requires_grad) |
45 |
| - inputs = tuple(inputs[i] for i in diff_argnums) |
| 48 | + diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) |
46 | 49 | if grad_outputs is None:
|
47 | 50 | diff_outputs = tuple(out for out in outputs if out.requires_grad)
|
48 | 51 | else:
|
49 |
| - something = [(out, go) for out, go in zip(outputs, grad_outputs) |
50 |
| - if out.requires_grad] |
51 |
| - if len(something) == 0: |
| 52 | + diff_grad_outputs = [ |
| 53 | + (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad |
| 54 | + ] |
| 55 | + if len(diff_grad_outputs) == 0: |
52 | 56 | diff_outputs, grad_outputs = (), ()
|
53 | 57 | else:
|
54 |
| - diff_outputs, grad_outputs = zip(*something) |
55 |
| - if len(diff_outputs) == 0: |
56 |
| - return tuple(torch.zeros_like(inp) for inp in inputs) |
57 |
| - grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs, |
58 |
| - retain_graph=retain_graph, |
59 |
| - create_graph=create_graph, |
60 |
| - allow_unused=True) |
61 |
| - grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi |
62 |
| - for gi, inp in zip(grad_inputs, inputs)) |
63 |
| - for idx, grad_inp in zip(diff_argnums, grad_inputs): |
64 |
| - result[idx] = grad_inp |
| 58 | + diff_outputs, grad_outputs = zip(*diff_grad_outputs) |
| 59 | + grad_inputs = torch.autograd.grad( |
| 60 | + diff_outputs, |
| 61 | + diff_inputs, |
| 62 | + grad_outputs, |
| 63 | + retain_graph=retain_graph, |
| 64 | + create_graph=create_graph, |
| 65 | + allow_unused=True, |
| 66 | + ) |
| 67 | + result = [] |
| 68 | + grad_inputs_iter = iter(grad_inputs) |
| 69 | + for inp in inputs: |
| 70 | + if inp.requires_grad: |
| 71 | + grad_input = next(grad_inputs_iter) |
| 72 | + if grad_input is None: |
| 73 | + result.append(torch.zeros_like(inp)) |
| 74 | + else: |
| 75 | + result.append(grad_input) |
| 76 | + else: |
| 77 | + result.append(torch.zeros_like(inp)) |
65 | 78 | return tree_unflatten(result, inputs_spec)
|
66 | 79 |
|
67 | 80 |
|
|
0 commit comments