Skip to content
41 changes: 23 additions & 18 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,36 +120,41 @@ end
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
kwargs...)
if cache.isfresh
cacheval = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
# Convert to Float32 for factorization
A_f32 = Float32.(cache.A)
fact = lu(CUDA.CuArray(A_f32))
cache.cacheval = fact
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
# Convert to Float32 for factorization using cached type
A_f32 = T32.(cache.A)
copyto!(A_gpu_f32, A_f32)
fact = lu(A_gpu_f32)
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig)
cache.isfresh = false
end
fact = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
# Convert b to Float32, solve, then convert back to original precision
b_f32 = Float32.(cache.b)
u_f32 = CUDA.CuArray(b_f32)
y_f32 = ldiv!(u_f32, fact, CUDA.CuArray(b_f32))
b_f32 = T32.(cache.b)
copyto!(b_gpu_f32, b_f32)
ldiv!(u_gpu_f32, fact, b_gpu_f32)
# Convert back to original precision
y = Array(y_f32)
T = eltype(cache.u)
cache.u .= T.(y)
y = Array(u_gpu_f32)
cache.u .= Torig.(y)
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
# Pre-allocate with Float32 arrays
A_f32 = Float32.(A)
T = eltype(A_f32)
noUnitT = typeof(zero(T))
# Pre-allocate with Float32 arrays and cache types
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
noUnitT = typeof(zero(T32))
luT = LinearAlgebra.lutype(noUnitT)
ipiv = CuVector{Int32}(undef, 0)
ipiv = CuVector{Int32}(undef, min(m, n))
info = zero(LinearAlgebra.BlasInt)
return LU{luT}(CuMatrix{Float32}(undef, 0, 0), ipiv, info)
fact = LU{luT}(CuMatrix{T32}(undef, m, n), ipiv, info)
A_gpu_f32 = CuMatrix{T32}(undef, m, n)
b_gpu_f32 = CuVector{T32}(undef, size(b, 1))
u_gpu_f32 = CuVector{T32}(undef, size(u, 1))
return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig)
end

end
44 changes: 27 additions & 17 deletions ext/LinearSolveMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,47 @@ default_alias_b(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false
function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
# Pre-allocate with Float32 arrays
A_f32 = Float32.(convert(AbstractMatrix, A))
ArrayInterface.lu_instance(A_f32)
# Pre-allocate with Float32 arrays and cache types
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_f32 = similar(A, T32)
b_f32 = similar(b, T32)
u_f32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
# Pre-allocate Metal arrays
A_mtl = MtlArray{T32}(undef, m, n)
b_mtl = MtlVector{T32}(undef, size(b, 1))
u_mtl = MtlVector{T32}(undef, size(u, 1))
return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig)
end

function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
# Convert to Float32 for factorization
A_f32 = Float32.(A)
res = lu(MtlArray(A_f32))
# Store factorization on CPU with converted types
cache.cacheval = LU(Array(res.factors), Array{Int}(res.ipiv), res.info)
luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
# Convert to appropriate 32-bit type for factorization using cached type
A_f32 .= T32.(A)
copyto!(A_mtl, A_f32)
res = lu(A_mtl)
# Store factorization and pre-allocated arrays
fact = LU(Array(res.factors), Array{Int}(res.ipiv), res.info)
cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig)
cache.isfresh = false
end

fact = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
# Convert b to Float32 for solving
b_f32 = Float32.(cache.b)
u_f32 = similar(b_f32)
fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
# Convert b to 32-bit for solving using cached type
b_f32 .= T32.(cache.b)

# Create a temporary Float32 LU factorization for solving
fact_f32 = LU(Float32.(fact.factors), fact.ipiv, fact.info)
fact_f32 = LU(T32.(fact.factors), fact.ipiv, fact.info)
ldiv!(u_f32, fact_f32, b_f32)

# Convert back to original precision
T = eltype(cache.u)
cache.u .= T.(u_f32)
# Convert back to original precision using cached type
cache.u .= Torig.(u_f32)
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

Expand Down
61 changes: 25 additions & 36 deletions ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::LinearSolve.OperatorAssumptions) where {P, T}
# Pre-allocate appropriate 32-bit arrays based on input type
if eltype(A) <: Complex
A_32 = rand(ComplexF32, 0, 0)
else
A_32 = rand(Float32, 0, 0)
end
luinst = ArrayInterface.lu_instance(A_32)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
(luinst, ipiv)
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_32 = similar(A, T32)
b_32 = similar(b, T32)
u_32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(m, n))
# Return tuple with pre-allocated arrays and cached types
(luinst, ipiv, A_32, b_32, u_32, T32, Torig)
end

function SciMLBase.solve!(
Expand All @@ -61,25 +63,19 @@ function SciMLBase.solve!(
A = cache.A
A = convert(AbstractMatrix, A)

# Check if we have complex numbers
iscomplex = eltype(A) <: Complex

fact, ipiv = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
if cache.isfresh
# Convert to appropriate 32-bit type for factorization
if iscomplex
A_f32 = ComplexF32.(A)
else
A_f32 = Float32.(A)
end
# Get pre-allocated arrays from cacheval
luinst, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
# Copy A to pre-allocated 32-bit array using cached type
A_32 .= T32.(A)

# Ensure ipiv is the right size
if length(ipiv) != min(size(A_f32)...)
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A_f32)...))
if length(ipiv) != min(size(A_32)...)
resize!(ipiv, min(size(A_32)...))
end

