Skip to content

Commit af0bf85

Browse files
authored
Merge pull request #3 from kul-forbes/new-lbfgs
Fixed L-BFGS, included new code for BlockArray
2 parents 72a8d6e + baff0de commit af0bf85

File tree

11 files changed

+240
-364
lines changed

11 files changed

+240
-364
lines changed

src/AbstractOperators.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@ __precompile__()
22

33
module AbstractOperators
44

5-
const RealOrComplex{T<:Real} = Union{T, Complex{T}}
6-
75
abstract type AbstractOperator end
86

97
abstract type LinearOperator <: AbstractOperator end
108
abstract type NonLinearOperator <: AbstractOperator end
119

12-
import Base: A_mul_B!, Ac_mul_B!
10+
import Base: A_mul_B!, Ac_mul_B!
1311

1412
export LinearOperator,
1513
NonLinearOperator,
1614
AbstractOperator
1715

18-
# deep stuff
16+
# Block stuff
17+
18+
include("utilities/block.jl")
19+
#include("utilities/deep.jl") # TODO: remove this eventually
1920

20-
include("utilities/deep.jl")
21+
# Predicates and properties
2122

22-
# predicates and properties
2323
include("properties.jl")
2424

2525
# Linear operators
@@ -43,7 +43,7 @@ include("linearoperators/Filt.jl")
4343
include("linearoperators/MIMOFilt.jl")
4444
include("linearoperators/Xcorr.jl")
4545
include("linearoperators/LBFGS.jl")
46-
include("linearoperators/BlkDiagLBFGS.jl")
46+
# include("linearoperators/BlkDiagLBFGS.jl")
4747

4848
# Calculus rules
4949

@@ -67,7 +67,4 @@ include("nonlinearoperators/SoftPlus.jl")
6767
# Syntax
6868
include("syntax.jl")
6969

70-
71-
72-
7370
end

src/linearoperators/BlkDiagLBFGS.jl

Lines changed: 0 additions & 128 deletions
This file was deleted.

src/linearoperators/LBFGS.jl

Lines changed: 60 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,132 +1,125 @@
11
export LBFGS, update!
22

3-
# TODO make Ac_mul_B!
4-
# Edit: Ac_mul_B! is not really needed for this operator
5-
# Edit2: you never known! anyway for completeness would be cool to have it!
63
"""
7-
`LBFGS(T::Type, dim::Tuple, Memory::Int)`
4+
`LBFGS(domainType::Type,dim_in::Tuple, M::Integer)`
85
9-
`LBFGS{N}(T::NTuple{N,Type}, dim::NTuple{N,Tuple}, M::Int)`
6+
`LBFGS(dim_in::Tuple, M::Integer)`
107
11-
`LBFGS(x::AbstractArray, Memory::Int)`
8+
`LBFGS(x::AbstractArray, M::Integer)`
129
1310
Construct a Limited-Memory BFGS `LinearOperator` with memory `M`. The memory of `LBFGS` can be updated using the function `update!`, where the current iteration variable and gradient (`x`, `grad`) and the previous ones (`x_prev` and `grad_prev`) are needed:
1411
1512
```
1613
julia> L = LBFGS(Float64,(4,),5)
1714
LBFGS ℝ^4 -> ℝ^4
1815
19-
julia> update!(L,x,x_prev,grad,grad_prev); #update memory
16+
julia> update!(L,x,x_prev,grad,grad_prev); # update memory
2017
21-
julia> d = L*x; #compute new direction
18+
julia> d = L*grad; # compute new direction
2219
2320
```
24-
2521
"""
2622

27-
mutable struct LBFGS{M, N, R <: Real, T <: Union{R, Complex{R}}, A<:AbstractArray{T,N}} <: LinearOperator
28-
currmem::Int
29-
curridx::Int
30-
s::A
31-
y::A
32-
s_m::NTuple{M, A}
33-
y_m::NTuple{M, A}
34-
ys_m::Array{R, 1}
23+
mutable struct LBFGS{R, T <: BlockArray, M, I <: Integer} <: LinearOperator
24+
currmem::I
25+
curridx::I
26+
s::T
27+
y::T
28+
s_M::Array{T, 1}
29+
y_M::Array{T, 1}
30+
ys_M::Array{R, 1}
3531
alphas::Array{R, 1}
3632
H::R
3733
end
3834

39-
# Constructors
40-
#default
41-
function LBFGS(T::Type, dim::NTuple{N,Int}, M::Int) where {N}
42-
s_m = tuple([deepzeros(T,dim) for i = 1:M]...)
43-
y_m = tuple([deepzeros(T,dim) for i = 1:M]...)
44-
s = deepzeros(T,dim)
45-
y = deepzeros(T,dim)
46-
R = real(T)
47-
ys_m = zeros(R, M)
35+
#default constructor
36+
37+
function LBFGS(domainType, dim_in, M::I) where {I <: Integer}
38+
s_M = [blockzeros(domainType, dim_in) for i = 1:M]
39+
y_M = [blockzeros(domainType, dim_in) for i = 1:M]
40+
s = blockzeros(domainType, dim_in)
41+
y = blockzeros(domainType, dim_in)
42+
T = typeof(s)
43+
R = typeof(domainType) <: Tuple ? real(domainType[1]) : real(domainType)
44+
ys_M = zeros(R, M)
4845
alphas = zeros(R, M)
49-
LBFGS{M,N,R,T,typeof(s)}(0, 0, s, y, s_m, y_m, ys_m, alphas, one(R))
46+
LBFGS{R, T, M, I}(0, 0, s, y, s_M, y_M, ys_M, alphas, one(R))
5047
end
5148

