|
| 1 | +module LinearSolveHYPRE |
| 2 | + |
| 3 | +using HYPRE.LibHYPRE: HYPRE_Complex |
| 4 | +using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector |
| 5 | +using IterativeSolvers: Identity |
| 6 | +using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve, |
| 7 | + OperatorAssumptions, default_tol, init_cacheval, issquare, set_cacheval |
| 8 | +using SciMLBase: LinearProblem, SciMLBase |
| 9 | +using UnPack: @unpack |
| 10 | +using Setfield: @set! |
| 11 | + |
| 12 | +mutable struct HYPRECache |
| 13 | + solver::Union{HYPRE.HYPRESolver, Nothing} |
| 14 | + A::Union{HYPREMatrix, Nothing} |
| 15 | + b::Union{HYPREVector, Nothing} |
| 16 | + u::Union{HYPREVector, Nothing} |
| 17 | + isfresh_A::Bool |
| 18 | + isfresh_b::Bool |
| 19 | + isfresh_u::Bool |
| 20 | +end |
| 21 | + |
| 22 | +function LinearSolve.init_cacheval(alg::HYPREAlgorithm, A, b, u, Pl, Pr, maxiters::Int, |
| 23 | + abstol, reltol, |
| 24 | + verbose::Bool, assumptions::OperatorAssumptions) |
| 25 | + return HYPRECache(nothing, nothing, nothing, nothing, true, true, true) |
| 26 | +end |
| 27 | + |
| 28 | +# Overload set_(A|b|u) in order to keep track of "isfresh" for all of them |
| 29 | +const LinearCacheHYPRE = LinearCache{<:Any, <:Any, <:Any, <:Any, <:Any, HYPRECache} |
| 30 | +function LinearSolve.set_A(cache::LinearCacheHYPRE, A) |
| 31 | + @set! cache.A = A |
| 32 | + cache.cacheval.isfresh_A = true |
| 33 | + @set! cache.isfresh = true |
| 34 | + return cache |
| 35 | +end |
| 36 | +function LinearSolve.set_b(cache::LinearCacheHYPRE, b) |
| 37 | + @set! cache.b = b |
| 38 | + cache.cacheval.isfresh_b = true |
| 39 | + return cache |
| 40 | +end |
| 41 | +function LinearSolve.set_u(cache::LinearCacheHYPRE, u) |
| 42 | + @set! cache.u = u |
| 43 | + cache.cacheval.isfresh_u = true |
| 44 | + return cache |
| 45 | +end |
| 46 | + |
| 47 | +# Note: |
| 48 | +# SciMLBase.init is overloaded here instead of just LinearSolve.init_cacheval for two |
| 49 | +# reasons: |
| 50 | +# - HYPREArrays can't really be `deepcopy`d, so that is turned off by default |
| 51 | +# - The solution vector/initial guess u0 can't be created with |
| 52 | +# fill!(similar(b, size(A, 2)), false) since HYPREArrays are not AbstractArrays. |
| 53 | + |
| 54 | +function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, |
| 55 | + args...; |
| 56 | + alias_A = false, alias_b = false, |
| 57 | + # TODO: Implement eltype for HYPREMatrix in HYPRE.jl? Looks useful |
| 58 | + # even if it is not AbstractArray. |
| 59 | + abstol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex : |
| 60 | + eltype(prob.A)), |
| 61 | + reltol = default_tol(prob.A isa HYPREMatrix ? HYPRE_Complex : |
| 62 | + eltype(prob.A)), |
| 63 | + # TODO: Implement length() for HYPREVector in HYPRE.jl? |
| 64 | + maxiters::Int = prob.b isa HYPREVector ? 1000 : length(prob.b), |
| 65 | + verbose::Bool = false, |
| 66 | + Pl = Identity(), |
| 67 | + Pr = Identity(), |
| 68 | + assumptions = OperatorAssumptions(), |
| 69 | + kwargs...) |
| 70 | + @unpack A, b, u0, p = prob |
| 71 | + |
| 72 | + # Create solution vector/initial guess |
| 73 | + if u0 === nothing |
| 74 | + u0 = zero(b) |
| 75 | + end |
| 76 | + |
| 77 | + # Initialize internal alg cache |
| 78 | + cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose, |
| 79 | + assumptions) |
| 80 | + Tc = typeof(cacheval) |
| 81 | + isfresh = true |
| 82 | + |
| 83 | + cache = LinearCache{ |
| 84 | + typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc, |
| 85 | + typeof(Pl), typeof(Pr), typeof(reltol), issquare(assumptions) |
| 86 | + }(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, |
| 87 | + maxiters, |
| 88 | + verbose, assumptions) |
| 89 | + return cache |
| 90 | +end |
| 91 | + |
| 92 | +# Solvers whose constructor requires passing the MPI communicator |
| 93 | +const COMM_SOLVERS = Union{HYPRE.BiCGSTAB, HYPRE.FlexGMRES, HYPRE.GMRES, HYPRE.ParaSails, |
| 94 | + HYPRE.PCG} |
| 95 | +create_solver(::Type{S}, comm) where {S <: COMM_SOLVERS} = S(comm) |
| 96 | + |
| 97 | +# Solvers whose constructor should not be passed the MPI communicator |
| 98 | +const NO_COMM_SOLVERS = Union{HYPRE.BoomerAMG, HYPRE.Hybrid, HYPRE.ILU} |
| 99 | +create_solver(::Type{S}, comm) where {S <: NO_COMM_SOLVERS} = S() |
| 100 | + |
| 101 | +function create_solver(alg::HYPREAlgorithm, cache::LinearCache) |
| 102 | + # If the solver is already instantiated, return it directly |
| 103 | + if alg.solver isa HYPRE.HYPRESolver |
| 104 | + return alg.solver |
| 105 | + end |
| 106 | + |
| 107 | + # Otherwise instantiate |
| 108 | + if !(alg.solver <: Union{COMM_SOLVERS, NO_COMM_SOLVERS}) |
| 109 | + throw(ArgumentError("unknown or unsupported HYPRE solver: $(alg.solver)")) |
| 110 | + end |
| 111 | + comm = cache.cacheval.A.comm # communicator from the matrix |
| 112 | + solver = create_solver(alg.solver, comm) |
| 113 | + |
| 114 | + # Construct solver options |
| 115 | + solver_options = (; |
| 116 | + AbsoluteTol = cache.abstol, |
| 117 | + MaxIter = cache.maxiters, |
| 118 | + PrintLevel = Int(cache.verbose), |
| 119 | + Tol = cache.reltol) |
| 120 | + |
| 121 | + # Preconditioner (uses Pl even though it might not be a *left* preconditioner just *a* |
| 122 | + # preconditioner) |
| 123 | + if !(cache.Pl isa Identity) |
| 124 | + precond = if cache.Pl isa HYPRESolver |
| 125 | + cache.Pl |
| 126 | + elseif cache.Pl <: HYPRESolver |
| 127 | + create_solver(cache.Pl, comm) |
| 128 | + else |
| 129 | + throw(ArgumentError("unknown HYPRE preconditioner $(cache.Pl)")) |
| 130 | + end |
| 131 | + solver_options = merge(solver_options, (; Precond = precond)) |
| 132 | + end |
| 133 | + |
| 134 | + # Filter out some options that are not supported for some solvers |
| 135 | + if solver isa HYPRE.Hybrid |
| 136 | + # Rename MaxIter to PCGMaxIter |
| 137 | + MaxIter = solver_options.MaxIter |
| 138 | + ks = filter(x -> x !== :MaxIter, keys(solver_options)) |
| 139 | + solver_options = NamedTuple{ks}(solver_options) |
| 140 | + solver_options = merge(solver_options, (; PCGMaxIter = MaxIter)) |
| 141 | + elseif solver isa HYPRE.BoomerAMG || solver isa HYPRE.ILU |
| 142 | + # Remove AbsoluteTol, Precond |
| 143 | + ks = filter(x -> !in(x, (:AbsoluteTol, :Precond)), keys(solver_options)) |
| 144 | + solver_options = NamedTuple{ks}(solver_options) |
| 145 | + end |
| 146 | + |
| 147 | + # Set the options |
| 148 | + HYPRE.Internals.set_options(solver, pairs(solver_options)) |
| 149 | + |
| 150 | + return solver |
| 151 | +end |
| 152 | + |
| 153 | +# TODO: How are args... and kwargs... supposed to be used here? |
| 154 | +function SciMLBase.solve(cache::LinearCache, alg::HYPREAlgorithm, args...; kwargs...) |
| 155 | + # It is possible to reach here without HYPRE.Init() being called if HYPRE structures are |
| 156 | + # only to be created here internally (i.e. when cache.A::SparseMatrixCSC and not a |
| 157 | + # ::HYPREMatrix created externally by the user). Be nice to the user and call it :) |
| 158 | + if !(cache.A isa HYPREMatrix || cache.b isa HYPREVector || cache.u isa HYPREVector || |
| 159 | + alg.solver isa HYPRESolver) |
| 160 | + HYPRE.Init() |
| 161 | + end |
| 162 | + |
| 163 | + # Move matrix and vectors to HYPRE, if not already provided as HYPREArrays |
| 164 | + hcache = cache.cacheval |
| 165 | + if hcache.isfresh_A || hcache.A === nothing |
| 166 | + hcache.A = cache.A isa HYPREMatrix ? cache.A : HYPREMatrix(cache.A) |
| 167 | + hcache.isfresh_A = false |
| 168 | + end |
| 169 | + if hcache.isfresh_b || hcache.b === nothing |
| 170 | + hcache.b = cache.b isa HYPREVector ? cache.b : HYPREVector(cache.b) |
| 171 | + hcache.isfresh_b = false |
| 172 | + end |
| 173 | + if hcache.isfresh_u || hcache.u === nothing |
| 174 | + hcache.u = cache.u isa HYPREVector ? cache.u : HYPREVector(cache.u) |
| 175 | + hcache.isfresh_u = false |
| 176 | + end |
| 177 | + |
| 178 | + # Create the solver. |
| 179 | + if hcache.solver === nothing |
| 180 | + hcache.solver = create_solver(alg, cache) |
| 181 | + end |
| 182 | + |
| 183 | + # Done with cache updates; set it |
| 184 | + cache = set_cacheval(cache, hcache) |
| 185 | + |
| 186 | + # Solve! |
| 187 | + HYPRE.solve!(hcache.solver, hcache.u, hcache.A, hcache.b) |
| 188 | + |
| 189 | + # Copy back if the output is not HYPREVector |
| 190 | + if cache.u !== hcache.u |
| 191 | + @assert !(cache.u isa HYPREVector) |
| 192 | + copy!(cache.u, hcache.u) |
| 193 | + end |
| 194 | + |
| 195 | + # Note: Inlining SciMLBase.build_linear_solution(alg, u, resid, cache; retcode, iters) |
| 196 | + # since some of the functions used in there does not play well with HYPREVector. |
| 197 | + |
| 198 | + T = cache.u isa HYPREVector ? HYPRE_Complex : eltype(cache.u) # eltype(u) |
| 199 | + N = 1 # length((size(u)...,)) |
| 200 | + resid = nothing # TODO: Fetch from solver |
| 201 | + iters = 0 # TODO: Fetch from solver |
| 202 | + retc = SciMLBase.ReturnCode.Default # TODO: Fetch from solver |
| 203 | + |
| 204 | + ret = SciMLBase.LinearSolution{T, N, typeof(cache.u), typeof(resid), typeof(alg), |
| 205 | + typeof(cache)}(cache.u, resid, alg, retc, iters, cache) |
| 206 | + |
| 207 | + return ret |
| 208 | +end |
| 209 | + |
| 210 | +# HYPREArrays are not AbstractArrays so perform some type-piracy |
| 211 | +function SciMLBase.LinearProblem(A::HYPREMatrix, b::HYPREVector, |
| 212 | + p = SciMLBase.NullParameters(); |
| 213 | + u0::Union{HYPREVector, Nothing} = nothing, kwargs...) |
| 214 | + return LinearProblem{true}(A, b, p; u0 = u0, kwargs) |
| 215 | +end |
| 216 | + |
| 217 | +end # module LinearSolveHYPRE |
0 commit comments