Skip to content

Commit 0a35cc7

Browse files
committed
upload lanczos with full reorthogonalization
1 parent a34fd9f commit 0a35cc7

File tree

1 file changed

+80
-4
lines changed

1 file changed

+80
-4
lines changed

perf/lanczos/main.jl

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,90 @@
11
using Reactant
2+
using Reactant: Ops, TracedRNumber
23
using LinearAlgebra
34
using Random
5+
using Statistics
6+
using BenchmarkTools
47

8+
# setup
59
Random.seed!(0)
610

7-
A = rand(ComplexF64, 512, 512)
11+
A = rand(Float64, 512, 512)
812
A = A' * A # make it hermitian
913
@assert ishermitian(A)
1014

11-
b = normalize!(rand(ComplexF64, 512))
15+
b = normalize!(rand(Float64, 512))
1216

13-
Are = Reactant.to_rarray(A)
14-
bre = Reactant.to_rarray(b)
17+
A_re = Reactant.to_rarray(A)
18+
b_re = Reactant.to_rarray(b)
19+
20+
# fixes
21+
# TODO move to Reactant
22+
function Base.zeros(::Type{TracedRNumber{T}}, dims::NTuple{N,<:Integer}) where {T,N}
23+
_zero = Ops.constant(zero(T))
24+
return Ops.broadcast_in_dim(_zero, Int[], collect(dims))
25+
end
26+
27+
# algorithm
28+
# - A: matrix to (partially) decompose. lanczos requires it to be symmetric/hermitian.
29+
# - v0: initial vector, should be normalized.
30+
# - m: decomposition rank
31+
function lanczos(A, v0, m)
32+
n = size(A, 1)
33+
V = zeros(eltype(A), n, m + 1)
34+
T = zeros(eltype(A), m, m)
35+
36+
v = v0 / norm(v0)
37+
V[:, 1] = v
38+
beta = 0.0
39+
w = similar(v)
40+
41+
@allowscalar for j in 1:m
42+
w .= A * v
43+
if j > 1
44+
w .-= beta * V[:, j - 1]
45+
end
46+
alpha = dot(w, v)
47+
w .-= alpha * v
48+
beta = norm(w)
49+
50+
T[j, j] = alpha
51+
if j < m
52+
T[j, j + 1] = beta
53+
T[j + 1, j] = beta
54+
end
55+
56+
# early termination if Krylov subspace is reached
57+
# TODO Reactant.@trace doesn't support return statements yet
58+
# @trace if beta < eps(eltype(A))
59+
# return V[:, 1:j], T[1:j, 1:j]
60+
# end
61+
62+
# full reorthogonalization via modified Gram-Schmidt to avoid spurious duplicate eigenvalues
63+
# TODO implicitly restarted Lanczos? available at KrylovKit
64+
for k in 1:j
65+
w .-= dot(V[:, k], w) * V[:, k]
66+
end
67+
68+
v = w / beta
69+
V[:, j + 1] = v
70+
end
71+
72+
return V, T
73+
end
74+
75+
V, T = lanczos(A, b, 512)
76+
eigvals(T)
77+
78+
l1_error = sum(abs.(eigvals(A) - eigvals(T)))
79+
l2_error = sqrt(sum(abs2.(eigvals(A) - eigvals(T))))
80+
linf_error = maximum(abs.(eigvals(A) - eigvals(T)))
81+
@info "Error" l1 = l1_error l2 = l2_error linf = linf_error
82+
83+
# benchmarking
84+
85+
# @benchmark lanczos($A, $b, 16) setup = (GC.gc())
86+
@benchmark lanczos($A, $b, 16)
87+
88+
# compile with Reactant
89+
f = @compile sync = true lanczos(A_re, b_re, 16)
90+
@benchmark $f($A_re, $b_re, 16)

0 commit comments

Comments
 (0)