Skip to content

Commit 48d14b8

Browse files
committed
first commit
1 parent 1fc48ea commit 48d14b8

File tree

4 files changed

+189
-148
lines changed

4 files changed

+189
-148
lines changed

src/AbstractOperators.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@ abstract type AbstractOperator end
99
abstract type LinearOperator <: AbstractOperator end
1010
abstract type NonLinearOperator <: AbstractOperator end
1111

12-
import Base: A_mul_B!, Ac_mul_B!
12+
import Base: A_mul_B!, Ac_mul_B!
1313

1414
export LinearOperator,
1515
NonLinearOperator,
1616
AbstractOperator
1717

18-
# deep stuff
18+
# Block stuff
1919

20-
include("utilities/deep.jl")
20+
include("utilities/block.jl")
21+
include("utilities/deep.jl") # TODO: remove this eventually
22+
23+
# Predicates and properties
2124

22-
# predicates and properties
2325
include("properties.jl")
2426

2527
# Linear operators
@@ -43,7 +45,7 @@ include("linearoperators/Filt.jl")
4345
include("linearoperators/MIMOFilt.jl")
4446
include("linearoperators/Xcorr.jl")
4547
include("linearoperators/LBFGS.jl")
46-
include("linearoperators/BlkDiagLBFGS.jl")
48+
# include("linearoperators/BlkDiagLBFGS.jl")
4749

4850
# Calculus rules
4951

