Skip to content

Commit 901d096

Browse files
author
Pawel Latawiec
committed
Clean D and E computations
1 parent 34792c3 commit 901d096

File tree

1 file changed

+27
-55
lines changed

1 file changed

+27
-55
lines changed

src/lal.jl

Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
using Printf
2-
import BlockDiagonals: BlockDiagonal
3-
import BlockDiagonals
42
import Base: iterate
53
import LinearAlgebra: UpperTriangular, UpperHessenberg
64

@@ -73,11 +71,11 @@ mutable struct LookAheadLanczosDecomp{OpT, OptT, VecT, MatT, ElT, ElRT}
7371
γ::Vector{ElRT}
7472

7573
# Eq. 2.13
76-
D::BlockDiagonal{ElT, Matrix{ElT}}
74+
D::Matrix{ElT}
7775
# Eq. 3.14
78-
E::BlockDiagonal{ElT, Matrix{ElT}}
76+
E::Matrix{ElT}
7977
# Defined after Eq. 5.1
80-
F::BlockDiagonal{ElT, Matrix{ElT}}
78+
F::Matrix{ElT}
8179
F̃lastcol::Vector{ElT}
8280
# Eq. 5.1
8381
G::Vector{ElT}
@@ -177,12 +175,12 @@ function LookAheadLanczosDecomp(
177175
γ = Vector{real(elT)}(undef, 1)
178176
γ[1] = 1.0
179177

180-
D = BlockDiagonal{elT, Matrix{elT}}(Vector{Matrix{elT}}())
181-
E = BlockDiagonal{elT, Matrix{elT}}(Vector{Matrix{elT}}())
178+
D = Matrix{elT}(undef, 0, 0)
179+
E = Matrix{elT}(undef, 0, 0)
182180
G = Vector{elT}()
183181
H = Vector{elT}()
184182

185-
F = BlockDiagonal{elT, Matrix{elT}}(Vector{Matrix{elT}}())
183+
F = Matrix{elT}(undef, 0, 0)
186184
F̃lastcol = Vector{elT}()
187185

188186
U = UpperTriangular(Matrix{elT}(undef, 0, 0))
@@ -245,34 +243,6 @@ _VW_block_size(ld) = ld.n+1 - ld.nl[ld.l]
245243
_VW_prev_block_size(ld) = ld.nl[ld.l] - ld.nl[max(1, ld.l-1)]
246244
_is_block_small(ld, n) = n < ld.opts.max_block_size
247245

248-
"""
249-
_grow_last_block!(A, Bcol, Brow, Bcorner)
250-
251-
Grows the last block in-place in `A` by appending the column `Bcol`, the row `Brow`, and the corner element `Bcorner`. `Bcol` and `Brow` are automatically truncated to match the size of the grown block
252-
"""
253-
function _grow_last_block!(A::BlockDiagonal{T, TM}, Bcol, Brow, Bcorner) where {T, TM}
254-
n = BlockDiagonals.nblocks(A)
255-
b = BlockDiagonals.blocks(A)
256-
s = size(last(b), 1)
257-
b[n] = TM([
258-
b[n] Bcol[end-s+1:end]
259-
Brow[:, end-s+1:end] Bcorner
260-
])
261-
return A
262-
end
263-
264-
"""
265-
_start_new_block!(A, B)
266-
267-
Appends a new block to the end of `A` with `B`
268-
"""
269-
function _start_new_block!(A::BlockDiagonal{T, TM}, B) where {T, TM}
270-
push!(BlockDiagonals.blocks(A), TM(fill(only(B), 1, 1)))
271-
return A
272-
end
273-
274-
Base.size(B::BlockDiagonals.BlockDiagonal) = sum(firstsize, BlockDiagonals.blocks(B), init=0), sum(lastsize, BlockDiagonals.blocks(B), init=0)
275-
276246
start(::LookAheadLanczosDecomp) = 1
277247
done(ld::LookAheadLanczosDecomp, iteration::Int) = iteration ld.opts.max_iter
278248
function iterate(ld::LookAheadLanczosDecomp, n::Int=start(ld))
@@ -429,16 +399,20 @@ function _update_D!(ld)
429399
# Alg. 5.2.1
430400
# Eq. 5.2:
431401
# F[n-1] = Wt[n-1]V[n]L[n-1] = D[n-1]L[1:n-1, 1:n-1] + l[n, n-1]D[1:n-1, n][0 ... 0 1]
432-
# => D[1:end-1, end] = (F[:, end] - (D_prev L[1:end-1, end])) / ρ
402+
# => D[1:end-1, end] = (F[:, end] - (D_prev L[1:end-1, 1:end]))[:, end] / ρ
433403
# Eq. 3.15, (D Γ)ᵀ = (D Γ)
434404
# D[n, n] = wtv
435405

436-
if isone(ld.n) || _VW_block_size(ld) == 1
437-
_start_new_block!(ld.D, ld.wtv)
406+
# TODO: closed block
407+
if isone(ld.n)
408+
ld.D = fill(ld.wtv, 1, 1)
438409
else
439410
D_lastcol = (ld.F[:, end] - (ld.D * ld.L[1:end-1, end])) / ld.ρ
440411
D_lastrow = transpose(D_lastcol * ld.γ[end] ./ ld.γ[1:end-1])
441-
_grow_last_block!(ld.D, D_lastcol, D_lastrow, ld.wtv)
412+
ld.D = [
413+
ld.D D_lastcol
414+
D_lastrow ld.wtv
415+
]
442416
end
443417
return ld
444418
end
@@ -461,17 +435,14 @@ function _update_Flastrow!(ld)
461435
# Eq. 5.2 (w/ indices advanced):
462436
# F_{n} = D_{n}L[1:n, 1:n] + l[n+1, n]D_{n}[1:n, n+1][0 ... 0 1]
463437
# TODO: block
464-
if isone(ld.n)
465-
_start_new_block!(ld.F, 0.0)
466-
else
438+
if !isone(ld.n) # We only need to do this if we are constructing a block
467439
Flastrow = reshape(ld.D[end:end, :] * ld.L, :)
468440
ld.F̃lastcol = Flastrow .* ld.γ[1:end-1] ./ ld.γ[end]
469441
# we are not able to fill in the last column yet, so we fill with zero
470-
if _VW_block_size(ld) == 1
471-
_grow_last_block!(ld.F, fill(0.0, size(ld.F, 1)), transpose(Flastrow), 0.0)
472-
else
473-
_grow_last_block!(ld.F, fill(0.0, size(ld.F, 1)), transpose(Flastrow), 0.0)
474-
end
442+
ld.F = [
443+
ld.F fill(0.0, size(ld.F, 1))
444+
transpose(Flastrow) 0.0
445+
]
475446
end
476447
end
477448

@@ -484,7 +455,6 @@ function _update_U!(ld, innerp)
484455
idx_offset = 0
485456
# TODO
486457
# we only store the entries from mk[kstar] to n-1
487-
488458
ld.U = UpperTriangular(
489459
[
490460
ld.U fill(0.0, n-1, 1)
@@ -612,13 +582,16 @@ function _update_E!(ld)
612582
# 5.2.14
613583
n = ld.n
614584

615-
if isone(ld.n) || _PQ_block_size(ld) == 1
616-
_start_new_block!(ld.E, ld.qtAp)
585+
if isone(ld.n)
586+
ld.E = fill(ld.qtAp, 1, 1)
617587
else
618588
ΓUtinvΓ = ld.γ .* transpose(ld.U) ./ transpose(ld.γ)
619589
Elastrow = (ΓUtinvΓ[end, end] \ ld.F[n:n, 1:n-1] - ΓUtinvΓ[end:end, 1:end-1]*ld.E)
620-
Elastcol = (Elastrow .* ld.γ[1:n-1] ./ ld.γ[n])
621-
_grow_last_block!(ld.E, Elastcol, Elastrow, ld.qtAp)
590+
Elastcol = (transpose(Elastrow) .* ld.γ[1:n-1] ./ ld.γ[n])
591+
ld.E = [
592+
ld.E Elastcol
593+
Elastrow ld.qtAp
594+
]
622595
end
623596
return ld
624597
end
@@ -637,7 +610,7 @@ function _update_Flastcol!(ld)
637610
ΓUtinvΓ = ld.γ .* transpose(ld.U) ./ transpose(ld.γ)
638611
# length n, ld.F_lastrow of length n-1
639612
if isone(n)
640-
ld.F[1, 1] = ΓUtinvΓ[end, end] * ld.E[end, end]
613+
ld.F = fill(ΓUtinvΓ[end, end] * ld.E[end, end], 1, 1)
641614
else
642615
ld.F[:, end] .= ΓUtinvΓ * ld.E[:, end]
643616
end
@@ -654,7 +627,6 @@ function _update_L!(ld, innerv)
654627
Llastcol[block_start:block_end] .= ld.D[block_start:block_end, block_start:block_end] \ ld.F[block_start:block_end, end]
655628
end
656629
if !innerv
657-
@show ld.D
658630
Llastcol[nl[l]:end] .= ld.D[nl[l]:end, nl[l]:end] \ ld.F[nl[l]:end, end]
659631
end
660632
if isone(n)

0 commit comments

Comments
 (0)