Skip to content

Commit 98e4c42

Browse files
author
Pawel Latawiec
committed
Use BlockDiagonal for E, D
1 parent a156a15 commit 98e4c42

File tree

1 file changed

+48
-26
lines changed

1 file changed

+48
-26
lines changed

src/lal.jl

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using Printf
22
import Base: iterate
33
import LinearAlgebra: UpperTriangular, UpperHessenberg
4+
import BlockDiagonals: BlockDiagonal, blocks
5+
import BlockDiagonals
46

57
"""
68
LookAheadLanczosDecompOptions
@@ -71,9 +73,9 @@ mutable struct LookAheadLanczosDecomp{OpT, OptT, VecT, MatT, ElT, ElRT}
7173
γ::Vector{ElRT}
7274

7375
# Eq. 2.13
74-
D::Matrix{ElT}
76+
D::BlockDiagonal{ElT, Matrix{ElT}}
7577
# Eq. 3.14
76-
E::Matrix{ElT}
78+
E::BlockDiagonal{ElT, Matrix{ElT}}
7779
# Defined after Eq. 5.1
7880
F::Matrix{ElT}
7981
F̃lastcol::Vector{ElT}
@@ -177,8 +179,8 @@ function LookAheadLanczosDecomp(
177179
γ = Vector{real(elT)}(undef, 1)
178180
γ[1] = 1.0
179181

180-
D = Matrix{elT}(undef, 0, 0)
181-
E = Matrix{elT}(undef, 0, 0)
182+
D = BlockDiagonal{elT, Matrix{elT}}(Vector{Matrix{elT}}())
183+
E = BlockDiagonal{elT, Matrix{elT}}(Vector{Matrix{elT}}())
182184
G = Vector{elT}()
183185
H = Vector{elT}()
184186

@@ -245,6 +247,35 @@ _VW_block_size(ld) = ld.n+1 - ld.nl[ld.l]
245247
_VW_prev_block_size(ld) = ld.nl[ld.l] - ld.nl[max(1, ld.l-1)]
246248
_is_block_small(ld, n) = n < ld.opts.max_block_size
247249

250+
"""
251+
_grow_last_block!(A, Bcol, Brow, Bcorner)
252+
253+
Grows the last block in-place in `A` by appending the column `Bcol`, the row `Brow`, and the corner element `Bcorner`. `Bcol` and `Brow` are automatically truncated to match the size of the grown block
254+
"""
255+
function _grow_last_block!(A::BlockDiagonal{T, TM}, Bcol, Brow, Bcorner) where {T, TM}
256+
n = BlockDiagonals.nblocks(A)
257+
b = BlockDiagonals.blocks(A)
258+
s = size(last(b), 1)
259+
b[n] = TM([
260+
b[n] Bcol[end-s+1:end]
261+
Brow[1:1, end-s+1:end] Bcorner
262+
])
263+
return A
264+
end
265+
266+
"""
267+
_start_new_block!(A, B)
268+
269+
Appends a new block to the end of `A` with `B`
270+
"""
271+
function _start_new_block!(A::BlockDiagonal{T, TM}, B) where {T, TM}
272+
push!(BlockDiagonals.blocks(A), TM(fill(only(B), 1, 1)))
273+
return A
274+
end
275+
276+
Base.size(B::BlockDiagonals.BlockDiagonal) = sum(firstsize, BlockDiagonals.blocks(B), init=0), sum(lastsize, BlockDiagonals.blocks(B), init=0)
277+
278+
248279
start(::LookAheadLanczosDecomp) = 1
249280
done(ld::LookAheadLanczosDecomp, iteration::Int) = iteration ld.opts.max_iter
250281
function iterate(ld::LookAheadLanczosDecomp, n::Int=start(ld))
@@ -404,16 +435,12 @@ function _update_D!(ld)
404435
# D[n, n] = wtv
405436

406437
# TODO: closed block
407-
block_start = ld.nl[ld.lstar]
408-
if isone(ld.n)
409-
ld.D = fill(ld.wtv, 1, 1)
438+
if isone(ld.n) || _VW_block_size(ld) == 1
439+
_start_new_block!(ld.D, ld.wtv)
410440
else
411441
D_lastcol = (ld.F[:, end] - (ld.D * ld.L[1:end-1, end])) / ld.ρ
412442
D_lastrow = transpose(D_lastcol * ld.γ[end] ./ ld.γ[1:end-1])
413-
ld.D = [
414-
ld.D D_lastcol
415-
D_lastrow ld.wtv
416-
]
443+
_grow_last_block!(ld.D, D_lastcol, D_lastrow, ld.wtv)
417444
end
418445
return ld
419446
end
@@ -437,12 +464,12 @@ function _update_Flastrow!(ld)
437464
# F_{n} = D_{n}L[1:n, 1:n] + l[n+1, n]D_{n}[1:n, n+1][0 ... 0 1]
438465
# TODO: block
439466
if !isone(ld.n) # We only need to do this if we are constructing a block
440-
Flastrow = reshape(ld.D[end:end, :] * ld.L, :)
441-
ld.F̃lastcol = Flastrow .* ld.γ[1:end-1] ./ ld.γ[end]
467+
Flastrow = ld.D[end:end, :] * ld.L
468+
ld.F̃lastcol = reshape(Flastrow, :) .* ld.γ[1:end-1] ./ ld.γ[end]
442469
# we are not able to fill in the last column yet, so we fill with zero
443470
ld.F = [
444471
ld.F fill(0.0, size(ld.F, 1))
445-
transpose(Flastrow) 0.0
472+
Flastrow 0.0
446473
]
447474
end
448475
end
@@ -462,10 +489,10 @@ function _update_U!(ld, innerp)
462489
for i = kstar:k-1
463490
block_start = mk[i]
464491
block_end = mk[i+1]-1
465-
ld.U[block_start:block_end, end] .= ld.E[block_start:block_end, block_start:block_end] \ ld.F̃lastcol[block_start:block_end]
492+
ld.U[block_start:block_end, end] .= blocks(ld.E)[i] \ ld.F̃lastcol[block_start:block_end]
466493
end
467494
if !innerp && !isone(n)
468-
ld.U[mk[k]:end-1, end] .= ld.E[mk[k]:end, mk[k]:end] \ ld.F̃lastcol[mk[k]:end]
495+
ld.U[mk[k]:end-1, end] .= blocks(ld.E)[end] \ ld.F̃lastcol[mk[k]:end]
469496
end
470497
return ld
471498
end
@@ -572,18 +599,13 @@ function _update_E!(ld)
572599
# F = ΓUtinv(Γ)E
573600
# 5.2.14
574601
n = ld.n
575-
block_start = ld.mk[ld.kstar]
576-
577-
if isone(ld.n)
578-
ld.E = fill(ld.qtAp, 1, 1)
602+
if isone(ld.n) || (ld.n == ld.mk[end])
603+
_start_new_block!(ld.E, ld.qtAp)
579604
else
580605
ΓUtinvΓ = ld.γ .* transpose(ld.U) ./ transpose(ld.γ)
581606
Elastrow = (ΓUtinvΓ[end, end] \ ld.F[n:n, 1:n-1] - ΓUtinvΓ[end:end, 1:end-1]*ld.E)
582607
Elastcol = (transpose(Elastrow) .* ld.γ[1:n-1] ./ ld.γ[n])
583-
ld.E = [
584-
ld.E Elastcol
585-
Elastrow ld.qtAp
586-
]
608+
_grow_last_block!(ld.E, Elastcol, Elastrow, ld.qtAp)
587609
end
588610
return ld
589611
end
@@ -616,10 +638,10 @@ function _update_L!(ld, innerv)
616638
for i = lstar:l-1
617639
block_start = nl[i]
618640
block_end = nl[i+1]-1
619-
Llastcol[block_start:block_end] .= ld.D[block_start:block_end, block_start:block_end] \ ld.F[block_start:block_end, end]
641+
Llastcol[block_start:block_end] .= blocks(ld.D)[i] \ ld.F[block_start:block_end, end]
620642
end
621643
if !innerv
622-
Llastcol[nl[l]:end] .= ld.D[nl[l]:end, nl[l]:end] \ ld.F[nl[l]:end, end]
644+
Llastcol[nl[l]:end] .= blocks(ld.D)[end] \ ld.F[nl[l]:end, end]
623645
end
624646
if isone(n)
625647
ld.L = UpperHessenberg(

0 commit comments

Comments
 (0)