52-
LBFGS(x::AbstractArray,M::Int) = LBFGS(eltype(x),size(x),M)
49+
function LBFGS(dim_in, M::I) where {I <: Integer}
50+
domainType = eltype(dim_in) <: Integer ? Float64 : ([Float64 for i in eachindex(dim_in)]...)
51+
LBFGS(domainType, dim_in, M)
52+
end
53+
54+
function LBFGS(x::T, M::I) where {T <: BlockArray, I <: Integer}
55+
domainType = blockeltype(x)
56+
dim_in = blocksize(x)
57+
LBFGS(domainType, dim_in, M)
58+
end
5359

5460
"""
5561
`update!(L::LBFGS, x, x_prex, grad, grad_prev)`
5662
57-
See `LBFGS` documentation.
58-
63+
See the documentation for `LBFGS`.
5964
"""
6065

61-
function update!(L::LBFGS{M,N,R,T,A},
62-
x::A,
63-
x_prev::A,
64-
gradx::A,
65-
gradx_prev::A) where {M,N,R,T,A}
66-
67-
ys = update_s_y(L,x,x_prev,gradx,gradx_prev)
68-
66+
function update!(L::LBFGS{R, T, M, I}, x::T, x_prev::T, gradx::T, gradx_prev::T) where {R, T, M, I}
67+
L.s .= x .- x_prev
68+
L.y .= gradx .- gradx_prev
69+
ys = real(blockvecdot(L.s, L.y))
6970
if ys > 0
7071
L.curridx += 1
7172
if L.curridx > M L.curridx = 1 end
7273
L.currmem += 1
7374
if L.currmem > M L.currmem = M end
74-
75-
76-
yty = update_s_m_y_m(L,L.curridx)
77-
L.ys_m[L.curridx] = ys
75+
L.ys_M[L.curridx] = ys
76+
blockcopy!(L.s_M[L.curridx], L.s)
77+
blockcopy!(L.y_M[L.curridx], L.y)
78+
yty = real(blockvecdot(L.y, L.y))
7879
L.H = ys/yty
7980
end
8081
return L
8182
end
8283

83-
function update_s_y(L::LBFGS{M,N,R,T,A}, x::A, x_prev::A, gradx::A, gradx_prev::A) where {M,N,R,T,A}
84-
L.s .= (-).(x, x_prev)
85-
L.y .= (-).(gradx, gradx_prev)
86-
ys = real(vecdot(L.s,L.y))
87-
return ys
88-
end
84+
# LBFGS operators are symmetric
8985

90-
function update_s_m_y_m(L::LBFGS{M,N,R,T,A}, curridx::Int) where {M,N,R,T,A}
91-
L.s_m[curridx] .= L.s
92-
L.y_m[curridx] .= L.y
86+
Ac_mul_B!(x::T, L::LBFGS{R, T, M, I}, y::T) where {R, T, M, I} = A_mul_B!(x, L, y)
9387

94-
yty = real(vecdot(L.y,L.y))
95-
return yty
96-
end
88+
# Two-loop recursion
9789

98-
function A_mul_B!(d::A, L::LBFGS{M,N,R,T,A}, gradx::A) where {M,N,R,T,A}
99-
d .= (-).(gradx)
90+
function A_mul_B!(d::T, L::LBFGS{R, T, M, I}, gradx::T) where {R, T, M, I}
91+
d .= gradx
10092
idx = loop1!(d,L)
10193
d .= (*).(L.H, d)
10294
d = loop2!(d,idx,L)
10395
end
10496

105-
function loop1!(d::A, L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A}
97+
function loop1!(d::T, L::LBFGS{R, T, M, I}) where {R, T, M, I}
10698
idx = L.curridx
107-
for i=1:L.currmem
108-
L.alphas[idx] = real(vecdot(L.s_m[idx], d))/L.ys_m[idx]
109-
d .-= L.alphas[idx].*L.y_m[idx]
99+
for i = 1:L.currmem
100+
L.alphas[idx] = real(blockvecdot(L.s_M[idx], d))/L.ys_M[idx]
101+
d .-= L.alphas[idx] .* L.y_M[idx]
110102
idx -= 1
111103
if idx == 0 idx = M end
112104
end
113105
return idx
114106
end
115107

116-
function loop2!(d::A, idx::Int, L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A}
117-
for i=1:L.currmem
108+
function loop2!(d::T, idx::Int, L::LBFGS{R, T, M, I}) where {R, T, M, I}
109+
for i = 1:L.currmem
118110
idx += 1
119111
if idx > M idx = 1 end
120-
beta = real(vecdot(L.y_m[idx], d))/L.ys_m[idx]
121-
d .+= (L.alphas[idx].-beta).*L.s_m[idx]
112+
beta = real(blockvecdot(L.y_M[idx], d))/L.ys_M[idx]
113+
d .+= (L.alphas[idx] - beta) .* L.s_M[idx]
122114
end
123115
return d
124116
end
125117

126118
# Properties
127-
domainType(L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A} = T
128-
codomainType(L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A} = T
129119

130-
size(A::LBFGS) = (size(A.s), size(A.s))
120+
domainType(L::LBFGS{R, T, M, I}) where {R, T, M, I} = blockeltype(L.y_M[1])
121+
codomainType(L::LBFGS{R, T, M, I}) where {R, T, M, I} = blockeltype(L.y_M[1])
122+
123+
size(A::LBFGS) = (blocksize(A.s), blocksize(A.s))
131124

132125
fun_name(A::LBFGS) = "LBFGS"

0 commit comments

Comments
 (0)