@@ -6,6 +6,7 @@ using CUDA, CUDA.CUSPARSE, CUDA.CUFFT
66using Oceananigans. Utils: linear_expand, __linear_ndrange, MappedCompilerMetadata
77using KernelAbstractions: __dynamic_checkbounds, __iterspace
88using KernelAbstractions
9+ using SparseArrays
910
1011import Oceananigans. Architectures as AC
1112import Oceananigans. BoundaryConditions as BC
@@ -73,6 +74,16 @@ AC.on_architecture(::CUDAGPU, a::StepRangeLen) = a
7374AC. on_architecture (arch:: Distributed , a:: CuArray ) = AC. on_architecture (AC. child_architecture (arch), a)
7475AC. on_architecture (arch:: Distributed , a:: SubArray{<:Any, <:Any, <:CuArray} ) = AC. on_architecture (child_architecture (arch), a)
7576
77+ @inline AC. sparse_matrix_constructors (:: AC.GPU{CUDABackend} , A:: SparseMatrixCSC ) = (CuArray (A. colptr), CuArray (A. rowval), CuArray (A. nzval), (A. m, A. n))
78+ @inline AC. sparse_matrix_constructors (:: AC.CPU , A:: CuSparseMatrixCSC ) = (A. dims[1 ], A. dims[2 ], Int64 .(Array (A. colPtr)), Int64 .(Array (A. rowVal)), Array (A. nzVal))
79+ @inline AC. sparse_matrix_constructors (:: AC.GPU{CUDABackend} , A:: CuSparseMatrixCSC ) = (A. colPtr, A. rowVal, A. nzVal, A. dims)
80+
81+ @inline AC. sparse_matrix (:: AC.GPU{CUDABackend} , constr:: Tuple ) = CuSparseMatrixCSC (constr... )
82+
83+ @inline AC. on_architecture (:: AC.CPU , A:: CuSparseMatrixCSC ) = SparseMatrixCSC (AC. sparse_matrix_constructors (AC. CPU (), A)... )
84+ @inline AC. on_architecture (:: AC.GPU{CUDABackend} , A:: SparseMatrixCSC ) = CuSparseMatrixCSC (AC. sparse_matrix_constructors (AC. GPU (), A)... )
85+ @inline AC. on_architecture (:: AC.GPU{CUDABackend} , A:: CuSparseMatrixCSC ) = A
86+
7687# cu alters the type of `a`, so we convert it back to the correct type
7788AC. unified_array (:: CUDAGPU , a:: AbstractArray ) = map (eltype (a), cu (a; unified = true ))
7889
0 commit comments