|
1 | | - |
2 | | -mutable struct BlkDiagLBFGS{M, N, R<:Real, A <:NTuple{N,AbstractArray}, B <:NTuple{M,A}} <: LinearOperator |
3 | | - currmem::Int |
4 | | - curridx::Int |
5 | | - s::A |
6 | | - y::A |
7 | | - s_m::B |
8 | | - y_m::B |
9 | | - ys_m::Array{Float64,1} |
10 | | - alphas::Array{Float64,1} |
11 | | - H::R |
12 | | -end |
13 | | - |
14 | | -#constructors |
15 | | -#default |
16 | | -function LBFGS(T::NTuple{N,Type}, dim::NTuple{N,NTuple}, M::Int) where {N} |
17 | | - s_m = tuple([deepzeros(T,dim) for i = 1:M]...) |
18 | | - y_m = tuple([deepzeros(T,dim) for i = 1:M]...) |
19 | | - s = deepzeros(T,dim) |
20 | | - y = deepzeros(T,dim) |
21 | | - R = real(T[1]) |
22 | | - ys_m = zeros(R, M) |
23 | | - alphas = zeros(R, M) |
24 | | - BlkDiagLBFGS{M,N,R,typeof(s),typeof(s_m)}(0, 0, s, y, s_m, y_m, ys_m, alphas, 1.) |
25 | | -end |
26 | | - |
27 | | -LBFGS(x::NTuple{N,AbstractArray},M::Int64) where {N} = LBFGS(eltype.(x),size.(x),M) |
28 | | - |
29 | | -#mappings |
30 | | - |
31 | | -@generated function update!(L::BlkDiagLBFGS{M,N,R,A,B}, |
32 | | - x::A, |
33 | | - x_prev::A, |
34 | | - gradx::A, |
35 | | - gradx_prev::A) where {M,N,R,A,B} |
36 | | - |
37 | | - ex = :(ys = 0.) |
38 | | - for i = 1:N |
39 | | - ex = :($ex; ys += update_s_y(L.s[$i],L.y[$i],x[$i],x_prev[$i],gradx[$i],gradx_prev[$i])) |
40 | | - end |
41 | | - ex2 = :(yty = 0.) |
42 | | - for i = 1:N |
43 | | - ex2 = :($ex2; yty += update_s_m_y_m(L.s_m[L.curridx][$i],L.y_m[L.curridx][$i], |
44 | | - L.s[$i],L.y[$i] )) |
45 | | - end |
46 | | - |
47 | | - ex = quote |
48 | | - $ex |
49 | | - if ys > 0 |
50 | | - L.curridx += 1 |
51 | | - if L.curridx > M L.curridx = 1 end |
52 | | - L.currmem += 1 |
53 | | - if L.currmem > M L.currmem = M end |
54 | | - |
55 | | - $ex2 |
56 | | - L.ys_m[L.curridx] = ys |
57 | | - L.H = ys/sum(yty) |
58 | | - end |
59 | | - return L |
60 | | - end |
61 | | -end |
62 | | - |
63 | | -function update_s_y(s::A, y::A, x::A, x_prev::A, gradx::A, gradx_prev::A) where {A} |
64 | | - s .= (-).(x, x_prev) |
65 | | - y .= (-).(gradx, gradx_prev) |
66 | | - ys = real(vecdot(s,y)) |
67 | | - return ys |
68 | | -end |
69 | | - |
70 | | -function update_s_m_y_m(s_m::A,y_m::A,s::A,y::A) where {A} |
71 | | - s_m .= s |
72 | | - y_m .= y |
73 | | - |
74 | | - yty = real(vecdot(y,y)) |
75 | | - return yty |
76 | | -end |
77 | | - |
78 | | -function A_mul_B!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}, gradx::A) where {M, N, R, A, B} |
79 | | - for i = 1:N |
80 | | - d[i] .= (-).(gradx[i]) |
81 | | - end |
82 | | - idx = loop1!(d,L) |
83 | | - for i = 1:N |
84 | | - d[i] .= (*).(L.H, d[i]) |
85 | | - end |
86 | | - d = loop2!(d,idx,L) |
87 | | -end |
88 | | - |
89 | | -function loop1!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B} |
90 | | - idx = L.curridx |
91 | | - for i=1:L.currmem |
92 | | - L.alphas[idx] = sum(real.(vecdot.(L.s_m[idx], d))) |
93 | | - |
94 | | - L.alphas[idx] /= L.ys_m[idx] |
95 | | - for ii = 1:N |
96 | | - d[ii] .-= L.alphas[idx].*L.y_m[idx][ii] |
97 | | - end |
98 | | - idx -= 1 |
99 | | - if idx == 0 idx = M end |
100 | | - end |
101 | | - return idx |
102 | | -end |
103 | | - |
104 | | -function loop2!(d::A, idx::Int, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B} |
105 | | - for i=1:L.currmem |
106 | | - idx += 1 |
107 | | - if idx > M idx = 1 end |
108 | | - beta = sum(real.(vecdot.(L.y_m[idx], d))) |
109 | | - beta /= L.ys_m[idx] |
110 | | - for ii = 1:N |
111 | | - d[ii] .-= (beta-L.alphas[idx]).*L.s_m[idx][ii] |
112 | | - end |
113 | | - end |
114 | | - return d |
115 | | -end |
116 | | - |
117 | | -function reset(L::BlkDiagLBFGS) |
118 | | - L.currmem = 0 |
119 | | - L.curridx = 0 |
120 | | -end |
121 | | - |
122 | | -# Properties |
123 | | - |
124 | | -domainType(L::BlkDiagLBFGS) = eltype.(L.s) |
125 | | -codomainType(L::BlkDiagLBFGS) = eltype.(L.s) |
126 | | - |
127 | | -fun_name(A::BlkDiagLBFGS) = "LBFGS" |
128 | | -size(A::BlkDiagLBFGS) = (size.(A.s), size.(A.s)) |
| 1 | +# |
| 2 | +# mutable struct BlkDiagLBFGS{M, N, R<:Real, A <:NTuple{N,AbstractArray}, B <:NTuple{M,A}} <: LinearOperator |
| 3 | +# currmem::Int |
| 4 | +# curridx::Int |
| 5 | +# s::A |
| 6 | +# y::A |
| 7 | +# s_m::B |
| 8 | +# y_m::B |
| 9 | +# ys_m::Array{Float64,1} |
| 10 | +# alphas::Array{Float64,1} |
| 11 | +# H::R |
| 12 | +# end |
| 13 | +# |
| 14 | +# #constructors |
| 15 | +# #default |
| 16 | +# function LBFGS(T::NTuple{N,Type}, dim::NTuple{N,NTuple}, M::Int) where {N} |
| 17 | +# s_m = tuple([deepzeros(T,dim) for i = 1:M]...) |
| 18 | +# y_m = tuple([deepzeros(T,dim) for i = 1:M]...) |
| 19 | +# s = deepzeros(T,dim) |
| 20 | +# y = deepzeros(T,dim) |
| 21 | +# R = real(T[1]) |
| 22 | +# ys_m = zeros(R, M) |
| 23 | +# alphas = zeros(R, M) |
| 24 | +# BlkDiagLBFGS{M,N,R,typeof(s),typeof(s_m)}(0, 0, s, y, s_m, y_m, ys_m, alphas, 1.) |
| 25 | +# end |
| 26 | +# |
| 27 | +# LBFGS(x::NTuple{N,AbstractArray},M::Int64) where {N} = LBFGS(eltype.(x),size.(x),M) |
| 28 | +# |
| 29 | +# #mappings |
| 30 | +# |
| 31 | +# @generated function update!(L::BlkDiagLBFGS{M,N,R,A,B}, |
| 32 | +# x::A, |
| 33 | +# x_prev::A, |
| 34 | +# gradx::A, |
| 35 | +# gradx_prev::A) where {M,N,R,A,B} |
| 36 | +# |
| 37 | +# ex = :(ys = 0.) |
| 38 | +# for i = 1:N |
| 39 | +# ex = :($ex; ys += update_s_y(L.s[$i],L.y[$i],x[$i],x_prev[$i],gradx[$i],gradx_prev[$i])) |
| 40 | +# end |
| 41 | +# ex2 = :(yty = 0.) |
| 42 | +# for i = 1:N |
| 43 | +# ex2 = :($ex2; yty += update_s_m_y_m(L.s_m[L.curridx][$i],L.y_m[L.curridx][$i], |
| 44 | +# L.s[$i],L.y[$i] )) |
| 45 | +# end |
| 46 | +# |
| 47 | +# ex = quote |
| 48 | +# $ex |
| 49 | +# if ys > 0 |
| 50 | +# L.curridx += 1 |
| 51 | +# if L.curridx > M L.curridx = 1 end |
| 52 | +# L.currmem += 1 |
| 53 | +# if L.currmem > M L.currmem = M end |
| 54 | +# |
| 55 | +# $ex2 |
| 56 | +# L.ys_m[L.curridx] = ys |
| 57 | +# L.H = ys/sum(yty) |
| 58 | +# end |
| 59 | +# return L |
| 60 | +# end |
| 61 | +# end |
| 62 | +# |
| 63 | +# function update_s_y(s::A, y::A, x::A, x_prev::A, gradx::A, gradx_prev::A) where {A} |
| 64 | +# s .= (-).(x, x_prev) |
| 65 | +# y .= (-).(gradx, gradx_prev) |
| 66 | +# ys = real(vecdot(s,y)) |
| 67 | +# return ys |
| 68 | +# end |
| 69 | +# |
| 70 | +# function update_s_m_y_m(s_m::A,y_m::A,s::A,y::A) where {A} |
| 71 | +# s_m .= s |
| 72 | +# y_m .= y |
| 73 | +# |
| 74 | +# yty = real(vecdot(y,y)) |
| 75 | +# return yty |
| 76 | +# end |
| 77 | +# |
| 78 | +# function A_mul_B!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}, gradx::A) where {M, N, R, A, B} |
| 79 | +# for i = 1:N |
| 80 | +# d[i] .= (-).(gradx[i]) |
| 81 | +# end |
| 82 | +# idx = loop1!(d,L) |
| 83 | +# for i = 1:N |
| 84 | +# d[i] .= (*).(L.H, d[i]) |
| 85 | +# end |
| 86 | +# d = loop2!(d,idx,L) |
| 87 | +# end |
| 88 | +# |
| 89 | +# function loop1!(d::A, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B} |
| 90 | +# idx = L.curridx |
| 91 | +# for i=1:L.currmem |
| 92 | +# L.alphas[idx] = sum(real.(vecdot.(L.s_m[idx], d))) |
| 93 | +# |
| 94 | +# L.alphas[idx] /= L.ys_m[idx] |
| 95 | +# for ii = 1:N |
| 96 | +# d[ii] .-= L.alphas[idx].*L.y_m[idx][ii] |
| 97 | +# end |
| 98 | +# idx -= 1 |
| 99 | +# if idx == 0 idx = M end |
| 100 | +# end |
| 101 | +# return idx |
| 102 | +# end |
| 103 | +# |
| 104 | +# function loop2!(d::A, idx::Int, L::BlkDiagLBFGS{M,N,R,A,B}) where {M, N, R, A, B} |
| 105 | +# for i=1:L.currmem |
| 106 | +# idx += 1 |
| 107 | +# if idx > M idx = 1 end |
| 108 | +# beta = sum(real.(vecdot.(L.y_m[idx], d))) |
| 109 | +# beta /= L.ys_m[idx] |
| 110 | +# for ii = 1:N |
| 111 | +# d[ii] .-= (beta-L.alphas[idx]).*L.s_m[idx][ii] |
| 112 | +# end |
| 113 | +# end |
| 114 | +# return d |
| 115 | +# end |
| 116 | +# |
| 117 | +# function reset(L::BlkDiagLBFGS) |
| 118 | +# L.currmem = 0 |
| 119 | +# L.curridx = 0 |
| 120 | +# end |
| 121 | +# |
| 122 | +# # Properties |
| 123 | +# |
| 124 | +# domainType(L::BlkDiagLBFGS) = eltype.(L.s) |
| 125 | +# codomainType(L::BlkDiagLBFGS) = eltype.(L.s) |
| 126 | +# |
| 127 | +# fun_name(A::BlkDiagLBFGS) = "LBFGS" |
| 128 | +# size(A::BlkDiagLBFGS) = (size.(A.s), size.(A.s)) |
0 commit comments