Lines changed: 128 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,128 @@
1-
2-
mutable struct BlkDiagLBFGS{M, N, R<:Real, A <:NTuple{N,AbstractArray}, B <:NTuple{M,A}} <: LinearOperator
3-
currmem::Int
4-
curridx::Int
5-
s::A
6-
y::A
7-
s_m::B
8-
y_m::B
9-
ys_m::Array{Float64,1}
10-
alphas::Array{Float64,1}
11-
H::R
12-
end
13-
14-
#constructors
15-
#default
16-
function LBFGS(T::NTuple{N,Type}, dim::NTuple{N,NTuple}, M::Int) where {N}
17-
s_m = tuple([deepzeros(T,dim) for i = 1:M]...)
18-
y_m = tuple([deepzeros(T,dim) for i = 1:M]...)
19-
s = deepzeros(T,dim)
20-
y = deepzeros(T,dim)
21-
R = real(T[1])
22-
ys_m = zeros(R, M)
23-
alphas = zeros(R, M)
24-
BlkDiagLBFGS{M,N,R,typeof(s),typeof(s_m)}(0, 0, s, y, s_m, y_m, ys_m, alphas, 1.)
25-
end
26-
27-
LBFGS(x::NTuple{N,AbstractArray},M::Int64) where {N} = LBFGS(eltype.(x),size.(x),M)
28-
29-
#mappings
30-
31-
@generated function update!(L::BlkDiagLBFGS{M,N,R,A,B},
32-
x::A,
33-
x_prev::A,
34-
gradx::A,
35-
gradx_prev::A) where {M,N,R,A,B}
36-
37-
ex = :(ys = 0.)
38-
for i = 1:N
39-
ex = :($ex; ys += update_s_y(L.s[$i],L.y[$i],x[$i],x_prev[$i],gradx[$i],gradx_prev[$i]))
40-
end
41-
ex2 = :(yty = 0.)
42-
for i = 1:N
43-
ex2 = :($ex2; yty += update_s_m_y_m(L.s_m[L.curridx][$i],L.y_m[L.curridx][$i],
44-
L.s[$i],L.y[$i] ))
45-
end
46-
47-
ex = quote
48-
$ex
49-
if ys > 0
50-
L.curridx += 1
51-
if L.curridx > M L.curridx = 1 end
52-
L.currmem += 1
53-
if L.currmem > M L.currmem = M end
54-
55-
$ex2
56-
L.ys_m[L.curridx] = ys
57-
L.H = ys/sum(yty)
58-
end
59-
return L
60-
end
61-
end
62-
63-
function update_s_y(s::A, y::A, x::A, x_prev::A, gradx::A, gradx_prev::A) where {A}
64-
s .= (-).(x, x_prev)
65-
y .= (-).(gradx, gradx_prev)
66-
ys = real(vecdot(s,y))
67-
return ys
68-
end
69-
70-
function update_s_m_y_m(s_m::A,y_m::A,s::A,y::A) where {A}
71-
s_m .= s
72-
y_m .= y
73-
74-
yty = real(vecdot(y,y))
75-
return yty
76-
end
77-
78-
function A_mul_B!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}, gradx::A) where {M, N, R, A, B}
79-
for i = 1:N
80-
d[i] .= (-).(gradx[i])
81-
end
82-
idx = loop1!(d,L)
83-
for i = 1:N
84-
d[i] .= (*).(L.H, d[i])
85-
end
86-
d = loop2!(d,idx,L)
87-
end
88-
89-
function loop1!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B}
90-
idx = L.curridx
91-
for i=1:L.currmem
92-
L.alphas[idx] = sum(real.(vecdot.(L.s_m[idx], d)))
93-
94-
L.alphas[idx] /= L.ys_m[idx]
95-
for ii = 1:N
96-
d[ii] .-= L.alphas[idx].*L.y_m[idx][ii]
97-
end
98-
idx -= 1
99-
if idx == 0 idx = M end
100-
end
101-
return idx
102-
end
103-
104-
function loop2!(d::A, idx::Int, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B}
105-
for i=1:L.currmem
106-
idx += 1
107-
if idx > M idx = 1 end
108-
beta = sum(real.(vecdot.(L.y_m[idx], d)))
109-
beta /= L.ys_m[idx]
110-
for ii = 1:N
111-
d[ii] .-= (beta-L.alphas[idx]).*L.s_m[idx][ii]
112-
end
113-
end
114-
return d
115-
end
116-
117-
function reset(L::BlkDiagLBFGS)
118-
L.currmem = 0
119-
L.curridx = 0
120-
end
121-
122-
# Properties
123-
124-
domainType(L::BlkDiagLBFGS) = eltype.(L.s)
125-
codomainType(L::BlkDiagLBFGS) = eltype.(L.s)
126-
127-
fun_name(A::BlkDiagLBFGS) = "LBFGS"
128-
size(A::BlkDiagLBFGS) = (size.(A.s), size.(A.s))
1+
#
2+
# mutable struct BlkDiagLBFGS{M, N, R<:Real, A <:NTuple{N,AbstractArray}, B <:NTuple{M,A}} <: LinearOperator
3+
# currmem::Int
4+
# curridx::Int
5+
# s::A
6+
# y::A
7+
# s_m::B
8+
# y_m::B
9+
# ys_m::Array{Float64,1}
10+
# alphas::Array{Float64,1}
11+
# H::R
12+
# end
13+
#
14+
# #constructors
15+
# #default
16+
# function LBFGS(T::NTuple{N,Type}, dim::NTuple{N,NTuple}, M::Int) where {N}
17+
# s_m = tuple([deepzeros(T,dim) for i = 1:M]...)
18+
# y_m = tuple([deepzeros(T,dim) for i = 1:M]...)
19+
# s = deepzeros(T,dim)
20+
# y = deepzeros(T,dim)
21+
# R = real(T[1])
22+
# ys_m = zeros(R, M)
23+
# alphas = zeros(R, M)
24+
# BlkDiagLBFGS{M,N,R,typeof(s),typeof(s_m)}(0, 0, s, y, s_m, y_m, ys_m, alphas, 1.)
25+
# end
26+
#
27+
# LBFGS(x::NTuple{N,AbstractArray},M::Int64) where {N} = LBFGS(eltype.(x),size.(x),M)
28+
#
29+
# #mappings
30+
#
31+
# @generated function update!(L::BlkDiagLBFGS{M,N,R,A,B},
32+
# x::A,
33+
# x_prev::A,
34+
# gradx::A,
35+
# gradx_prev::A) where {M,N,R,A,B}
36+
#
37+
# ex = :(ys = 0.)
38+
# for i = 1:N
39+
# ex = :($ex; ys += update_s_y(L.s[$i],L.y[$i],x[$i],x_prev[$i],gradx[$i],gradx_prev[$i]))
40+
# end
41+
# ex2 = :(yty = 0.)
42+
# for i = 1:N
43+
# ex2 = :($ex2; yty += update_s_m_y_m(L.s_m[L.curridx][$i],L.y_m[L.curridx][$i],
44+
# L.s[$i],L.y[$i] ))
45+
# end
46+
#
47+
# ex = quote
48+
# $ex
49+
# if ys > 0
50+
# L.curridx += 1
51+
# if L.curridx > M L.curridx = 1 end
52+
# L.currmem += 1
53+
# if L.currmem > M L.currmem = M end
54+
#
55+
# $ex2
56+
# L.ys_m[L.curridx] = ys
57+
# L.H = ys/sum(yty)
58+
# end
59+
# return L
60+
# end
61+
# end
62+
#
63+
# function update_s_y(s::A, y::A, x::A, x_prev::A, gradx::A, gradx_prev::A) where {A}
64+
# s .= (-).(x, x_prev)
65+
# y .= (-).(gradx, gradx_prev)
66+
# ys = real(vecdot(s,y))
67+
# return ys
68+
# end
69+
#
70+
# function update_s_m_y_m(s_m::A,y_m::A,s::A,y::A) where {A}
71+
# s_m .= s
72+
# y_m .= y
73+
#
74+
# yty = real(vecdot(y,y))
75+
# return yty
76+
# end
77+
#
78+
# function A_mul_B!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}, gradx::A) where {M, N, R, A, B}
79+
# for i = 1:N
80+
# d[i] .= (-).(gradx[i])
81+
# end
82+
# idx = loop1!(d,L)
83+
# for i = 1:N
84+
# d[i] .= (*).(L.H, d[i])
85+
# end
86+
# d = loop2!(d,idx,L)
87+
# end
88+
#
89+
# function loop1!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B}
90+
# idx = L.curridx
91+
# for i=1:L.currmem
92+
# L.alphas[idx] = sum(real.(vecdot.(L.s_m[idx], d)))
93+
#
94+
# L.alphas[idx] /= L.ys_m[idx]
95+
# for ii = 1:N
96+
# d[ii] .-= L.alphas[idx].*L.y_m[idx][ii]
97+
# end
98+
# idx -= 1
99+
# if idx == 0 idx = M end
100+
# end
101+
# return idx
102+
# end
103+
#
104+
# function loop2!(d::A, idx::Int, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B}
105+
# for i=1:L.currmem
106+
# idx += 1
107+
# if idx > M idx = 1 end
108+
# beta = sum(real.(vecdot.(L.y_m[idx], d)))
109+
# beta /= L.ys_m[idx]
110+
# for ii = 1:N
111+
# d[ii] .-= (beta-L.alphas[idx]).*L.s_m[idx][ii]
112+
# end
113+
# end
114+
# return d
115+
# end
116+
#
117+
# function reset(L::BlkDiagLBFGS)
118+
# L.currmem = 0
119+
# L.curridx = 0
120+
# end
121+
#
122+
# # Properties
123+
#
124+
# domainType(L::BlkDiagLBFGS) = eltype.(L.s)
125+
# codomainType(L::BlkDiagLBFGS) = eltype.(L.s)
126+
#
127+
# fun_name(A::BlkDiagLBFGS) = "LBFGS"
128+
# size(A::BlkDiagLBFGS) = (size.(A.s), size.(A.s))

