Skip to content

Commit 6adb8c9

Browse files
committed
LBFGS should be fixed
1 parent 48d14b8 commit 6adb8c9

File tree

8 files changed

+79
-347
lines changed

8 files changed

+79
-347
lines changed

src/AbstractOperators.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ __precompile__()
22

33
module AbstractOperators
44

5-
const RealOrComplex{T<:Real} = Union{T, Complex{T}}
5+
using BlockArrays
66

77
abstract type AbstractOperator end
88

@@ -17,7 +17,6 @@ export LinearOperator,
1717

1818
# Block stuff
1919

20-
include("utilities/block.jl")
2120
include("utilities/deep.jl") # TODO: remove this eventually
2221

2322
# Predicates and properties
@@ -68,7 +67,4 @@ include("nonlinearoperators/SoftMax.jl")
6867
# Syntax
6968
include("syntax.jl")
7069

71-
72-
73-
7470
end

src/linearoperators/BlkDiagLBFGS.jl

Lines changed: 0 additions & 128 deletions
This file was deleted.

src/linearoperators/LBFGS.jl

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,122 +1,101 @@
11
export LBFGS, update!
22

33
"""
4-
`LBFGS(T::Type, dim::Tuple, Memory::Int)`
4+
`LBFGS(x::Tuple, M::Integer)`
55
6-
`LBFGS{N}(T::NTuple{N,Type}, dim::NTuple{N,Tuple}, M::Int)`
7-
8-
`LBFGS(x::AbstractArray, Memory::Int)`
6+
`LBFGS(x::AbstractArray, M::Integer)`
97
108
Construct a Limited-Memory BFGS `LinearOperator` with memory `M`. The memory of `LBFGS` can be updated using the function `update!`, where the current iteration variable and gradient (`x`, `grad`) and the previous ones (`x_prev` and `grad_prev`) are needed:
119
1210
```
1311
julia> L = LBFGS(Float64,(4,),5)
1412
LBFGS ℝ^4 -> ℝ^4
1513
16-
julia> update!(L,x,x_prev,grad,grad_prev); #update memory
14+
julia> update!(L,x,x_prev,grad,grad_prev); # update memory
1715
18-
julia> d = L*x; #compute new direction
16+
julia> d = L*grad; # compute new direction
1917
2018
```
2119
"""
2220

2321
mutable struct LBFGS{R, T <: BlockArray, M} <: LinearOperator
24-
currmem::Int
25-
curridx::Int
22+
currmem::Integer
23+
curridx::Integer
2624
s::T
2725
y::T
28-
s_m::NTuple{M, T}
29-
y_m::NTuple{M, T}
30-
ys_m::Array{R, 1}
26+
s_M::Array{T, 1}
27+
y_M::Array{T, 1}
28+
ys_M::Array{R, 1}
3129
alphas::Array{R, 1}
3230
H::R
3331
end
3432

3533
# Constructors
3634

37-
function LBFGS(T::Type, dim::NTuple{N,Int}, M::Int) where {N}
38-
s_m = tuple([deepzeros(T,dim) for i = 1:M]...)
39-
y_m = tuple([deepzeros(T,dim) for i = 1:M]...)
40-
s = deepzeros(T,dim)
41-
y = deepzeros(T,dim)
42-
R = real(T)
43-
ys_m = zeros(R, M)
44-
alphas = zeros(R, M)
45-
LBFGS{M,N,R,T,typeof(s)}(0, 0, s, y, s_m, y_m, ys_m, alphas, one(R))
46-
end
47-
48-
function LBFGS(x::T, M::Int)
49-
35+
function LBFGS(x::T, M::Integer) where {R, T <: BlockArray{R}}
36+
s_M = [blockzeros(x) for i = 1:M]
37+
y_M = [blockzeros(x) for i = 1:M]
38+
s = blockzeros(x)
39+
y = blockzeros(x)
40+
ys_M = zeros(M)
41+
alphas = zeros(M)
42+
LBFGS{R, T, M}(0, 0, s, y, s_M, y_M, ys_M, alphas, one(R))
5043
end
5144

5245
"""
5346
`update!(L::LBFGS, x, x_prex, grad, grad_prev)`
5447
55-
See `LBFGS` documentation.
56-
48+
See the documentation for `LBFGS`.
5749
"""
5850

