Skip to content

Commit dddb844

Browse files
committed
fixed MatrixOp, Eye constructor from BlockVector
1 parent 85a9331 commit dddb844

File tree

12 files changed

+157
-21
lines changed

12 files changed

+157
-21
lines changed

src/AbstractOperators.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ include("linearoperators/Filt.jl")
4242
include("linearoperators/MIMOFilt.jl")
4343
include("linearoperators/Xcorr.jl")
4444
include("linearoperators/LBFGS.jl")
45-
# include("linearoperators/BlkDiagLBFGS.jl")
4645

4746
# Calculus rules
4847

src/calculus/DCAT.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,11 @@ function codomainType(H::DCAT)
181181
return (codomain...)
182182
end
183183

184+
is_eye(L::DCAT) = all(is_eye.(L.A))
184185
is_linear(L::DCAT) = all(is_linear.(L.A))
185186
is_diagonal(L::DCAT) = all(is_diagonal.(L.A))
186187
is_AcA_diagonal(L::DCAT) = all(is_AcA_diagonal.(L.A))
187-
is_Ac_diagonal(L::DCAT) = all(is_Ac_diagonal.(L.A))
188+
is_AAc_diagonal(L::DCAT) = all(is_AAc_diagonal.(L.A))
188189
is_orthogonal(L::DCAT) = all(is_orthogonal.(L.A))
189190
is_invertible(L::DCAT) = all(is_invertible.(L.A))
190191
is_full_row_rank(L::DCAT) = all(is_full_row_rank.(L.A))
@@ -208,3 +209,10 @@ function permute(H::DCAT{N,L,P1,P2}, p::AbstractVector{Int}) where {N,L,P1,P2}
208209

209210
DCAT(H.A,new_part,H.idxC)
210211
end
212+
213+
# special cases
214+
# Eye constructor
215+
Eye(x::A) where {N, A <: NTuple{N,AbstractArray}} = DCAT(Eye.(x)...)
216+
diag(L::DCAT{N,NTuple{N,E}}) where {N, E <: Eye} = 1.
217+
diag_AAc(L::DCAT{N,NTuple{N,E}}) where {N, E <: Eye} = 1.
218+
diag_AcA(L::DCAT{N,NTuple{N,E}}) where {N, E <: Eye} = 1.