src/linearoperators/LBFGS.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
export LBFGS, update!
22

3-
# TODO make Ac_mul_B!
4-
# Edit: Ac_mul_B! is not really needed for this operator
5-
# Edit2: you never known! anyway for completeness would be cool to have it!
63
"""
74
`LBFGS(T::Type, dim::Tuple, Memory::Int)`
85
96
`LBFGS{N}(T::NTuple{N,Type}, dim::NTuple{N,Tuple}, M::Int)`
107
118
`LBFGS(x::AbstractArray, Memory::Int)`
129
13-
Construct a Limited-Memory BFGS `LinearOperator` with memory `M`. The memory of `LBFGS` can be updated using the function `update!`, where the current iteration variable and gradient (`x`, `grad`) and the previous ones (`x_prev` and `grad_prev`) are needed:
10+
Construct a Limited-Memory BFGS `LinearOperator` with memory `M`. The memory of `LBFGS` can be updated using the function `update!`, where the current iteration variable and gradient (`x`, `grad`) and the previous ones (`x_prev` and `grad_prev`) are needed:
1411
1512
```
1613
julia> L = LBFGS(Float64,(4,),5)
@@ -21,23 +18,22 @@ julia> update!(L,x,x_prev,grad,grad_prev); #update memory
2118
julia> d = L*x; #compute new direction
2219
2320
```
24-
2521
"""
2622

27-
mutable struct LBFGS{M, N, R <: Real, T <: Union{R, Complex{R}}, A<:AbstractArray{T,N}} <: LinearOperator
23+
mutable struct LBFGS{R, T <: BlockArray, M} <: LinearOperator
2824
currmem::Int
2925
curridx::Int
30-
s::A
31-
y::A
32-
s_m::NTuple{M, A}
33-
y_m::NTuple{M, A}
26+
s::T
27+
y::T
28+
s_m::NTuple{M, T}
29+
y_m::NTuple{M, T}
3430
ys_m::Array{R, 1}
3531
alphas::Array{R, 1}
3632
H::R
3733
end
3834

3935
# Constructors
40-
#default
36+
4137
function LBFGS(T::Type, dim::NTuple{N,Int}, M::Int) where {N}
4238
s_m = tuple([deepzeros(T,dim) for i = 1:M]...)
4339
y_m = tuple([deepzeros(T,dim) for i = 1:M]...)
@@ -46,10 +42,12 @@ function LBFGS(T::Type, dim::NTuple{N,Int}, M::Int) where {N}
4642
R = real(T)
4743
ys_m = zeros(R, M)
4844
alphas = zeros(R, M)
49-
LBFGS{M,N,R,T,typeof(s)}(0, 0, s, y, s_m, y_m, ys_m, alphas, zero(R))
45+
LBFGS{M,N,R,T,typeof(s)}(0, 0, s, y, s_m, y_m, ys_m, alphas, one(R))
5046
end
5147

52-
LBFGS(x::AbstractArray,M::Int) = LBFGS(eltype(x),size(x),M)
48+
function LBFGS(x::T, M::Int)
49+
50+
end
5351

5452
"""
5553
`update!(L::LBFGS, x, x_prex, grad, grad_prev)`
@@ -124,8 +122,8 @@ function loop2!(d::A, idx::Int, L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A}
124122
end
125123

126124
# Properties
127-
domainType(L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A} = T
128-
codomainType(L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A} = T
125+
domainType(L::LBFGS{R, T, M}) where {R, T, M} = T
126+
codomainType(L::LBFGS{R, T, M}) where {R, T, M} = T
129127

130128
size(A::LBFGS) = (size(A.s), size(A.s))
131129

0 commit comments

Comments
 (0)