59-
function update!(L::LBFGS{M,N,R,T,A},
60-
x::A,
61-
x_prev::A,
62-
gradx::A,
63-
gradx_prev::A) where {M,N,R,T,A}
64-
65-
ys = update_s_y(L,x,x_prev,gradx,gradx_prev)
66-
51+
function update!(L::LBFGS{R, T, M}, x::T, x_prev::T, gradx::T, gradx_prev::T) where {R, T, M}
52+
L.s .= x .- x_prev
53+
L.y .= gradx .- gradx_prev
54+
ys = real(blockvecdot(L.s, L.y))
6755
if ys > 0
6856
L.curridx += 1
6957
if L.curridx > M L.curridx = 1 end
7058
L.currmem += 1
7159
if L.currmem > M L.currmem = M end
72-
73-
74-
yty = update_s_m_y_m(L,L.curridx)
75-
L.ys_m[L.curridx] = ys
60+
L.ys_M[L.curridx] = ys
61+
blockcopy!(L.s_M[L.curridx], L.s)
62+
blockcopy!(L.y_M[L.curridx], L.y)
63+
yty = real(vecdot(L.y, L.y))
7664
L.H = ys/yty
7765
end
7866
return L
7967
end
8068

81-
function update_s_y(L::LBFGS{M,N,R,T,A}, x::A, x_prev::A, gradx::A, gradx_prev::A) where {M,N,R,T,A}
82-
L.s .= (-).(x, x_prev)
83-
L.y .= (-).(gradx, gradx_prev)
84-
ys = real(vecdot(L.s,L.y))
85-
return ys
86-
end
69+
# LBFGS operators are symmetric
8770

88-
function update_s_m_y_m(L::LBFGS{M,N,R,T,A}, curridx::Int) where {M,N,R,T,A}
89-
L.s_m[curridx] .= L.s
90-
L.y_m[curridx] .= L.y
71+
Ac_mul_B!(x::T, L::LBFGS{R, T, M}, y::T) where {R, T, M} = A_mul_B!(x, L, y)
9172

92-
yty = real(vecdot(L.y,L.y))
93-
return yty
94-
end
73+
# Two-loop recursion
9574

96-
function A_mul_B!(d::A, L::LBFGS{M,N,R,T,A}, gradx::A) where {M,N,R,T,A}
97-
d .= (-).(gradx)
75+
function A_mul_B!(d::T, L::LBFGS{R, T, M}, gradx::T) where {R, T, M}
76+
d .= gradx
9877
idx = loop1!(d,L)
9978
d .= (*).(L.H, d)
10079
d = loop2!(d,idx,L)
10180
end
10281

103-
function loop1!(d::A, L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A}
82+
function loop1!(d::T, L::LBFGS{R, T, M}) where {R, T, M}
10483
idx = L.curridx
105-
for i=1:L.currmem
106-
L.alphas[idx] = real(vecdot(L.s_m[idx], d))/L.ys_m[idx]
107-
d .-= L.alphas[idx].*L.y_m[idx]
84+
for i = 1:L.currmem
85+
L.alphas[idx] = real(vecdot(L.s_M[idx], d))/L.ys_M[idx]
86+
d .-= L.alphas[idx] .* L.y_M[idx]
10887
idx -= 1
10988
if idx == 0 idx = M end
11089
end
11190
return idx
11291
end
11392

114-
function loop2!(d::A, idx::Int, L::LBFGS{M,N,R,T,A}) where {M,N,R,T,A}
115-
for i=1:L.currmem
93+
function loop2!(d::T, idx::Int, L::LBFGS{R, T, M}) where {R, T, M}
94+
for i = 1:L.currmem
11695
idx += 1
11796
if idx > M idx = 1 end
118-
beta = real(vecdot(L.y_m[idx], d))/L.ys_m[idx]
119-
d .+= (L.alphas[idx].-beta).*L.s_m[idx]
97+
beta = real(vecdot(L.y_M[idx], d))/L.ys_M[idx]
98+
d .+= (L.alphas[idx] - beta) .* L.s_M[idx]
12099
end
121100
return d
122101
end
@@ -125,6 +104,6 @@ end
125104
domainType(L::LBFGS{R, T, M}) where {R, T, M} = T
126105
codomainType(L::LBFGS{R, T, M}) where {R, T, M} = T
127106

128-
size(A::LBFGS) = (size(A.s), size(A.s))
107+
size(A::LBFGS) = (blocksize(A.s), blocksize(A.s))
129108

130109
fun_name(A::LBFGS) = "LBFGS"

0 commit comments

Comments
 (0)