Skip to content

Commit cc52ae8

Browse files
authored
Merge pull request #549 from mrava87/patch-slidingdtype
fix: ensure sliding ops work with fp32
2 parents 99f5aa8 + 01568da commit cc52ae8

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

pylops/signalprocessing/sliding1d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def Sliding1D(
157157

158158
# create tapers
159159
if tapertype is not None:
160-
tap = taper(nwin, nover, tapertype=tapertype)
160+
tap = taper(nwin, nover, tapertype=tapertype).astype(Op.dtype)
161161
tapin = tap.copy()
162162
tapin[:nover] = 1
163163
tapend = tap.copy()
@@ -172,7 +172,9 @@ def Sliding1D(
172172
if tapertype is None:
173173
OOp = BlockDiag([Op for _ in range(nwins)])
174174
else:
175-
OOp = BlockDiag([Diagonal(taps[itap].ravel()) * Op for itap in range(nwins)])
175+
OOp = BlockDiag(
176+
[Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op for itap in range(nwins)]
177+
)
176178

177179
combining = HStack(
178180
[

pylops/signalprocessing/sliding2d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def Sliding2D(
191191

192192
# create tapers
193193
if tapertype is not None:
194-
tap = taper2d(dimsd[1], nwin, nover, tapertype=tapertype)
194+
tap = taper2d(dimsd[1], nwin, nover, tapertype=tapertype).astype(Op.dtype)
195195
tapin = tap.copy()
196196
tapin[:nover] = 1
197197
tapend = tap.copy()
@@ -206,7 +206,9 @@ def Sliding2D(
206206
if tapertype is None:
207207
OOp = BlockDiag([Op for _ in range(nwins)])
208208
else:
209-
OOp = BlockDiag([Diagonal(taps[itap].ravel()) * Op for itap in range(nwins)])
209+
OOp = BlockDiag(
210+
[Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op for itap in range(nwins)]
211+
)
210212

211213
combining = HStack(
212214
[

pylops/signalprocessing/sliding3d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,16 @@ def Sliding3D(
183183

184184
# create tapers
185185
if tapertype is not None:
186-
tap = taper3d(dimsd[2], nwin, nover, tapertype=tapertype)
186+
tap = taper3d(dimsd[2], nwin, nover, tapertype=tapertype).astype(Op.dtype)
187187

188188
# transform to apply
189189
if tapertype is None:
190190
OOp = BlockDiag([Op for _ in range(nwins)], nproc=nproc)
191191
else:
192-
OOp = BlockDiag([Diagonal(tap.ravel()) * Op for _ in range(nwins)], nproc=nproc)
192+
OOp = BlockDiag(
193+
[Diagonal(tap.ravel(), dtype=Op.dtype) * Op for _ in range(nwins)],
194+
nproc=nproc,
195+
)
193196

194197
hstack = HStack(
195198
[

0 commit comments

Comments
 (0)