diff --git a/ext/LinearSolveCUDAExt.jl b/ext/LinearSolveCUDAExt.jl index 80d559cb3..77796409b 100644 --- a/ext/LinearSolveCUDAExt.jl +++ b/ext/LinearSolveCUDAExt.jl @@ -120,36 +120,41 @@ end function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization; kwargs...) if cache.isfresh - cacheval = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) - # Convert to Float32 for factorization - A_f32 = Float32.(cache.A) - fact = lu(CUDA.CuArray(A_f32)) - cache.cacheval = fact + fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) + # Convert to Float32 for factorization using cached type + A_f32 = T32.(cache.A) + copyto!(A_gpu_f32, A_f32) + fact = lu(A_gpu_f32) + cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig) cache.isfresh = false end - fact = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) + fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) # Convert b to Float32, solve, then convert back to original precision - b_f32 = Float32.(cache.b) - u_f32 = CUDA.CuArray(b_f32) - y_f32 = ldiv!(u_f32, fact, CUDA.CuArray(b_f32)) + b_f32 = T32.(cache.b) + copyto!(b_gpu_f32, b_f32) + ldiv!(u_gpu_f32, fact, b_gpu_f32) # Convert back to original precision - y = Array(y_f32) - T = eltype(cache.u) - cache.u .= T.(y) + y = Array(u_gpu_f32) + cache.u .= Torig.(y) SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - # Pre-allocate with Float32 arrays - A_f32 = Float32.(A) - T = eltype(A_f32) - noUnitT = typeof(zero(T)) + # Pre-allocate with Float32 arrays and cache types + m, n = size(A) + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(u) + noUnitT = typeof(zero(T32)) luT = LinearAlgebra.lutype(noUnitT) - ipiv = CuVector{Int32}(undef, 0) + ipiv = CuVector{Int32}(undef, min(m, n)) info = zero(LinearAlgebra.BlasInt) - return LU{luT}(CuMatrix{Float32}(undef, 0, 0), ipiv, info) + fact = LU{luT}(CuMatrix{T32}(undef, m, n), ipiv, info) + A_gpu_f32 = CuMatrix{T32}(undef, m, n) + b_gpu_f32 = CuVector{T32}(undef, size(b, 1)) + u_gpu_f32 = CuVector{T32}(undef, size(u, 1)) + return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig) end end diff --git a/ext/LinearSolveMetalExt.jl b/ext/LinearSolveMetalExt.jl index de0175b86..81f497725 100644 --- a/ext/LinearSolveMetalExt.jl +++ b/ext/LinearSolveMetalExt.jl @@ -36,9 +36,19 @@ default_alias_b(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - # Pre-allocate with Float32 arrays - A_f32 = Float32.(convert(AbstractMatrix, A)) - ArrayInterface.lu_instance(A_f32) + # Pre-allocate with Float32 arrays and cache types + m, n = size(A) + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(u) + A_f32 = similar(A, T32) + b_f32 = similar(b, T32) + u_f32 = similar(u, T32) + luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) + # Pre-allocate Metal arrays + A_mtl = MtlArray{T32}(undef, m, n) + b_mtl = MtlVector{T32}(undef, size(b, 1)) + u_mtl = MtlVector{T32}(undef, size(u, 1)) + return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig) end function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactorization; @@ -46,27 +56,27 @@ function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactoriz A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh - cacheval = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) - # Convert to Float32 for factorization - A_f32 = Float32.(A) - res = lu(MtlArray(A_f32)) - # Store factorization on CPU with converted types - cache.cacheval = LU(Array(res.factors), Array{Int}(res.ipiv), res.info) + luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) + # Convert to appropriate 32-bit type for factorization using cached type + A_f32 .= T32.(A) + copyto!(A_mtl, A_f32) + res = lu(A_mtl) + # Store factorization and pre-allocated arrays + fact = LU(Array(res.factors), Array{Int}(res.ipiv), res.info) + cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig) cache.isfresh = false end - fact = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) - # Convert b to Float32 for solving - b_f32 = Float32.(cache.b) - u_f32 = similar(b_f32) + fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) + # Convert b to 32-bit for solving using cached type + b_f32 .= T32.(cache.b) # Create a temporary Float32 LU factorization for solving - fact_f32 = LU(Float32.(fact.factors), fact.ipiv, fact.info) + fact_f32 = LU(T32.(fact.factors), fact.ipiv, fact.info) ldiv!(u_f32, fact_f32, b_f32) - # Convert back to original precision - T = eltype(cache.u) - cache.u .= T.(u_f32) + # Convert back to original precision using cached type + cache.u .= Torig.(u_f32) SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end diff --git a/ext/LinearSolveRecursiveFactorizationExt.jl b/ext/LinearSolveRecursiveFactorizationExt.jl index 3e70807f7..340b53838 100644 --- a/ext/LinearSolveRecursiveFactorizationExt.jl +++ b/ext/LinearSolveRecursiveFactorizationExt.jl @@ -45,14 +45,16 @@ function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::LinearSolve.OperatorAssumptions) where {P, T} # Pre-allocate appropriate 32-bit arrays based on input type - if eltype(A) <: Complex - A_32 = rand(ComplexF32, 0, 0) - else - A_32 = rand(Float32, 0, 0) - end - luinst = ArrayInterface.lu_instance(A_32) - ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...)) - (luinst, ipiv) + m, n = size(A) + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(u) + A_32 = similar(A, T32) + b_32 = similar(b, T32) + u_32 = similar(u, T32) + luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) + ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(m, n)) + # Return tuple with pre-allocated arrays and cached types + (luinst, ipiv, A_32, b_32, u_32, T32, Torig) end function SciMLBase.solve!( @@ -61,25 +63,19 @@ function SciMLBase.solve!( A = cache.A A = convert(AbstractMatrix, A) - # Check if we have complex numbers - iscomplex = eltype(A) <: Complex - - fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization) if cache.isfresh - # Convert to appropriate 32-bit type for factorization - if iscomplex - A_f32 = ComplexF32.(A) - else - A_f32 = Float32.(A) - end + # Get pre-allocated arrays from cacheval + luinst, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization) + # Copy A to pre-allocated 32-bit array using cached type + A_32 .= T32.(A) # Ensure ipiv is the right size - if length(ipiv) != min(size(A_f32)...) - ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...)) + if length(ipiv) != min(size(A_32)...) + resize!(ipiv, min(size(A_32)...)) end - fact = RecursiveFactorization.lu!(A_f32, ipiv, Val(P), Val(T), check = false) - cache.cacheval = (fact, ipiv) + fact = RecursiveFactorization.lu!(A_32, ipiv, Val(P), Val(T), check = false) + cache.cacheval = (fact, ipiv, A_32, b_32, u_32, T32, Torig) if !LinearAlgebra.issuccess(fact) return SciMLBase.build_linear_solution( @@ -89,24 +85,17 @@ function SciMLBase.solve!( cache.isfresh = false end - # Get the factorization from the cache - fact_cached = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)[1] + # Get the factorization and pre-allocated arrays from the cache + fact_cached, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization) - # Convert b to appropriate 32-bit type for solving - if iscomplex - b_f32 = ComplexF32.(cache.b) - u_f32 = similar(b_f32) - else - b_f32 = Float32.(cache.b) - u_f32 = similar(b_f32) - end + # Copy b to pre-allocated 32-bit array using cached type + b_32 .= T32.(cache.b) # Solve in 32-bit precision - ldiv!(u_f32, fact_cached, b_f32) + ldiv!(u_32, fact_cached, b_32) - # Convert back to original precision - T_orig = eltype(cache.u) - cache.u .= T_orig.(u_f32) + # Convert back to original precision using cached type + cache.u .= Torig.(u_32) SciMLBase.build_linear_solution( alg, cache.u, nothing, cache; retcode = ReturnCode.Success) diff --git a/src/appleaccelerate.jl b/src/appleaccelerate.jl index 2af9e63a6..6de1567ee 100644 --- a/src/appleaccelerate.jl +++ b/src/appleaccelerate.jl @@ -298,13 +298,15 @@ function LinearSolve.init_cacheval(alg::AppleAccelerate32MixedLUFactorization, A maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) # Pre-allocate appropriate 32-bit arrays based on input type - if eltype(A) <: Complex - A_32 = rand(ComplexF32, 0, 0) - else - A_32 = rand(Float32, 0, 0) - end - luinst = ArrayInterface.lu_instance(A_32) - LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}() + m, n = size(A) + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(u) + A_32 = similar(A, T32) + b_32 = similar(b, T32) + u_32 = similar(u, T32) + luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) + # Return tuple with pre-allocated arrays and cached types + (LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32, T32, Torig) end function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization; @@ -314,19 +316,13 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto A = cache.A A = convert(AbstractMatrix, A) - # Check if we have complex numbers - iscomplex = eltype(A) <: Complex - if cache.isfresh - cacheval = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) - # Convert to appropriate 32-bit type for factorization - if iscomplex - A_f32 = ComplexF32.(A) - else - A_f32 = Float32.(A) - end - res = aa_getrf!(A_f32; ipiv = cacheval[1].ipiv, info = cacheval[2]) - fact = LU(res[1:3]...), res[4] + # Get pre-allocated arrays from cacheval + luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) + # Copy A to pre-allocated 32-bit array using cached type + A_32 .= T32.(A) + res = aa_getrf!(A_32; ipiv = luinst.ipiv, info = info) + fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig) cache.cacheval = fact if !LinearAlgebra.issuccess(fact[1]) @@ -336,29 +332,22 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto cache.isfresh = false end - A_lu, info = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) + A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) require_one_based_indexing(cache.u, cache.b) m, n = size(A_lu, 1), size(A_lu, 2) - # Convert b to appropriate 32-bit type for solving - if iscomplex - b_f32 = ComplexF32.(cache.b) - else - b_f32 = Float32.(cache.b) - end + # Copy b to pre-allocated 32-bit array using cached type + b_32 .= T32.(cache.b) if m > n - Bc = copy(b_f32) - aa_getrs!('N', A_lu.factors, A_lu.ipiv, Bc; info) - # Convert back to original precision - T = eltype(cache.u) - cache.u .= T.(Bc[1:n]) + aa_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info) + # Convert back to original precision using cached type + cache.u[1:n] .= Torig.(b_32[1:n]) else - u_f32 = copy(b_f32) - aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_f32; info) - # Convert back to original precision - T = eltype(cache.u) - cache.u .= T.(u_f32) + copyto!(u_32, b_32) + aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info) + # Convert back to original precision using cached type + cache.u .= Torig.(u_32) end SciMLBase.build_linear_solution( diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 5cc9ec28d..51cdb901f 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -809,7 +809,7 @@ alg = MKL32MixedLUFactorization() sol = solve(prob, alg) ``` """ -struct MKL32MixedLUFactorization <: AbstractFactorization end +struct MKL32MixedLUFactorization <: AbstractDenseFactorization end """ AppleAccelerate32MixedLUFactorization() @@ -833,7 +833,7 @@ alg = AppleAccelerate32MixedLUFactorization() sol = solve(prob, alg) ``` """ -struct AppleAccelerate32MixedLUFactorization <: AbstractFactorization end +struct AppleAccelerate32MixedLUFactorization <: AbstractDenseFactorization end """ OpenBLAS32MixedLUFactorization() @@ -857,7 +857,7 @@ alg = OpenBLAS32MixedLUFactorization() sol = solve(prob, alg) ``` """ -struct OpenBLAS32MixedLUFactorization <: AbstractFactorization end +struct OpenBLAS32MixedLUFactorization <: AbstractDenseFactorization end """ RF32MixedLUFactorization{P, T}(; pivot = Val(true), thread = Val(true)) diff --git a/src/mkl.jl b/src/mkl.jl index 11fa20f09..0eab16140 100644 --- a/src/mkl.jl +++ b/src/mkl.jl @@ -281,12 +281,15 @@ function LinearSolve.init_cacheval(alg::MKL32MixedLUFactorization, A, b, u, Pl, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) # Pre-allocate appropriate 32-bit arrays based on input type - if eltype(A) <: Complex - A_32 = rand(ComplexF32, 0, 0) - else - A_32 = rand(Float32, 0, 0) - end - ArrayInterface.lu_instance(A_32), Ref{BlasInt}() + m, n = size(A) + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(u) + A_32 = similar(A, T32) + b_32 = similar(b, T32) + u_32 = similar(u, T32) + luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) + # Return tuple with pre-allocated arrays and cached types + (luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig) end function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization; @@ -296,19 +299,13 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization; A = cache.A A = convert(AbstractMatrix, A) - # Check if we have complex numbers - iscomplex = eltype(A) <: Complex - if cache.isfresh - cacheval = @get_cacheval(cache, :MKL32MixedLUFactorization) - # Convert to appropriate 32-bit type for factorization - if iscomplex - A_f32 = ComplexF32.(A) - else - A_f32 = Float32.(A) - end - res = getrf!(A_f32; ipiv = cacheval[1].ipiv, info = cacheval[2]) - fact = LU(res[1:3]...), res[4] + # Get pre-allocated arrays from cacheval + luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization) + # Copy A to pre-allocated 32-bit array using cached type + A_32 .= T32.(A) + res = getrf!(A_32; ipiv = luinst.ipiv, info = info) + fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig) cache.cacheval = fact if !LinearAlgebra.issuccess(fact[1]) @@ -318,29 +315,22 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization; cache.isfresh = false end - A_lu, info = @get_cacheval(cache, :MKL32MixedLUFactorization) + A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization) require_one_based_indexing(cache.u, cache.b) m, n = size(A_lu, 1), size(A_lu, 2) - # Convert b to appropriate 32-bit type for solving - if iscomplex - b_f32 = ComplexF32.(cache.b) - else - b_f32 = Float32.(cache.b) - end + # Copy b to pre-allocated 32-bit array using cached type + b_32 .= T32.(cache.b) if m > n - Bc = copy(b_f32) - getrs!('N', A_lu.factors, A_lu.ipiv, Bc; info) - # Convert back to original precision - T = eltype(cache.u) - cache.u .= T.(Bc[1:n]) + getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info) + # Convert back to original precision using cached type + cache.u[1:n] .= Torig.(b_32[1:n]) else - u_f32 = copy(b_f32) - getrs!('N', A_lu.factors, A_lu.ipiv, u_f32; info) - # Convert back to original precision - T = eltype(cache.u) - cache.u .= T.(u_f32) + copyto!(u_32, b_32) + getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info) + # Convert back to original precision using cached type + cache.u .= Torig.(u_32) end SciMLBase.build_linear_solution( diff --git a/src/openblas.jl b/src/openblas.jl index d5b6d353c..3830c9a39 100644 --- a/src/openblas.jl +++ b/src/openblas.jl @@ -306,12 +306,15 @@ function LinearSolve.init_cacheval(alg::OpenBLAS32MixedLUFactorization, A, b, u, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) # Pre-allocate appropriate 32-bit arrays based on input type - if eltype(A) <: Complex - A_32 = rand(ComplexF32, 0, 0) - else - A_32 = rand(Float32, 0, 0) - end - ArrayInterface.lu_instance(A_32), Ref{BlasInt}() + m, n = size(A) + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(u) + A_32 = similar(A, T32) + b_32 = similar(b, T32) + u_32 = similar(u, T32) + luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) + # Return tuple with pre-allocated arrays and cached types + (luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig) end function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorization; @@ -321,19 +324,13 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio A = cache.A A = convert(AbstractMatrix, A) - # Check if we have complex numbers - iscomplex = eltype(A) <: Complex - if cache.isfresh - cacheval = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) - # Convert to appropriate 32-bit type for factorization - if iscomplex - A_f32 = ComplexF32.(A) - else - A_f32 = Float32.(A) - end - res = openblas_getrf!(A_f32; ipiv = cacheval[1].ipiv, info = cacheval[2]) - fact = LU(res[1:3]...), res[4] + # Get pre-allocated arrays from cacheval + luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) + # Copy A to pre-allocated 32-bit array using cached type + A_32 .= T32.(A) + res = openblas_getrf!(A_32; ipiv = luinst.ipiv, info = info) + fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig) cache.cacheval = fact if !LinearAlgebra.issuccess(fact[1]) @@ -343,29 +340,22 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio cache.isfresh = false end - A_lu, info = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) + A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) require_one_based_indexing(cache.u, cache.b) m, n = size(A_lu, 1), size(A_lu, 2) - # Convert b to appropriate 32-bit type for solving - if iscomplex - b_f32 = ComplexF32.(cache.b) - else - b_f32 = Float32.(cache.b) - end + # Copy b to pre-allocated 32-bit array using cached type + b_32 .= T32.(cache.b) if m > n - Bc = copy(b_f32) - openblas_getrs!('N', A_lu.factors, A_lu.ipiv, Bc; info) - # Convert back to original precision - T = eltype(cache.u) - cache.u .= T.(Bc[1:n]) + openblas_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info) + # Convert back to original precision using cached type + cache.u[1:n] .= Torig.(b_32[1:n]) else - u_f32 = copy(b_f32) - openblas_getrs!('N', A_lu.factors, A_lu.ipiv, u_f32; info) - # Convert back to original precision - T = eltype(cache.u) - cache.u .= T.(u_f32) + copyto!(u_32, b_32) + openblas_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info) + # Convert back to original precision using cached type + cache.u .= Torig.(u_32) end SciMLBase.build_linear_solution( diff --git a/test/nopre/caching_allocation_tests.jl b/test/nopre/caching_allocation_tests.jl index faede6dd4..fe529110b 100644 --- a/test/nopre/caching_allocation_tests.jl +++ b/test/nopre/caching_allocation_tests.jl @@ -1,6 +1,8 @@ using LinearSolve, LinearAlgebra, SparseArrays, Test, StableRNGs using AllocCheck -using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization +using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization, + MKL32MixedLUFactorization, OpenBLAS32MixedLUFactorization, + AppleAccelerate32MixedLUFactorization, RF32MixedLUFactorization using InteractiveUtils rng = StableRNG(123) @@ -15,7 +17,7 @@ rng = StableRNG(123) b3 = rand(rng, n) # Test major dense factorization algorithms - dense_algs = [ + dense_algs = Any[ LUFactorization(), QRFactorization(), CholeskyFactorization(), @@ -25,6 +27,23 @@ rng = StableRNG(123) DiagonalFactorization() ] + # Add mixed precision methods if available + if LinearSolve.usemkl + push!(dense_algs, MKL32MixedLUFactorization()) + end + if LinearSolve.useopenblas + push!(dense_algs, OpenBLAS32MixedLUFactorization()) + end + if Sys.isapple() && LinearSolve.appleaccelerate_isavailable() + push!(dense_algs, AppleAccelerate32MixedLUFactorization()) + end + # Test RF32Mixed only if RecursiveFactorization is available + try + using RecursiveFactorization + push!(dense_algs, RF32MixedLUFactorization()) + catch + end + for alg in dense_algs @testset "$(typeof(alg))" begin # Special matrix preparation for specific algorithms @@ -38,13 +57,20 @@ rng = StableRNG(123) A end + # Mixed precision methods need looser tolerance + is_mixed_precision = alg isa Union{MKL32MixedLUFactorization, + OpenBLAS32MixedLUFactorization, + AppleAccelerate32MixedLUFactorization, + RF32MixedLUFactorization} + tol = is_mixed_precision ? 1e-4 : 1e-10 + # Initialize the cache prob = LinearProblem(test_A, b1) cache = init(prob, alg) # First solve - this will create the factorization sol1 = solve!(cache) - @test norm(test_A * sol1.u - b1) < 1e-10 + @test norm(test_A * sol1.u - b1) < tol # Define the allocation-free solve function function solve_with_new_b!(cache, new_b) @@ -62,11 +88,11 @@ rng = StableRNG(123) # Run the allocation test try @test_nowarn solve_no_alloc!(cache, b2) - @test norm(test_A * cache.u - b2) < 1e-10 + @test norm(test_A * cache.u - b2) < tol # Test one more time with different b @test_nowarn solve_no_alloc!(cache, b3) - @test norm(test_A * cache.u - b3) < 1e-10 + @test norm(test_A * cache.u - b3) < tol catch e # Some algorithms might still allocate in certain Julia versions @test_broken false diff --git a/test_allocation_fix.jl b/test_allocation_fix.jl new file mode 100644 index 000000000..e23b73a8a --- /dev/null +++ b/test_allocation_fix.jl @@ -0,0 +1,105 @@ +using LinearSolve +using LinearAlgebra +using BenchmarkTools + +println("Testing allocation improvements for 32Mixed precision methods...") + +# Test size +n = 100 +A = rand(Float64, n, n) + 5.0I # Well-conditioned matrix +b = rand(Float64, n) + +# Test MKL32MixedLUFactorization if available +if LinearSolve.usemkl + println("\nTesting MKL32MixedLUFactorization:") + prob = LinearProblem(A, b) + + # Warm up + sol = solve(prob, MKL32MixedLUFactorization()) + + # Test allocations on subsequent solves + cache = init(prob, MKL32MixedLUFactorization()) + solve!(cache) # First solve (factorization) + + # Change b and solve again - this should have minimal allocations + cache.b .= rand(n) + alloc_bytes = @allocated solve!(cache) + println(" Allocations on second solve: $alloc_bytes bytes") + + # Benchmark + println(" Benchmark results:") + @btime solve!(cache) setup=(cache.b .= rand($n)) +end + +# Test OpenBLAS32MixedLUFactorization if available +if LinearSolve.useopenblas + println("\nTesting OpenBLAS32MixedLUFactorization:") + prob = LinearProblem(A, b) + + # Warm up + sol = solve(prob, OpenBLAS32MixedLUFactorization()) + + # Test allocations on subsequent solves + cache = init(prob, OpenBLAS32MixedLUFactorization()) + solve!(cache) # First solve (factorization) + + # Change b and solve again - this should have minimal allocations + cache.b .= rand(n) + alloc_bytes = @allocated solve!(cache) + println(" Allocations on second solve: $alloc_bytes bytes") + + # Benchmark + println(" Benchmark results:") + @btime solve!(cache) setup=(cache.b .= rand($n)) +end + +# Test AppleAccelerate32MixedLUFactorization if available +if Sys.isapple() && LinearSolve.appleaccelerate_isavailable() + println("\nTesting AppleAccelerate32MixedLUFactorization:") + prob = LinearProblem(A, b) + + # Warm up + sol = solve(prob, AppleAccelerate32MixedLUFactorization()) + + # Test allocations on subsequent solves + cache = init(prob, AppleAccelerate32MixedLUFactorization()) + solve!(cache) # First solve (factorization) + + # Change b and solve again - this should have minimal allocations + cache.b .= rand(n) + alloc_bytes = @allocated solve!(cache) + println(" Allocations on second solve: $alloc_bytes bytes") + + # Benchmark + println(" Benchmark results:") + @btime solve!(cache) setup=(cache.b .= rand($n)) +end + +# Test RF32MixedLUFactorization if RecursiveFactorization is available +try + using RecursiveFactorization + println("\nTesting RF32MixedLUFactorization:") + prob = LinearProblem(A, b) + + # Warm up + sol = solve(prob, RF32MixedLUFactorization()) + + # Test allocations on subsequent solves + cache = init(prob, RF32MixedLUFactorization()) + solve!(cache) # First solve (factorization) + + # Change b and solve again - this should have minimal allocations + cache.b .= rand(n) + alloc_bytes = @allocated solve!(cache) + println(" Allocations on second solve: $alloc_bytes bytes") + + # Benchmark + println(" Benchmark results:") + @btime solve!(cache) setup=(cache.b .= rand($n)) +catch e + println("\nRecursiveFactorization not available, skipping RF32MixedLUFactorization test") +end + +println("\n✅ Allocation test complete!") +println("Note: Ideally, the allocation count on the second solve should be minimal (< 1KB)") +println(" as all temporary arrays should be pre-allocated in init_cacheval.") \ No newline at end of file