Skip to content

Commit d500b2c

Browse files
committed
support scimloperators in linearsolvecuda
1 parent 63396af commit d500b2c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

lib/LinearSolveCUDA/src/LinearSolveCUDA.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LinearSolveCUDA
22

33
using CUDA, LinearAlgebra, LinearSolve, SciMLBase
4+
using SciMLBase: AbstractSciMLOperator
45

56
struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization end
67

@@ -17,12 +18,13 @@ function SciMLBase.solve(cache::LinearSolve.LinearCache, alg::CudaOffloadFactori
1718
end
1819

1920
function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
20-
A isa Union{AbstractMatrix, SciMLBase.AbstractDiffEqOperator} ||
21+
A isa Union{AbstractMatrix, AbstractSciMLOperator} ||
2122
error("LU is not defined for $(typeof(A))")
2223

23-
if A isa SciMLBase.AbstractDiffEqOperator
24+
if A isa Union{MatrixOperator, DiffEqArrayOperator}
2425
A = A.A
2526
end
27+
2628
fact = qr(CUDA.CuArray(A))
2729
return fact
2830
end

0 commit comments

Comments
 (0)