Skip to content

Commit 164b4b8

Browse files
yaugenst-flexmomchil-flex
authored andcommitted
feat: Autograd support for FieldProjectionKSpaceMonitor
1 parent b913b70 commit 164b4b8

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- Autograd support for local field projections using `FieldProjectionKSpaceMonitor`.
12+
1013
### Fixed
1114
- Regression in local field projection leading to incorrect results for `far_field_approx=True`.
1215

tests/test_components/test_autograd.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def objective(args):
10031003

10041004

10051005
@pytest.mark.parametrize("far_field_approx", [True, False])
1006-
@pytest.mark.parametrize("projection_type", ["angular", "cartesian"])
1006+
@pytest.mark.parametrize("projection_type", ["angular", "cartesian", "kspace"])
10071007
@pytest.mark.parametrize("sim_2d", [True, False])
10081008
class TestFieldProjection:
10091009
@staticmethod
@@ -1047,6 +1047,20 @@ def setup(far_field_approx, projection_type, sim_2d):
10471047
far_field_approx=far_field_approx,
10481048
name="far_field",
10491049
)
1050+
elif projection_type == "kspace":
1051+
ux = np.linspace(-0.7, 0.7, 2)
1052+
uy = np.linspace(-0.7, 0.7, 3)
1053+
monitor_far = td.FieldProjectionKSpaceMonitor(
1054+
center=monitor.center,
1055+
size=monitor.size,
1056+
freqs=monitor.freqs,
1057+
ux=ux,
1058+
uy=uy,
1059+
proj_axis=1,
1060+
proj_distance=r_proj,
1061+
far_field_approx=far_field_approx,
1062+
name="far_field",
1063+
)
10501064

10511065
sim = SIM_BASE.updated_copy(monitors=[monitor])
10521066

tidy3d/components/field_projection.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def _project_fields_kspace(
759759

760760
# compute projected fields for the dataset associated with each monitor
761761
field_names = ("Er", "Etheta", "Ephi", "Hr", "Htheta", "Hphi")
762-
fields = [np.zeros((len(ux), len(uy), 1, len(freqs)), dtype=complex) for _ in field_names]
762+
fields = np.zeros((len(field_names), len(ux), len(uy), 1, len(freqs)), dtype=complex)
763763

764764
medium = monitor.medium if monitor.medium else self.medium
765765
k = AbstractFieldProjectionData.wavenumber(medium=medium, frequency=freqs)
@@ -793,16 +793,17 @@ def _project_fields_kspace(
793793
currents=currents,
794794
medium=medium,
795795
)
796-
for field, _field in zip(fields, _fields):
797-
field = add_at(field, [i, j, 0, idx_f], _field * phase[idx_f])
798-
796+
where = (slice(None), i, j, 0, idx_f)
797+
_fields = anp.reshape(_fields, fields[where].shape)
798+
fields = add_at(fields, where, _fields * phase[idx_f])
799799
else:
800800
_x, _y, _z = monitor.sph_2_car(monitor.proj_distance, theta, phi)
801801
_fields = self._fields_for_surface_exact(
802802
x=_x, y=_y, z=_z, surface=surface, currents=currents, medium=medium
803803
)
804-
for field, _field in zip(fields, _fields):
805-
field = add_at(field, [i, j, 0], _field)
804+
where = (slice(None), i, j, 0)
805+
_fields = anp.reshape(_fields, fields[where].shape)
806+
fields = add_at(fields, where, _fields)
806807

807808
coords = {
808809
"ux": np.array(monitor.ux),

0 commit comments

Comments
 (0)