Skip to content

Commit 82c2074

Browse files
Add SparseMatricesCSR.jl extension (#2720)
1 parent b19e47b commit 82c2074

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ demumble_jll = "1e29f10c-031c-5a83-9565-69cddfc27673"
4141
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4242
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
4343
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
44+
SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
4445

4546
[extensions]
4647
ChainRulesCoreExt = "ChainRulesCore"
4748
EnzymeCoreExt = "EnzymeCore"
4849
SpecialFunctionsExt = "SpecialFunctions"
50+
SparseMatricesCSRExt = "SparseMatricesCSR"
4951

5052
[compat]
5153
AbstractFFTs = "0.4, 0.5, 1.0"
@@ -80,6 +82,7 @@ RandomNumbers = "1.5.3"
8082
Reexport = "0.2, 1.0"
8183
Requires = "0.5, 1.0"
8284
SparseArrays = "1"
85+
SparseMatricesCSR = "0.6.9"
8386
SpecialFunctions = "1.3, 2"
8487
StaticArrays = "1"
8588
Statistics = "1"
@@ -90,3 +93,4 @@ julia = "1.10"
9093
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9194
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
9295
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
96+
SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"

ext/SparseMatricesCSRExt.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module SparseMatricesCSRExt
2+
3+
using CUDA
4+
import CUDA.CUSPARSE:
5+
CuSparseMatrixCSR, CuSparseMatrixCSC, CuSparseMatrixCOO, CuSparseMatrixBSR,
6+
SparseMatrixCSC
7+
using SparseMatricesCSR
8+
import SparseMatricesCSR: SparseMatrixCSR
9+
import Adapt
10+
11+
# CPU → GPU
12+
CUSPARSE.CuSparseMatrixCSR{T}(Mat::SparseMatrixCSR) where {T} =
13+
CUSPARSE.CuSparseMatrixCSR{T}(
14+
CuVector{Cint}(Mat.rowptr), CuVector{Cint}(Mat.colval),
15+
CuVector{T}(Mat.nzval), size(Mat)
16+
)
17+
CUSPARSE.CuSparseMatrixCSC{T}(Mat::SparseMatrixCSR) where {T} = CuSparseMatrixCSC(CuSparseMatrixCSR{T}(Mat))
18+
CUSPARSE.CuSparseMatrixCOO{T}(Mat::SparseMatrixCSR) where {T} = CuSparseMatrixCOO(CuSparseMatrixCSR{T}(Mat))
19+
CUSPARSE.CuSparseMatrixBSR{T}(Mat::SparseMatrixCSR, blockdim) where {T} = CuSparseMatrixBSR(CuSparseMatrixCSR{T}(Mat), blockdim)
20+
21+
# GPU → CPU
22+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCSR) = SparseMatrixCSR{1}(size(A)..., Array(A.rowPtr), Array(A.colVal), Array(A.nzVal))
23+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCOO) = SparseMatrixCSR(CuSparseMatrixCSR(A))
24+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixCSC) = SparseMatrixCSR(CuSparseMatrixCSR(A))
25+
SparseMatricesCSR.SparseMatrixCSR(A::CUSPARSE.CuSparseMatrixBSR) = SparseMatrixCSR(CuSparseMatrixCSR(A))
26+
27+
# Adapt
28+
Adapt.adapt_storage(::Type{CuArray}, xs::SparseMatrixCSR) = CUSPARSE.CuSparseMatrixCSR(xs)
29+
Adapt.adapt_storage(::Type{CuArray{T}}, xs::SparseMatrixCSR) where {T} = CUSPARSE.CuSparseMatrixCSR{T}(xs)
30+
Adapt.adapt_storage(::Type{Array}, mat::CUSPARSE.CuSparseMatrixCSR) = SparseMatrixCSR(mat)
31+
32+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2121
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
2222
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2323
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
24+
SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
2425
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2526
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2627
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using SparseMatricesCSR
2+
using SparseArrays
3+
using CUDA
4+
using CUDA.CUSPARSE
5+
using Test
6+
7+
@testset "SparseMatricesCSRExt" begin
8+
9+
for (n, bd, p) in [(100, 5, 0.02), (5, 1, 0.8), (4, 2, 0.5)]
10+
v"12.0" <= CUSPARSE.version() < v"12.1" && n == 4 && continue
11+
@testset "conversions between CuSparseMatrices (n, bd, p) = ($n, $bd, $p)" begin
12+
_A = sprand(n, n, p)
13+
A = SparseMatrixCSR(_A)
14+
blockdim = bd
15+
for CuSparseMatrixType1 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR)
16+
dA1 = CuSparseMatrixType1 == CuSparseMatrixBSR ? CuSparseMatrixType1(A, blockdim) : CuSparseMatrixType1(A)
17+
@testset "conversion $CuSparseMatrixType1 --> SparseMatrixCSR" begin
18+
@test SparseMatrixCSR(dA1) A
19+
end
20+
for CuSparseMatrixType2 in (CuSparseMatrixCSC, CuSparseMatrixCSR, CuSparseMatrixCOO, CuSparseMatrixBSR)
21+
CuSparseMatrixType1 == CuSparseMatrixType2 && continue
22+
dA2 = CuSparseMatrixType2 == CuSparseMatrixBSR ? CuSparseMatrixType2(dA1, blockdim) : CuSparseMatrixType2(dA1)
23+
@testset "conversion $CuSparseMatrixType1 --> $CuSparseMatrixType2" begin
24+
@test collect(dA1) collect(dA2)
25+
end
26+
end
27+
end
28+
end
29+
end
30+
end

0 commit comments

Comments
 (0)