Skip to content

Commit 2c9ae23

Browse files
committed
fix: xarray 2023.08.0 compatibility
1 parent 143efcf commit 2c9ae23

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

tests/test_components/test_autograd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,7 +1630,7 @@ def test_custom_methods_grads(self, attr):
16301630
"""Test grads of TidyArrayBox methods implemented in autograd/boxes.py"""
16311631

16321632
def objective(x, attr):
1633-
da = DataArray(x, dims=map(str, range(x.ndim)))
1633+
da = DataArray(x)
16341634
attr_value = getattr(da, attr)
16351635
val = attr_value() if callable(attr_value) else attr_value
16361636
return val.item()
@@ -1643,11 +1643,11 @@ def test_multiply_at_grads(self, rng):
16431643

16441644
def objective(a, b):
16451645
coords = {str(i): np.arange(a.shape[i]) for i in range(a.ndim)}
1646-
da = DataArray(a, coords=coords, dims=map(str, range(a.ndim)))
1646+
da = DataArray(a, coords=coords)
16471647
da_mult = da.multiply_at(b, "0", [0, 1]) ** 2
16481648
return np.sum(da_mult).item()
16491649

16501650
a = rng.uniform(-1, 1, (3, 3))
16511651
b = 1.0
1652-
check_grads(lambda x: objective(x, b), modes=["fwd", "rev"], order=1)(a)
1653-
check_grads(lambda x: objective(a, x), modes=["fwd", "rev"], order=1)(b)
1652+
check_grads(lambda x: objective(x, b), modes=["fwd", "rev"], order=2)(a)
1653+
check_grads(lambda x: objective(a, x), modes=["fwd", "rev"], order=2)(b)

tidy3d/components/autograd/boxes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,21 @@
55
from typing import Any, Callable, Dict, List, Tuple
66

77
import autograd.numpy as anp
8+
from autograd.extend import defjvp
89
from autograd.numpy.numpy_boxes import ArrayBox
10+
from autograd.numpy.numpy_wrapper import _astype
911

1012
TidyArrayBox = ArrayBox # NOT a subclass
1113

1214
_autograd_module_cache = {} # cache for imported autograd modules
1315

16+
defjvp(
17+
_astype,
18+
lambda g, ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: _astype(g, dtype),
19+
)
20+
21+
anp.astype = _astype
22+
1423

1524
@classmethod
1625
def from_arraybox(cls, box: ArrayBox) -> TidyArrayBox:

0 commit comments

Comments
 (0)