fact = RecursiveFactorization.lu!(A_f32, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv)
fact = RecursiveFactorization.lu!(A_32, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv, A_32, b_32, u_32, T32, Torig)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
Expand All @@ -89,24 +85,17 @@ function SciMLBase.solve!(
cache.isfresh = false
end

# Get the factorization from the cache
fact_cached = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)[1]
# Get the factorization and pre-allocated arrays from the cache
fact_cached, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)

# Convert b to appropriate 32-bit type for solving
if iscomplex
b_f32 = ComplexF32.(cache.b)
u_f32 = similar(b_f32)
else
b_f32 = Float32.(cache.b)
u_f32 = similar(b_f32)
end
# Copy b to pre-allocated 32-bit array using cached type
b_32 .= T32.(cache.b)

# Solve in 32-bit precision
ldiv!(u_f32, fact_cached, b_f32)
ldiv!(u_32, fact_cached, b_32)

# Convert back to original precision
T_orig = eltype(cache.u)
cache.u .= T_orig.(u_f32)
# Convert back to original precision using cached type
cache.u .= Torig.(u_32)

SciMLBase.build_linear_solution(
alg, cache.u, nothing, cache; retcode = ReturnCode.Success)
Expand Down
61 changes: 25 additions & 36 deletions src/appleaccelerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,15 @@ function LinearSolve.init_cacheval(alg::AppleAccelerate32MixedLUFactorization, A
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
# Pre-allocate appropriate 32-bit arrays based on input type
if eltype(A) <: Complex
A_32 = rand(ComplexF32, 0, 0)
else
A_32 = rand(Float32, 0, 0)
end
luinst = ArrayInterface.lu_instance(A_32)
LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}()
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_32 = similar(A, T32)
b_32 = similar(b, T32)
u_32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
# Return tuple with pre-allocated arrays and cached types
(LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32, T32, Torig)
end

function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization;
Expand All @@ -314,19 +316,13 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto
A = cache.A
A = convert(AbstractMatrix, A)

# Check if we have complex numbers
iscomplex = eltype(A) <: Complex

if cache.isfresh
cacheval = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
# Convert to appropriate 32-bit type for factorization
if iscomplex
A_f32 = ComplexF32.(A)
else
A_f32 = Float32.(A)
end
res = aa_getrf!(A_f32; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
# Get pre-allocated arrays from cacheval
luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
# Copy A to pre-allocated 32-bit array using cached type
A_32 .= T32.(A)
res = aa_getrf!(A_32; ipiv = luinst.ipiv, info = info)
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig)
cache.cacheval = fact

if !LinearAlgebra.issuccess(fact[1])
Expand All @@ -336,29 +332,22 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto
cache.isfresh = false
end

A_lu, info = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
require_one_based_indexing(cache.u, cache.b)
m, n = size(A_lu, 1), size(A_lu, 2)

# Convert b to appropriate 32-bit type for solving
if iscomplex
b_f32 = ComplexF32.(cache.b)
else
b_f32 = Float32.(cache.b)
end
# Copy b to pre-allocated 32-bit array using cached type
b_32 .= T32.(cache.b)

if m > n
Bc = copy(b_f32)
aa_getrs!('N', A_lu.factors, A_lu.ipiv, Bc; info)
# Convert back to original precision
T = eltype(cache.u)
cache.u .= T.(Bc[1:n])
aa_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info)
# Convert back to original precision using cached type
cache.u[1:n] .= Torig.(b_32[1:n])
else
u_f32 = copy(b_f32)
aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_f32; info)
# Convert back to original precision
T = eltype(cache.u)
cache.u .= T.(u_f32)
copyto!(u_32, b_32)
aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info)
# Convert back to original precision using cached type
cache.u .= Torig.(u_32)
end

SciMLBase.build_linear_solution(
Expand Down
6 changes: 3 additions & 3 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ alg = MKL32MixedLUFactorization()
sol = solve(prob, alg)
```
"""
struct MKL32MixedLUFactorization <: AbstractFactorization end
struct MKL32MixedLUFactorization <: AbstractDenseFactorization end

"""
AppleAccelerate32MixedLUFactorization()
Expand All @@ -833,7 +833,7 @@ alg = AppleAccelerate32MixedLUFactorization()
sol = solve(prob, alg)
```
"""
struct AppleAccelerate32MixedLUFactorization <: AbstractFactorization end
struct AppleAccelerate32MixedLUFactorization <: AbstractDenseFactorization end

"""
OpenBLAS32MixedLUFactorization()
Expand All @@ -857,7 +857,7 @@ alg = OpenBLAS32MixedLUFactorization()
sol = solve(prob, alg)
```
"""
struct OpenBLAS32MixedLUFactorization <: AbstractFactorization end
struct OpenBLAS32MixedLUFactorization <: AbstractDenseFactorization end

"""
RF32MixedLUFactorization{P, T}(; pivot = Val(true), thread = Val(true))
Expand Down
Loading
Loading