|
1 | 1 | using Reactant
|
| 2 | +using Reactant: Ops, TracedRNumber |
2 | 3 | using LinearAlgebra
|
3 | 4 | using Random
|
| 5 | +using Statistics |
| 6 | +using BenchmarkTools |
4 | 7 |
|
| 8 | +# setup |
5 | 9 | Random.seed!(0)
|
6 | 10 |
|
7 |
| -A = rand(ComplexF64, 512, 512) |
| 11 | +A = rand(Float64, 512, 512) |
8 | 12 | A = A' * A # make it hermitian
|
9 | 13 | @assert ishermitian(A)
|
10 | 14 |
|
11 |
| -b = normalize!(rand(ComplexF64, 512)) |
| 15 | +b = normalize!(rand(Float64, 512)) |
12 | 16 |
|
13 |
| -Are = Reactant.to_rarray(A) |
14 |
| -bre = Reactant.to_rarray(b) |
| 17 | +A_re = Reactant.to_rarray(A) |
| 18 | +b_re = Reactant.to_rarray(b) |
| 19 | + |
| 20 | +# fixes |
| 21 | +# TODO move to Reactant |
| 22 | +function Base.zeros(::Type{TracedRNumber{T}}, dims::NTuple{N,<:Integer}) where {T,N} |
| 23 | + _zero = Ops.constant(zero(T)) |
| 24 | + return Ops.broadcast_in_dim(_zero, Int[], collect(dims)) |
| 25 | +end |
| 26 | + |
| 27 | +# algorithm |
| 28 | +# - A: matrix to (partially) decompose. lanczos requires it to be symmetric/hermitian. |
| 29 | +# - v0: initial vector, should be normalized. |
| 30 | +# - m: decomposition rank |
| 31 | +function lanczos(A, v0, m) |
| 32 | + n = size(A, 1) |
| 33 | + V = zeros(eltype(A), n, m + 1) |
| 34 | + T = zeros(eltype(A), m, m) |
| 35 | + |
| 36 | + v = v0 / norm(v0) |
| 37 | + V[:, 1] = v |
| 38 | + beta = 0.0 |
| 39 | + w = similar(v) |
| 40 | + |
| 41 | + @allowscalar for j in 1:m |
| 42 | + w .= A * v |
| 43 | + if j > 1 |
| 44 | + w .-= beta * V[:, j - 1] |
| 45 | + end |
| 46 | + alpha = dot(w, v) |
| 47 | + w .-= alpha * v |
| 48 | + beta = norm(w) |
| 49 | + |
| 50 | + T[j, j] = alpha |
| 51 | + if j < m |
| 52 | + T[j, j + 1] = beta |
| 53 | + T[j + 1, j] = beta |
| 54 | + end |
| 55 | + |
| 56 | + # early termination if Krylov subspace is reached |
| 57 | + # TODO Reactant.@trace doesn't support return statements yet |
| 58 | + # @trace if beta < eps(eltype(A)) |
| 59 | + # return V[:, 1:j], T[1:j, 1:j] |
| 60 | + # end |
| 61 | + |
| 62 | + # full reorthogonalization via modified Gram-Schmidt to avoid spurious duplicate eigenvalues |
| 63 | + # TODO implicitly restarted Lanczos? available at KrylovKit |
| 64 | + for k in 1:j |
| 65 | + w .-= dot(V[:, k], w) * V[:, k] |
| 66 | + end |
| 67 | + |
| 68 | + v = w / beta |
| 69 | + V[:, j + 1] = v |
| 70 | + end |
| 71 | + |
| 72 | + return V, T |
| 73 | +end |
| 74 | + |
| 75 | +V, T = lanczos(A, b, 512) |
| 76 | +eigvals(T) |
| 77 | + |
| 78 | +l1_error = sum(abs.(eigvals(A) - eigvals(T))) |
| 79 | +l2_error = sqrt(sum(abs2.(eigvals(A) - eigvals(T)))) |
| 80 | +linf_error = maximum(abs.(eigvals(A) - eigvals(T))) |
| 81 | +@info "Error" l1 = l1_error l2 = l2_error linf = linf_error |
| 82 | + |
| 83 | +# benchmarking |
| 84 | + |
| 85 | +# @benchmark lanczos($A, $b, 16) setup = (GC.gc()) |
| 86 | +@benchmark lanczos($A, $b, 16) |
| 87 | + |
| 88 | +# compile with Reactant |
| 89 | +f = @compile sync = true lanczos(A_re, b_re, 16) |
| 90 | +@benchmark $f($A_re, $b_re, 16) |
0 commit comments