Skip to content

Commit 8b9451a

Browse files
committed
Fix NormalOp weighting behaviour
1 parent 08c3ff7 commit 8b9451a

File tree

3 files changed

+49
-18
lines changed

3 files changed

+49
-18
lines changed

src/NormalOp.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
11
export normalOperator
22

3+
"""
4+
NormalOp(T::Type; parent, weights)
5+
6+
Lazy normal operator of `parent` with an optional weighting operator `weights.`
7+
Computes `adjoint(parent) * weights * parent`.
8+
9+
# Required Argument
10+
* `T` - type of elements, .e.g. `Float64` for `ComplexF32`
11+
12+
# Required Keyword argument
13+
* `parent` - Base operator
14+
15+
# Optional Keyword argument
16+
* `weights` - Optional weights for normal operator. Must already be of form `weights = adjoint.(w) .* w`
17+
18+
"""
319
function LinearOperatorCollection.NormalOp(::Type{T}; parent, weights = opEye(eltype(parent), size(parent, 1), S = storage_type(parent))) where T <: Number
420
return NormalOp(T, parent, weights)
521
end
@@ -47,7 +63,6 @@ function NormalOpImpl(parent, weights, tmp)
4763
function produ!(y, parent, weights, tmp, x)
4864
mul!(tmp, parent, x)
4965
mul!(tmp, weights, tmp) # This can be dangerous. We might need to create two tmp vectors
50-
mul!(tmp, weights, tmp)
5166
return mul!(y, adjoint(parent), tmp)
5267
end
5368

@@ -63,6 +78,11 @@ function Base.copy(S::NormalOpImpl)
6378
return NormalOpImpl(copy(S.parent), S.weights, copy(S.tmp))
6479
end
6580

81+
"""
82+
normalOperator(parent (, weights); kwargs...)
83+
84+
Constructs a normal operator of the parent in an opinionated way, i.e. it tries to apply optimisations to the resulting operator.
85+
"""
6686
function normalOperator(parent, weights=opEye(eltype(parent), size(parent, 1), S= storage_type(parent)); kwargs...)
6787
return NormalOp(eltype(storage_type((parent))); parent = parent, weights = weights)
6888
end

src/ProdOp.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ end
121121
# In this case we are converting the left argument into a
122122
# weighting matrix, that is passed to normalOperator
123123
# TODO Port vom MRIOperators drops given weighting matrix, I just left it out for now
124-
normalOperator(S::ProdOp{T, <:WeightingOp, matT}; kwargs...) where {T, matT} = normalOperator(S.B, S.A; kwargs...)
124+
"""
125+
normalOperator(prod::ProdOp{T, <:WeightingOp, matT}; kwargs...)
126+
127+
Fuses weights of `ẀeightingOp` by computing `adjoint.(weights) .* weights`
128+
"""
129+
normalOperator(S::ProdOp{T, <:WeightingOp, matT}; kwargs...) where {T, matT} = normalOperator(S.B, WeightingOp(adjoint.(S.A.weights) .* S.A.weights); kwargs...)
125130
function normalOperator(S::ProdOp, W=opEye(eltype(S),size(S,1), S = storage_type(S)); kwargs...)
126131
arrayType = storage_type(S)
127132
tmp = arrayType(undef, size(S.A, 2))

test/testNormalOp.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11

22
@testset "Normal Operator" begin
33
for arrayType in arrayTypes
4-
@testset "$arrayType" begin
5-
N = 512
4+
for elType in [Float32, ComplexF32]
5+
@testset "$arrayType" begin
6+
N = 512
67

7-
Random.seed!(1234)
8-
x = arrayType(rand(N))
9-
A = arrayType(rand(N,N))
10-
A_adj = arrayType(collect(adjoint(A))) # LinearOperators can't resolve storage_type otherwise
11-
W = WeightingOp(arrayType(rand(N)))
12-
WA = W*A
8+
Random.seed!(1234)
9+
x = arrayType(rand(elType, N))
10+
A = arrayType(rand(elType, N,N))
11+
A_adj = arrayType(collect(adjoint(A))) # LinearOperators can't resolve storage_type otherwise
12+
W = WeightingOp(arrayType(rand(elType, N)))
13+
WA = W*A
14+
WHW = adjoint.(W.weights) .* W.weights
15+
prod = ProdOp(W, A)
1316

14-
y1 = Array(A_adj*W*W*A*x)
15-
y2 = Array(adjoint(WA) * WA * x)
16-
y = Array(normalOperator(A,W)*x)
17+
y1 = Array(A_adj*adjoint(W)*W*A*x)
18+
y2 = Array(adjoint(WA) * WA * x)
19+
y3 = Array(normalOperator(prod) * x)
20+
y4 = Array(normalOperator(A, WHW)*x)
1721

18-
@test norm(y1 - y) / norm(y) 0 atol=0.01
19-
@test norm(y2 - y) / norm(y) 0 atol=0.01
22+
@test norm(y1 - y4) / norm(y4) 0 atol=0.01
23+
@test norm(y2 - y4) / norm(y4) 0 atol=0.01
24+
@test norm(y3 - y4) / norm(y4) 0 atol=0.01
2025

2126

22-
y1 = Array(adjoint(A)*A*x)
23-
y = Array(normalOperator(A)*x)
27+
y1 = Array(adjoint(A)*A*x)
28+
y = Array(normalOperator(A)*x)
2429

25-
@test norm(y1 - y) / norm(y) 0 atol=0.01
30+
@test norm(y1 - y) / norm(y) 0 atol=0.01
31+
end
2632
end
2733
end
2834
end

0 commit comments

Comments
 (0)