Skip to content

Commit 7ae7ddb

Browse files
committed
Fix bugs in FFTOp and SamplingOp
1 parent 8811386 commit 7ae7ddb

File tree

4 files changed

+21
-21
lines changed

4 files changed

+21
-21
lines changed

ext/LinearOperatorFFTWExt/FFTOp.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,8 @@ function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::
6464
end
6565
end
6666

67-
function fft_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, factor::T, tmpVec::AbstractArray{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
68-
tmpVec[:] .= x
69-
plan * tmpVec
67+
function fft_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, ::NTuple{D}, factor::T, tmpVec::AbstractArray{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
68+
plan * copyto!(tmpVec, x)
7069
res .= factor .* vec(tmpVec)
7170
end
7271

src/SamplingOp.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ function SamplingOpImpl(T::Type{<:Number}, pattern::AbstractArray{Int}, shape::T
3434
end
3535

3636
function SamplingOpImpl(T::Type{<:Number}, pattern::AbstractArray{Bool}; S = Vector{T})
37-
38-
function prod!(res::Vector{U}, x::Vector{V}) where {U,V}
37+
pattern = copyto!(similar(S(undef,0), Bool, size(pattern)...), pattern)
38+
39+
function prod!(res::AbstractArray{U}, x::AbstractArray{V}) where {U,V}
3940
res .= pattern.*x
4041
end
4142

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Wavelets
77
using NFFT
88
using JLArrays
99

10-
arrayTypes = [Array]
10+
arrayTypes = [Array, JLArray]
1111

1212
@testset "LinearOperatorCollection" begin
1313
include("testNormalOp.jl")

test/testOperators.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ function testDirectionalGradOp(N=64;arrayType = Array)
150150
G2 = GradientOp(eltype(x); shape=size(x), dims=2, S = typeof(xop))
151151
G_1d = Bidiagonal(ones(N),-ones(N-1), :U)[1:N-1,:]
152152

153-
y1 = Array(G1*x)
154-
y2 = Array(G2*x)
153+
y1 = Array(G1*xop)
154+
y2 = Array(G2*xop)
155155
y1_ref = zeros(ComplexF64, N-1,N)
156156
y2_ref = zeros(ComplexF64, N, N-1)
157157
for i=1:N
@@ -183,7 +183,7 @@ function testSampling(N=64;arrayType = Array)
183183
idx = shuffle(collect(1:N^2)[1:N*div(N,2)])
184184
SOp = SamplingOp(ComplexF64, pattern=idx, shape=(N,N), S = typeof(xop))
185185
y = Array(SOp*xop)
186-
x2 = Array(adjoint(SOp)*y)
186+
x2 = Array(adjoint(SOp)*arrayType(y))
187187
# mask-based sampling
188188
msk = zeros(Bool,N*N);msk[idx].=true
189189
SOp2 = SamplingOp(ComplexF64, pattern=msk, S = typeof(xop))
@@ -258,7 +258,7 @@ function testNFFT2d(N=16;arrayType = Array)
258258
# test type stability;
259259
# TODO: Ensure type stability for Trajectory objects and test here
260260
nodes = Float32.(nodes)
261-
F_nfft = NFFTOp(ComplexF32; shape=(N,N), nodes, symmetrize=false, S = typeof(xop))
261+
F_nfft = NFFTOp(ComplexF32; shape=(N,N), nodes, symmetrize=false, S = typeof(ComplexF32.(xop)))
262262

263263
y_nfft = F_nfft * ComplexF32.(xop)
264264
y_adj_nfft = adjoint(F_nfft) * ComplexF32.(xop)
@@ -316,31 +316,31 @@ end
316316
# TODO RadonOp
317317

318318
@testset "Linear Operators" begin
319-
for arrayType in arrayTypes
320-
@info "test DCT-II and DCT-IV Ops"
319+
@testset for arrayType in arrayTypes
320+
@info "test DCT-II and DCT-IV Ops: $arrayType"
321321
for N in [2,8,16,32]
322-
@test testDCT1d(N;arrayType) skip = !isa(arrayType, Array)
322+
@test testDCT1d(N;arrayType) skip = arrayType != Array # Not implemented for GPUs
323323
end
324-
@info "test FFTOp"
324+
@info "test FFTOp: $arrayType"
325325
for N in [8,16,32]
326326
@test testFFT1d(N,false;arrayType)
327327
@test testFFT1d(N,true;arrayType)
328328
@test testFFT2d(N,false;arrayType)
329329
@test testFFT2d(N,true;arrayType)
330330
end
331-
@info "test WeightingOp"
331+
@info "test WeightingOp: $arrayType"
332332
@test testWeighting(512;arrayType)
333-
@info "test GradientOp"
333+
@info "test GradientOp: $arrayType"
334334
@test testGradOp1d(512;arrayType)
335335
@test testGradOp2d(64;arrayType)
336336
@test testDirectionalGradOp(64;arrayType)
337-
@info "test SamplingOp"
337+
@info "test SamplingOp: $arrayType"
338338
@test testSampling(64;arrayType)
339-
@info "test WaveletOp"
339+
@info "test WaveletOp: $arrayType"
340340
@test testWavelet(64,64;arrayType)
341341
@test testWavelet(64,60;arrayType)
342-
@info "test NFFTOp"
343-
@test testNFFT2d(;arrayType) skip = arrayType == JLArray
344-
@test testNFFT3d(;arrayType) skip = arrayType == JLArray
342+
@info "test NFFTOp: $arrayType"
343+
@test testNFFT2d(;arrayType) skip = arrayType == JLArray # JLArray does not have a NFFTPlan
344+
@test testNFFT3d(;arrayType) skip = arrayType == JLArray # JLArray does not have a NFFTPlan
345345
end
346346
end

0 commit comments

Comments
 (0)