Skip to content

Commit 5e7ef85

Browse files
Gregory Robertsyaugenst-flex
authored andcommitted
feature[frontend]: data array method for extrapolating to endpoints outside of coordinate array that is autograd compatible
1 parent 44a7e25 commit 5e7ef85

File tree

5 files changed

+74
-8
lines changed

5 files changed

+74
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3737
- Fixed `ElectromagneticFieldData.to_zbf()` to support single frequency monitors and apply the correct flattening order.
3838
- Bug in `TerminalComponentModeler.get_antenna_metrics_data` when port amplitudes are set to zero.
3939
- Added missing `solver_version` keyword argument to `run_async`.
40+
- Fixed `interpn` data array method to be compatible with extrapolation outside of data array coordinates.
4041

4142
## [2.9.0] - 2025-08-04
4243

tests/test_data/test_data_arrays.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import autograd as ag
88
import autograd.numpy as np
9+
import numpy
910
import pytest
1011
import xarray.testing as xrt
1112
from autograd.test_util import check_grads
@@ -516,3 +517,66 @@ def test_with_updated_data_shape():
516517

517518
with pytest.raises(ValueError):
518519
arr2 = arr._with_updated_data(data=data, coords=coords)
520+
521+
522+
@pytest.mark.parametrize("method", ["nearest", "linear"])
523+
def test_interpn_with_extrapolation(rng, method):
524+
"""Checks that the extrapolation in `interpn` works as expected and that
525+
it is autograd compatible."""
526+
arr = td.SpatialDataArray(
527+
rng.random((1, 3, 4, 5), dtype=np.float64),
528+
coords={"x": [1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "f": [0, 1, 2, 3, 4]},
529+
)
530+
531+
for coord in arr.dims:
532+
endpoints = [
533+
arr.coords[coord].values[0] - 1.0,
534+
arr.coords[coord].values[-1] + 1.0,
535+
]
536+
537+
method_coord = method if (len(arr.coords[coord]) > 1) else "nearest"
538+
539+
offset_interp_coords = arr.coords[coord].values + 0.5
540+
coords_interp = {coord: [endpoints[0], *offset_interp_coords, endpoints[1]]}
541+
542+
extrapolate = arr._ag_interp(
543+
coords_interp, method=method_coord, kwargs={"fill_value": "extrapolate"}
544+
)
545+
546+
compare = arr.interp(
547+
coords_interp, method=method_coord, kwargs={"fill_value": "extrapolate"}
548+
)
549+
550+
numpy.testing.assert_allclose(
551+
extrapolate.data, compare.data, err_msg="Expected data to be close!"
552+
)
553+
554+
def f(params):
555+
arr = td.SpatialDataArray(
556+
params.reshape((1, 3, 4, 5)),
557+
coords={"x": [1], "y": [1, 2, 3], "z": [2, 3, 4, 5], "f": [0, 1, 2, 3, 4]},
558+
)
559+
560+
result = 0.0
561+
562+
for coord in arr.dims:
563+
endpoints = [
564+
arr.coords[coord].values[0] - 1.0,
565+
arr.coords[coord].values[-1] + 1.0,
566+
]
567+
568+
method_coord = method if (len(arr.coords[coord]) > 1) else "nearest"
569+
570+
offset_interp_coords = arr.coords[coord].values + 0.5
571+
coords_interp = {coord: [endpoints[0], *offset_interp_coords, endpoints[1]]}
572+
573+
interp_data = arr.interp(
574+
coords_interp, method=method_coord, kwargs={"fill_value": "extrapolate"}
575+
)
576+
577+
result += np.sum(interp_data.data)
578+
579+
return result
580+
581+
data = rng.random((1, 3, 4, 5), dtype=np.float64)
582+
check_grads(f, order=1, modes=["rev"])(data)

tidy3d/components/autograd/functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def interpn(
9898
xi: tuple[NDArray[np.float64], ...],
9999
*,
100100
method: InterpolationType = "linear",
101+
**kwargs,
101102
) -> NDArray[np.float64]:
102103
"""Interpolate over a rectilinear grid in arbitrary dimensions.
103104
@@ -137,7 +138,12 @@ def interpn(
137138
else:
138139
raise ValueError(f"Unsupported interpolation method: {method}")
139140

140-
itrp = RegularGridInterpolator(points, values, method=method)
141+
if kwargs.get("fill_value") == "extrapolate":
142+
itrp = RegularGridInterpolator(
143+
points, values, method=method, fill_value=None, bounds_error=False
144+
)
145+
else:
146+
itrp = RegularGridInterpolator(points, values, method=method)
141147

142148
# Prepare the grid for interpolation
143149
# This step reshapes the grid, checks for NaNs and out-of-bounds values

tidy3d/components/data/data_array.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -467,12 +467,7 @@ def _ag_interp_func(var, indexes_coords, method, **kwargs):
467467
data = anp.transpose(var.data, combined_permutation)
468468
xi = anp.stack([anp.ravel(new_xi.data) for new_xi in new_x], axis=-1)
469469

470-
result = interpn(
471-
[xn.data for xn in x],
472-
data,
473-
xi,
474-
method=method,
475-
)
470+
result = interpn([xn.data for xn in x], data, xi, method=method, **kwargs)
476471

477472
result = anp.moveaxis(result, 0, -1)
478473
result = anp.reshape(result, result.shape[:-1] + new_x[0].shape)

tidy3d/plugins/smatrix/component_modelers/terminal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Literal, Optional, Union
66

7-
import numpy as np
7+
import autograd.numpy as np
88
import pydantic.v1 as pd
99

1010
from tidy3d.components.base import cached_property

0 commit comments

Comments
 (0)