Skip to content

Commit ee0ecf1

Browse files
try simpler gpu code
1 parent 9ea818d commit ee0ecf1

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@ using SciMLBase: AbstractSciMLOperator
66
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
77
kwargs...)
88
if cache.isfresh
9-
fact = LinearSolve.do_factorization(alg, CUDA.CuArray(cache.A), cache.b, cache.u)
9+
fact = lu(CUDA.CuArray(cache.A))
1010
cache.cacheval = fact
1111
cache.isfresh = false
1212
end
13-
14-
copyto!(cache.u, cache.b)
15-
y = Array(ldiv!(cache.cacheval, CUDA.CuArray(cache.u)))
13+
y = Array(ldiv!(cache.u, cache.cacheval, CUDA.CuArray(cache.u)))
1614
SciMLBase.build_linear_solution(alg, y, nothing, cache)
1715
end
1816

19-
function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
20-
fact = lu(CUDA.CuArray(A))
21-
return fact
17+
function init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
18+
maxiters::Int, abstol, reltol, verbose::Bool,
19+
assumptions::OperatorAssumptions)
20+
ArrayInterface.lu_instance(CUDA.CuArray(A))
2221
end
2322

2423
end

0 commit comments

Comments
 (0)