Skip to content

Commit b01733e

Browse files
yaugenst-flextxdai
andauthored
fix(pytorch): Correct gradient for array-valued functions in wrapper (#2608)
* fix(pytorch): Correct gradient for array-valued functions in wrapper The `to_torch` wrapper, which connects `autograd` functions to PyTorch's autograd system, failed to compute correct gradients for functions that returned multi-element arrays. The root cause was in the `_Wrapper.backward` method: 1. The vector-Jacobian product function (`vjp`) was called with an array of ones instead of the true upstream gradient (`grad_output`). 2. The result was then incorrectly multiplied by `grad_output` again. This worked by coincidence for scalar outputs, where the upstream gradient is often `1.0`, but produced incorrect gradients for array outputs. This commit corrects the implementation by passing the NumPy-converted `grad_output` directly to the `vjp` function and removing the subsequent redundant multiplication. The wrapper now correctly supports differentiation through functions that return tensors of any shape. * test(pytorch): Add test for array-valued function gradients --------- Co-authored-by: Tianxiang Dai <[email protected]>
1 parent 3a467ab commit b01733e

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

CHANGELOG.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- Add support for `np.unwrap` in `tidy3d.plugins.autograd`.
1212
- Add Nunley variant to germanium material library based on Nunley et al. 2016 data.
1313

14-
### Fixed
15-
- Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window.
16-
- Bug in `PlaneWave` defined with a negative `angle_theta` which would lead to wrong injection.
17-
- Plots of objects defined by shape intersection logic will no longer display thin line artifacts.
18-
1914
### Changed
2015
- Switched to an analytical gradient calculation for spatially-varying pole-residue models (`CustomPoleResidue`).
21-
- Significantly improved performance of the `tidy3d.plugins.autograd.grey_dilation` morphological operation and its gradient calculation. The new implementation is orders of magnitude faster, especially for large arrays and kernel sizes.
2216
- `GaussianBeam` and `AstigmaticGaussianBeam` default `num_freqs` reset to 1 (it was set to 3 in v2.8.0) and a warning is issued for a broadband, angled beam for which `num_freqs` may not be sufficiently large.
2317
- Set the maximum `num_freqs` to 20 for all broadband sources (we have been warning about the introduction of this hard limit for a while).
18+
- Significantly improved performance of the `tidy3d.plugins.autograd.grey_dilation` morphological operation and its gradient calculation. The new implementation is orders of magnitude faster, especially for large arrays and kernel sizes.
19+
20+
### Fixed
21+
- Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window.
22+
- Bug in `PlaneWave` defined with a negative `angle_theta` which would lead to wrong injection.
23+
- Plots of objects defined by shape intersection logic will no longer display thin line artifacts.
24+
- Fixed incorrect gradient computation in PyTorch plugin (`to_torch`) for functions returning multi-element arrays.
2425

2526
## [2.9.0rc1] - 2025-06-10
2627

tests/test_plugins/pytorch/test_wrapper.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,42 @@ def f_np(x, y):
4444
expected_grad = elementwise_grad(f_np, argnum=[0, 1])(x_np, y_np)
4545

4646
assert_allclose(grad, expected_grad)
47+
48+
49+
def test_to_torch_array_valued_function(rng):
50+
"""Test that gradients are computed correctly for functions returning arrays with different shapes than input."""
51+
x_np = rng.uniform(-1, 1, (2, 2)).astype("f4")
52+
x_torch = torch.tensor(x_np, requires_grad=True)
53+
54+
# define a function that returns a different shape than input
55+
# this function maps (2,2) -> (2,3)
56+
def f_np(x):
57+
return anp.stack([x.sum(axis=1), x.mean(axis=1) * 2, x[:, 0] * x[:, 1]], axis=1)
58+
59+
f_torch = to_torch(f_np)
60+
61+
output = f_torch(x_torch)
62+
assert output.shape == (2, 3)
63+
64+
# create upstream gradient (simulating backprop from a loss)
65+
grad_output = torch.ones_like(output) # shape (2, 3)
66+
67+
output.backward(grad_output)
68+
69+
h = 1e-5
70+
expected_grad = anp.zeros_like(x_np)
71+
72+
for i in range(x_np.shape[0]):
73+
for j in range(x_np.shape[1]):
74+
x_plus = x_np.copy()
75+
x_plus[i, j] += h
76+
x_minus = x_np.copy()
77+
x_minus[i, j] -= h
78+
79+
f_plus = f_np(x_plus)
80+
f_minus = f_np(x_minus)
81+
82+
expected_grad[i, j] = anp.sum((f_plus - f_minus) / (2 * h))
83+
84+
computed_grad = x_torch.grad.numpy()
85+
assert_allclose(computed_grad, expected_grad, rtol=1e-3, atol=1e-3)

tidy3d/plugins/pytorch/wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66
from autograd import make_vjp
7-
from autograd.extend import vspace
87

98

109
def to_torch(fun):
@@ -79,10 +78,11 @@ def forward(ctx, *args):
7978

8079
@staticmethod
8180
def backward(ctx, grad_output):
82-
_grads = ctx.vjp(vspace(grad_output.detach().cpu().numpy()).ones())
81+
numpy_grad_output = grad_output.detach().cpu().numpy()
82+
_grads = ctx.vjp(numpy_grad_output)
8383
grads = [None] * ctx.num_args
8484
for idx, grad in zip(ctx.grad_argnums, _grads):
85-
grads[idx] = torch.as_tensor(grad, device=ctx.device) * grad_output
85+
grads[idx] = torch.as_tensor(grad, device=ctx.device)
8686
return tuple(grads)
8787

8888
def apply(*args, **kwargs):

0 commit comments

Comments
 (0)