Skip to content

Commit b6604ad

Browse files
committed
feat(autograd): add primitive for np.unwrap
1 parent 82bf4ba commit b6604ad

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Add support for `np.unwrap` in `tidy3d.plugins.autograd`.
12+
1013
### Fixed
1114
- Arrow lengths are now scaled consistently in the X and Y directions,
1215
and their lengths no longer exceed the height of the plot window.

tests/test_plugins/autograd/primitives/test_misc.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import autograd.numpy as np
34
import pytest
45
from autograd.test_util import check_grads
56

@@ -22,3 +23,19 @@
2223
def test_gaussian_filter_grad(rng, size, ndim, sigma, mode):
2324
x = rng.random((size,) * ndim)
2425
check_grads(lambda x: gaussian_filter(x, sigma=sigma, mode=mode), modes=["rev"], order=2)(x)
26+
27+
28+
@pytest.mark.parametrize("shape, axis", [((100,), -1), ((10, 12), 0), ((10, 12), 1)])
29+
@pytest.mark.parametrize("period", [np.pi, 2 * np.pi])
30+
@pytest.mark.parametrize("discont", [None, 0.6])
31+
def test_unwrap_grad(rng, shape, axis, period, discont):
32+
"""Test the gradient of the unwrap function with various arguments."""
33+
if discont is not None:
34+
# discont must be > period / 2 to have an effect
35+
discont = discont * period
36+
37+
x = rng.uniform(-4 * period, 4 * period, shape)
38+
39+
check_grads(
40+
lambda x: np.unwrap(x, discont=discont, axis=axis, period=period), modes=["fwd", "rev"]
41+
)(x)
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from __future__ import annotations
22

3+
import autograd.numpy as np
34
import scipy.ndimage
4-
from autograd.extend import defvjp, primitive
5+
from autograd.extend import defjvp, defvjp, primitive
56

67
gaussian_filter = primitive(scipy.ndimage.gaussian_filter)
78
defvjp(
89
gaussian_filter,
910
lambda ans, x, *args, **kwargs: lambda g: gaussian_filter(g, *args, **kwargs),
1011
)
12+
13+
np.unwrap = primitive(np.unwrap)
14+
defjvp(np.unwrap, lambda g, ans, x, *args, **kwargs: g)
15+
defvjp(np.unwrap, lambda ans, x, *args, **kwargs: lambda g: g)

0 commit comments

Comments
 (0)