src/linearoperators/Eye.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Eye(DomainType::Type, DomainDim::NTuple{N,Int}) where {N} = Eye{DomainType,N}(Do
3333
Eye(t::Type, dims::Vararg{Integer}) = Eye(t,dims)
3434
Eye(dims::NTuple{N, Integer}) where {N} = Eye(Float64,dims)
3535
Eye(dims::Vararg{Integer}) = Eye(Float64,dims)
36+
Eye(x::A) where {A <: AbstractArray} = Eye(eltype(x), size(x))
3637

3738
# Mappings
3839

src/linearoperators/LBFGS.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export LBFGS, update!
1+
export LBFGS, update!, reset!
22

33
"""
44
`LBFGS(domainType::Type,dim_in::Tuple, M::Integer)`
@@ -18,6 +18,9 @@ julia> update!(L,x,x_prev,grad,grad_prev); # update memory
1818
julia> d = L*grad; # compute new direction
1919
2020
```
21+
22+
Use `reset!(L)` to cancel the memory of the operator.
23+
2124
"""
2225

2326
mutable struct LBFGS{R, T <: BlockArray, M, I <: Integer} <: LinearOperator
@@ -81,6 +84,17 @@ function update!(L::LBFGS{R, T, M, I}, x::T, x_prev::T, gradx::T, gradx_prev::T)
8184
return L
8285
end
8386

87+
"""
88+
`reset!(L::LBFGS)`
89+
90+
Cancels the memory of `L`.
91+
"""
92+
93+
function reset!(L::LBFGS)
94+
L.currmem, L.curridx = 0, 0
95+
L.H = 1.0
96+
end
97+
8498
# LBFGS operators are symmetric
8599

86100
Ac_mul_B!(x::T, L::LBFGS{R, T, M, I}, y::T) where {R, T, M, I} = A_mul_B!(x, L, y)

src/linearoperators/MatrixOp.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ julia> MatrixOp(randn(20,10),4)
2828
2929
"""
3030

31-
struct MatrixOp{T, M <: AbstractMatrix{T}} <: LinearOperator
31+
struct MatrixOp{D, T, M <: AbstractMatrix{T}} <: LinearOperator
3232
A::M
3333
n_col_in::Integer
3434
end
@@ -37,35 +37,43 @@ end
3737

3838
##TODO decide what to do when domainType is given, with conversion one loses pointer to data...
3939
###standard constructor Operator{N}(DomainType::Type, DomainDim::NTuple{N,Int})
40-
function MatrixOp(DomainType::Type, DomainDim::NTuple{N,Int}, A::M) where {N, M <: AbstractMatrix}
40+
function MatrixOp(DomainType::Type, DomainDim::NTuple{N,Int}, A::M) where {N, T, M <: AbstractMatrix{T}}
4141
N > 2 && error("cannot multiply a Matrix by a n-dimensional Variable with n > 2")
4242
size(A,2) != DomainDim[1] && error("wrong input dimensions")
4343
if N == 1
44-
MatrixOp{DomainType, M}(A, 1)
44+
MatrixOp{DomainType, T, M}(A, 1)
4545
else
46-
MatrixOp{DomainType, M}(A, DomainDim[2])
46+
MatrixOp{DomainType, T, M}(A, DomainDim[2])
4747
end
4848
end
4949
###
5050

51-
MatrixOp(A::M) where {M <: AbstractMatrix} = MatrixOp{eltype(A), M}(A, 1)
52-
MatrixOp(T::Type, A::M) where {M <: AbstractMatrix} = MatrixOp{T, M}(A, 1)
53-
MatrixOp(A::M, n::Integer) where {M <: AbstractMatrix} = MatrixOp{eltype(A), M}(A, n)
54-
MatrixOp(T::Type, A::M, n::Integer) where {M <: AbstractMatrix} = MatrixOp{T, M}(A, n)
51+
MatrixOp(A::M) where {M <: AbstractMatrix} = MatrixOp(eltype(A), (size(A,2),), A)
52+
MatrixOp(D::Type, A::M) where {M <: AbstractMatrix} = MatrixOp(D, (size(A,2),), A)
53+
MatrixOp(A::M, n::Integer) where {M <: AbstractMatrix} = MatrixOp(eltype(A), (size(A,2), n), A)
54+
MatrixOp(D::Type, A::M, n::Integer) where {M <: AbstractMatrix} = MatrixOp(D, (size(A,2), n), A)
5555

5656
import Base: convert
57-
convert(::Type{LinearOperator}, L::M) where {T,M<:AbstractMatrix{T}} = MatrixOp{T,M}(L,1)
58-
convert(::Type{LinearOperator}, L::M, n::Integer) where {T,M<:AbstractMatrix{T}} = MatrixOp{T,M}(L, n)
57+
convert(::Type{LinearOperator}, L::M) where {T, M<:AbstractMatrix{T}} = MatrixOp{T, T, M}(L,1)
58+
convert(::Type{LinearOperator}, L::M, n::Integer) where {T, M<:AbstractMatrix{T}} = MatrixOp{T, T, M}(L, n)
59+
convert(::Type{LinearOperator}, dom::Type, dim_in::Tuple, L::AbstractMatrix) = MatrixOp(dom, dim_in, L)
5960

6061
# Mappings
6162

62-
A_mul_B!(y::AbstractArray, L::MatrixOp{M, T}, b::AbstractArray) where {M, T} = A_mul_B!(y, L.A, b)
63-
Ac_mul_B!(y::AbstractArray, L::MatrixOp{M, T}, b::AbstractArray) where {M, T} = Ac_mul_B!(y, L.A, b)
63+
A_mul_B!(y::AbstractArray, L::MatrixOp{D, T, M}, b::AbstractArray) where {D, T, M} = A_mul_B!(y, L.A, b)
64+
Ac_mul_B!(y::AbstractArray, L::MatrixOp{D, T, M}, b::AbstractArray) where {D, T, M} = Ac_mul_B!(y, L.A, b)
65+
66+
# Special Case, real b, complex matrix
67+
function Ac_mul_B!(y::AbstractArray, L::MatrixOp{D, T}, b::AbstractArray) where {D <: Real , T <: Complex}
68+
yc = zeros(T,size(y))
69+
Ac_mul_B!(yc, L.A, b)
70+
y .= real.(yc)
71+
end
6472

6573
# Properties
6674

67-
domainType(L::MatrixOp{T, M}) where {T, M} = T
68-
codomainType(L::MatrixOp{T, M}) where {T, M} = T
75+
domainType(L::MatrixOp{D, T}) where {D, T} = D
76+
codomainType(L::MatrixOp{D, T}) where {D, T} = D <: Real && T <: Complex ? T : D
6977

7078
function size(L::MatrixOp)
7179
if L.n_col_in == 1

src/properties.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ julia> codomainType(vcat(Eye(Complex{Float64},(10,)),DFT(Complex{Float64},10)))
4848
"""
4949
codomainType
5050

51-
5251
"""
5352
`size(A::AbstractOperator, [dom,])`
5453
@@ -254,6 +253,15 @@ false
254253
"""
255254
is_full_column_rank(L::AbstractOperator) = false
256255

256+
257+
import Base: convert
258+
function convert(::Type{T}, dom::Type, dim_in::Tuple, L::T) where {T <: AbstractOperator}
259+
domainType(L) != dom && error("cannot convert operator with domain $(domainType(L)) to operator with domain $dom ")
260+
size(L,1) != dim_in && error("cannot convert operator with size $(size(L,1)) to operator with domain $dim_in ")
261+
return L
262+
end
263+
264+
257265
#printing
258266
function Base.show(io::IO, L::AbstractOperator)
259267
print(io, fun_name(L)" "*fun_space(L) )

src/utilities/block.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export RealOrComplex,
1414
blockset!,
1515
blockvecdot,
1616
blockzeros,
17+
blockones,
1718
blockaxpy!,
1819
blockiszero
1920

@@ -53,8 +54,8 @@ blockcopy!(y::AbstractArray, x::AbstractArray) = copy!(y, x)
5354
blockset!(y::Tuple, x) = blockset!.(y, x)
5455
blockset!(y::AbstractArray, x) = (y .= x)
5556

56-
blockvecdot(x::T, y::T) where {T <: Tuple} = sum(blockvecdot.(x,y))
57-
blockvecdot(x::AbstractArray{R}, y::AbstractArray{R}) where {R <: Number} = real(vecdot(x, y))
57+
blockvecdot(x::T1, y::T2) where {T1 <: Tuple, T2 <: Tuple} = sum(blockvecdot.(x,y))
58+
blockvecdot(x::AbstractArray{R1}, y::AbstractArray{R2}) where {R1 <: Number, R2 <: Number} = real(vecdot(x, y))
5859
# inner product must be always real see section 4.2 of TFOCS manual
5960

6061
blockzeros(t::Tuple, s::Tuple) = blockzeros.(t, s)
@@ -64,6 +65,13 @@ blockzeros(n::NTuple{N, Integer} where {N}) = zeros(n)
6465
blockzeros(n::Integer) = zeros(n)
6566
blockzeros(a::AbstractArray) = zeros(a)
6667

68+
blockones(t::Tuple, s::Tuple) = blockones.(t, s)
69+
blockones(t::Type, n::NTuple{N, Integer} where {N}) = ones(t, n)
70+
blockones(t::Tuple) = blockones.(t)
71+
blockones(n::NTuple{N, Integer} where {N}) = ones(n)
72+
blockones(n::Integer) = ones(n)
73+
blockones(a::AbstractArray) = ones(a)
74+
6775
blockaxpy!(z::Tuple, x, alpha::Real, y::Tuple) = blockaxpy!.(z, x, alpha, y)
6876
blockaxpy!(z::AbstractArray, x, alpha::Real, y::AbstractArray) = (z .= x .+ alpha.*y)
6977

test/test_block.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,21 @@ yb = blockzeros(blockeltype(xb),blocksize(xb))
6666
@test y == zeros(2)
6767
@test yb == (zeros(2),zeros(3)+im*zeros(3),zeros(2,3))
6868

69+
y = blockones(x)
70+
yb = blockones(xb)
71+
@test y == ones(2)
72+
@test yb == (ones(2),ones(3)+im*zeros(3),ones(2,3))
73+
74+
y = blockones(blocksize(x))
75+
yb = blockones(blocksize(xb))
76+
@test y == ones(2)
77+
@test yb == (ones(2),ones(3)+im*zeros(3),ones(2,3))
78+
79+
y = blockones(blockeltype(x),blocksize(x))
80+
yb = blockones(blockeltype(xb),blocksize(xb))
81+
@test y == ones(2)
82+
@test yb == (ones(2),ones(3)+im*zeros(3),ones(2,3))
83+
6984
blockaxpy!(y,x,2,x2)
7085
blockaxpy!(yb,xb,2,x2b)
7186

test/test_lbfgs.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,17 @@ for i = 1:5
7676

7777
end
7878

79+
@test blockones(size(H,1)) != H*blockones(size(H,1))
80+
@test blockones(size(HH,1)) != HH*blockones(size(HH,1))
81+
82+
#testing reset
83+
84+
reset!(H)
85+
reset!(HH)
86+
87+
@test blockones(size(H,1)) == H*blockones(size(H,1))
88+
@test blockones(size(HH,1)) == HH*blockones(size(HH,1))
89+
7990
end
8091

8192
test_lbfgs()

test/test_linear_operators.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ y1 = test_op(op, x1, randn(n), verb)
282282
op = Eye(Float64, (n,))
283283
op = Eye((n,))
284284
op = Eye(n)
285+
op = Eye(x1)
285286

286287
#properties
287288
@test is_linear(op) == true
@@ -472,21 +473,47 @@ op = GetIndex(Float64,(n,),(1:k,))
472473

473474
######## MatrixOp ############
474475

476+
# real matrix, real input
475477
n,m = 5,4
476478
A = randn(n,m)
477479
op = MatrixOp(Float64,(m,),A)
478480
x1 = randn(m)
479481
y1 = test_op(op, x1, randn(n), verb)
480482
y2 = A*x1
481483

484+
# real matrix, complex input
485+
n,m = 5,4
486+
A = randn(n,m)
487+
op = MatrixOp(Complex{Float64},(m,),A)
488+
x1 = randn(m)+im.*randn(m)
489+
y1 = test_op(op, x1, randn(n)+im*randn(n), verb)
490+
y2 = A*x1
491+
492+
# complex matrix, complex input
493+
n,m = 5,4
494+
A = randn(n,m)+im*randn(n,m)
495+
op = MatrixOp(Complex{Float64},(m,),A)
496+
x1 = randn(m)+im.*randn(m)
497+
y1 = test_op(op, x1, randn(n)+im*randn(n), verb)
498+
y2 = A*x1
499+
500+
# complex matrix, real input
501+
n,m = 5,4
502+
A = randn(n,m)+im*randn(n,m)
503+
op = MatrixOp(Float64,(m,),A)
504+
x1 = randn(m)
505+
y1 = test_op(op, x1, randn(n)+im*randn(n), verb)
506+
y2 = A*x1
507+
482508
@test all(vecnorm.(y1 .- y2) .<= 1e-12)
483509

510+
# complex matrix, real matrix input
484511
c = 3
485512
op = MatrixOp(Float64,(m,c),A)
486513
@test_throws ErrorException op = MatrixOp(Float64,(m,c,3),A)
487514
@test_throws MethodError op = MatrixOp(Float64,(m,c),randn(n,m,2))
488515
x1 = randn(m,c)
489-
y1 = test_op(op, x1, randn(n,c), verb)
516+
y1 = test_op(op, x1, randn(n,c).+randn(n,c), verb)
490517
y2 = A*x1
491518

492519
# other constructors
@@ -497,6 +524,7 @@ op = MatrixOp(Float64, A, c)
497524

498525
op = convert(LinearOperator,A)
499526
op = convert(LinearOperator,A,c)
527+
op = convert(LinearOperator, Complex{Float64}, size(x1), A)
500528

501529
##properties
502530
@test is_linear(op) == true

0 commit comments

Comments
 (0)