Skip to content

Commit 5f32e0b

Browse files
authored
Merge pull request #557 from mrava87/patch-shifttype
fix: force dtype for shift operator inputs
2 parents 294f951 + f130a96 commit 5f32e0b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pylops/signalprocessing/shift.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def Shift(
109109
shift = _value_or_sized_to_array(shift)
110110

111111
if shift.size == 1:
112-
shift = np.exp(-1j * 2 * np.pi * Fop.f * shift)
112+
shift = np.exp(-1j * 2 * np.pi * Fop.f * shift).astype(Fop.cdtype)
113113
Sop = Diagonal(shift, dims=dimsdiag, axis=axis, dtype=Fop.cdtype)
114114
else:
115115
# add dimensions to shift to match dimensions of model and data
@@ -120,7 +120,7 @@ def Shift(
120120
sdims = np.ones(shift.ndim + 1, dtype=int)
121121
sdims[:axis] = shift.shape[:axis]
122122
sdims[axis + 1 :] = shift.shape[axis:]
123-
shift = np.exp(-1j * 2 * np.pi * f * shift.reshape(sdims))
123+
shift = np.exp(-1j * 2 * np.pi * f * shift.reshape(sdims)).astype(Fop.cdtype)
124124
Sop = Diagonal(shift, dtype=Fop.cdtype)
125125
Op = Fop.H * Sop * Fop
126126
Op.dims = Op.dimsd = Fop.dims

pylops/waveeqprocessing/blending.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
# Define shift operator
112112
self.shifts = (times // self.dt).astype(np.int32)
113113
diff = (times / self.dt - self.shifts) * self.dt
114-
diff = np.repeat(diff[:, np.newaxis], self.nr, axis=1)
114+
diff = np.repeat(diff[:, np.newaxis], self.nr, axis=1).astype(self.dtype)
115115
self.ShiftOp = Shift(
116116
(self.ns, self.nr, self.nt + 1),
117117
diff,

0 commit comments

Comments
 (0)