@@ -185,9 +185,9 @@ def get_amplitude(x) -> complex:
185
185
"""Get the complex amplitude out of some data."""
186
186
187
187
if isinstance (x , DataArray ):
188
- x = complex ( x .values )
188
+ x = x .values
189
189
190
- return 1j * complex (x )
190
+ return complex (x )
191
191
192
192
193
193
class AbstractFieldData (MonitorData , AbstractFieldDataset , ABC ):
@@ -2115,7 +2115,7 @@ def _make_adjoint_sources_amps(self, fwidth: float) -> list[ModeSource]:
2115
2115
for mode_index in coords ["mode_index" ]:
2116
2116
amp_single = self .amps .sel (f = freq , direction = direction , mode_index = mode_index )
2117
2117
2118
- if self .get_amplitude (amp_single ) == 0.0 :
2118
+ if abs ( self .get_amplitude (amp_single ) ) == 0.0 :
2119
2119
continue
2120
2120
2121
2121
adjoint_source = self ._adjoint_source_amp (amp = amp_single , fwidth = fwidth )
@@ -2138,7 +2138,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource:
2138
2138
amp_complex = self .get_amplitude (amp )
2139
2139
k0 = 2 * np .pi * freq0 / C_0
2140
2140
grad_const = k0 / 4 / ETA_0
2141
- src_amp = grad_const * amp_complex
2141
+ src_amp = 1j * grad_const * amp_complex
2142
2142
2143
2143
# construct source
2144
2144
src_adj = ModeSource (
@@ -3404,7 +3404,7 @@ def _make_adjoint_sources_amps(self, fwidth: float) -> list[PlaneWave]:
3404
3404
3405
3405
# ignore any amplitudes of 0.0 or nan
3406
3406
amp_complex = self .get_amplitude (amp_single )
3407
- if (amp_complex == 0.0 ) or np .isnan (amp_complex ):
3407
+ if (abs ( amp_complex ) == 0.0 ) or np .isnan (amp_complex ):
3408
3408
continue
3409
3409
3410
3410
# compute a plane wave for this amplitude (if propagating / not None)
0 commit comments