Skip to content

Commit 59526b2

Browse files
Implement display and other format conversions (#26)
* Implement display and other format conversions * Fix benchmarks * Fix benchmarks typo * Comment out benchmark conversions for JLArray * Add JLArrays support and enable conversion benchmarks - Introduced JLArrays as a weak dependency in Project.toml. - Enabled benchmark conversions for JLArray in runbenchmarks.jl. - Added helper functions for sorting and cumulative sum using AcceleratedKernels in helpers.jl. - Updated conversion tests to include JLArray support in shared test conversions. - Created DeviceSparseArraysJLArraysExt.jl for JLArray specific operations. * Fix Metal errors
1 parent cc9807b commit 59526b2

File tree

17 files changed

+342
-153
lines changed

17 files changed

+342
-153
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,18 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313

1414
[weakdeps]
15+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1516
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1617

1718
[extensions]
19+
DeviceSparseArraysJLArraysExt = "JLArrays"
1820
DeviceSparseArraysReactantExt = "Reactant"
1921

2022
[compat]
2123
AcceleratedKernels = "0.4"
2224
Adapt = "4"
2325
ArrayInterface = "7"
26+
JLArrays = "0.3"
2427
KernelAbstractions = "0.9"
2528
LinearAlgebra = "1"
2629
Reactant = "0.2.164"

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ mul!(c_result, A_device, b)
5656
### GPU Backend Usage
5757

5858
```julia
59+
using Adapt
60+
5961
# For CUDA backend
6062
using CUDA
6163
A_cuda = adapt(CuArray, A_device)

benchmarks/benchmark_utils.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ _synchronize_backend(arr) = nothing # Fallback: no-op for arrays without Kernel
3232
3333
Synchronize KernelAbstractions backend for DeviceSparseArray types.
3434
"""
35-
function _synchronize_backend(arr::AbstractDeviceSparseArray)
36-
backend = KernelAbstractions.get_backend(arr)
35+
_synchronize_backend(arr::AbstractDeviceSparseArray) = _synchronize_backend(nonzeros(arr))
36+
37+
function _synchronize_backend(x::AbstractArray)
38+
backend = KernelAbstractions.get_backend(x)
3739
KernelAbstractions.synchronize(backend)
3840
return nothing
3941
end
42+
_synchronize_backend(x::JLArray) = nothing # No-op for Julia Arrays

benchmarks/matrix_benchmarks.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -242,34 +242,28 @@ Benchmark Kronecker product (kron) for CSC, CSR, and COO formats.
242242
- `N`: Size of the matrices (default: 100)
243243
- `T`: Element type (default: Float64)
244244
"""
245-
function benchmark_kron!(
246-
SUITE,
247-
array_constructor,
248-
array_type_name;
249-
N = 100,
250-
T = Float64,
251-
)
245+
function benchmark_kron!(SUITE, array_constructor, array_type_name; N = 100, T = Float64)
252246
# Create sparse matrices with 1% density (smaller matrices since kron grows quadratically)
253247
sm_a_std = sprand(T, N, N, 0.01)
254248
sm_b_std = sprand(T, N, N, 0.01)
255249

256250
# Convert to different formats
257251
sm_a_csc = DeviceSparseMatrixCSC(sm_a_std)
258252
sm_b_csc = DeviceSparseMatrixCSC(sm_b_std)
259-
253+
260254
sm_a_csr = DeviceSparseMatrixCSR(sm_a_std)
261255
sm_b_csr = DeviceSparseMatrixCSR(sm_b_std)
262-
256+
263257
sm_a_coo = DeviceSparseMatrixCOO(sm_a_std)
264258
sm_b_coo = DeviceSparseMatrixCOO(sm_b_std)
265259

266260
# Adapt to device
267261
dsm_a_csc = adapt(array_constructor, sm_a_csc)
268262
dsm_b_csc = adapt(array_constructor, sm_b_csc)
269-
263+
270264
dsm_a_csr = adapt(array_constructor, sm_a_csr)
271265
dsm_b_csr = adapt(array_constructor, sm_b_csr)
272-
266+
273267
dsm_a_coo = adapt(array_constructor, sm_a_coo)
274268
dsm_b_coo = adapt(array_constructor, sm_b_coo)
275269

@@ -291,4 +285,3 @@ function benchmark_kron!(
291285

292286
return nothing
293287
end
294-
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module DeviceSparseArraysJLArraysExt
2+
3+
using JLArrays: JLArray
4+
import DeviceSparseArrays
5+
6+
DeviceSparseArrays._sortperm_AK(x::JLArray) = JLArray(sortperm(collect(x)))
7+
DeviceSparseArrays._cumsum_AK(x::JLArray) = JLArray(cumsum(collect(x)))
8+
9+
end

src/DeviceSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import LinearAlgebra: wrap, copymutable_oftype, __normalize!, kron
55
using SparseArrays
66
import SparseArrays: SparseVector, SparseMatrixCSC
77
import SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds
8+
import SparseArrays: _show_with_braille_patterns
89

910
import ArrayInterface: allowed_getindex, allowed_setindex!
1011

src/conversions/conversions.jl

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,126 @@
11
# Conversions between CSC, CSR, and COO sparse matrix formats
2-
# All conversions operate on-device
2+
3+
# ============================================================================
4+
# SparseMatrixCSC ↔ DeviceSparseMatrix (CSC-CSR-COO) Conversions
5+
# ============================================================================
6+
7+
DeviceSparseMatrixCSC(A::SparseMatrixCSC) =
8+
DeviceSparseMatrixCSC(size(A, 1), size(A, 2), A.colptr, A.rowval, A.nzval)
9+
10+
SparseMatrixCSC(A::DeviceSparseMatrixCSC) = SparseMatrixCSC(
11+
size(A, 1),
12+
size(A, 2),
13+
collect(A.colptr),
14+
collect(A.rowval),
15+
collect(A.nzval),
16+
)
17+
function SparseMatrixCSC(A::Transpose{Tv,<:DeviceSparseMatrixCSC}) where {Tv}
18+
SparseMatrixCSC(DeviceSparseMatrixCSR(A))
19+
end
20+
function SparseMatrixCSC(A::Adjoint{Tv,<:DeviceSparseMatrixCSC}) where {Tv}
21+
SparseMatrixCSC(DeviceSparseMatrixCSR(A))
22+
end
23+
24+
function DeviceSparseMatrixCSR(A::SparseMatrixCSC)
25+
# TODO: Implement a direct CSC to CSR conversion without going through transposition
26+
At = transpose(A)
27+
At_sparse = transpose(SparseMatrixCSC(At))
28+
return DeviceSparseMatrixCSR(At_sparse)
29+
end
30+
31+
function SparseMatrixCSC(A::DeviceSparseMatrixCSR)
32+
# Convert CSR to CSC by creating transposed CSC and then transposing back
33+
At_csc =
34+
SparseMatrixCSC(A.n, A.m, collect(A.rowptr), collect(A.colval), collect(A.nzval))
35+
return SparseMatrixCSC(transpose(At_csc))
36+
end
37+
function SparseMatrixCSC(A::Transpose{Tv,<:DeviceSparseMatrixCSR}) where {Tv}
38+
At = A.parent
39+
SparseMatrixCSC(At.n, At.m, collect(At.rowptr), collect(At.colval), collect(At.nzval))
40+
end
41+
function SparseMatrixCSC(A::Adjoint{Tv,<:DeviceSparseMatrixCSR}) where {Tv}
42+
At = A.parent
43+
SparseMatrixCSC(
44+
size(A, 1),
45+
size(A, 2),
46+
collect(At.rowptr),
47+
collect(At.colval),
48+
collect(conj.(At.nzval)),
49+
)
50+
end
51+
52+
function DeviceSparseMatrixCOO(A::SparseMatrixCSC)
53+
m, n = size(A)
54+
rows, cols, vals = findnz(A)
55+
return DeviceSparseMatrixCOO(m, n, rows, cols, vals)
56+
end
57+
58+
function SparseMatrixCSC(A::DeviceSparseMatrixCOO)
59+
m, n = size(A)
60+
rowind = collect(A.rowind)
61+
colind = collect(A.colind)
62+
nzval = collect(A.nzval)
63+
64+
return sparse(rowind, colind, nzval, m, n)
65+
end
66+
SparseMatrixCSC(A::Transpose{Tv,<:DeviceSparseMatrixCOO}) where {Tv} = SparseMatrixCSC(
67+
size(A, 1),
68+
size(A, 2),
69+
collect(A.parent.colind),
70+
collect(A.parent.rowind),
71+
collect(A.parent.nzval),
72+
)
73+
SparseMatrixCSC(A::Adjoint{Tv,<:DeviceSparseMatrixCOO}) where {Tv} = SparseMatrixCSC(
74+
size(A, 1),
75+
size(A, 2),
76+
collect(A.parent.colind),
77+
collect(A.parent.rowind),
78+
collect(conj.(A.parent.nzval)),
79+
)
80+
81+
# ============================================================================
82+
# CSC ↔ CSR Conversions
83+
# ============================================================================
84+
85+
DeviceSparseMatrixCSC(A::DeviceSparseMatrixCSR) =
86+
DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(A))
87+
DeviceSparseMatrixCSC(A::Transpose{Tv,<:DeviceSparseMatrixCSR}) where {Tv} =
88+
DeviceSparseMatrixCSC(
89+
size(A, 1),
90+
size(A, 2),
91+
A.parent.rowptr,
92+
A.parent.colval,
93+
A.parent.nzval,
94+
)
95+
DeviceSparseMatrixCSC(A::Adjoint{Tv,<:DeviceSparseMatrixCSR}) where {Tv} =
96+
DeviceSparseMatrixCSC(
97+
size(A, 1),
98+
size(A, 2),
99+
A.parent.rowptr,
100+
A.parent.colval,
101+
conj.(A.parent.nzval),
102+
)
103+
104+
DeviceSparseMatrixCSR(A::DeviceSparseMatrixCSC) =
105+
DeviceSparseMatrixCSR(DeviceSparseMatrixCOO(A))
106+
function DeviceSparseMatrixCSR(
107+
A::Transpose{Tv,<:Union{<:SparseMatrixCSC,<:DeviceSparseMatrixCSC}},
108+
) where {Tv}
109+
At = A.parent
110+
DeviceSparseMatrixCSR(size(A, 1), size(A, 2), At.colptr, rowvals(At), nonzeros(At))
111+
end
112+
function DeviceSparseMatrixCSR(
113+
A::Adjoint{Tv,<:Union{<:SparseMatrixCSC,<:DeviceSparseMatrixCSC}},
114+
) where {Tv}
115+
At = A.parent
116+
DeviceSparseMatrixCSR(
117+
size(A, 1),
118+
size(A, 2),
119+
At.colptr,
120+
rowvals(At),
121+
conj.(nonzeros(At)),
122+
)
123+
end
3124

4125
# ============================================================================
5126
# CSC ↔ COO Conversions
@@ -36,7 +157,8 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
36157
kernel! = kernel_make_csc_keys!(backend)
37158
kernel!(keys, A.rowind, A.colind, m; ndrange = (nnz_count,))
38159

39-
perm = AcceleratedKernels.sortperm(keys)
160+
# Sort - use AcceleratedKernels
161+
perm = _sortperm_AK(keys)
40162

41163
# Apply permutation to get sorted arrays
42164
rowind_sorted = A.rowind[perm]
@@ -53,7 +175,7 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
53175

54176
# Compute cumulative sum
55177
allowed_setindex!(colptr, 1, 1) # TODO: Is there a better way to do this?
56-
colptr[2:end] .= AcceleratedKernels.cumsum(colptr[2:end]) .+ 1
178+
colptr[2:end] .= _cumsum_AK(colptr[2:end]) .+ 1
57179

58180
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
59181
end
@@ -94,7 +216,7 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
94216
kernel!(keys, A.rowind, A.colind, n; ndrange = (nnz_count,))
95217

96218
# Sort - use AcceleratedKernels
97-
perm = AcceleratedKernels.sortperm(keys)
219+
perm = _sortperm_AK(keys)
98220

99221
# Apply permutation to get sorted arrays
100222
rowind_sorted = A.rowind[perm]
@@ -111,7 +233,7 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
111233

112234
# Compute cumulative sum
113235
allowed_setindex!(rowptr, 1, 1) # TODO: Is there a better way to do this?
114-
rowptr[2:end] .= AcceleratedKernels.cumsum(rowptr[2:end]) .+ 1
236+
rowptr[2:end] .= _cumsum_AK(rowptr[2:end]) .+ 1
115237

116238
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
117239
end

src/core.jl

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ const AbstractDeviceSparseMatrix{Tv,Ti} = AbstractDeviceSparseArray{Tv,Ti,2}
1515
const AbstractDeviceSparseVecOrMat{Tv,Ti} =
1616
Union{AbstractDeviceSparseVector{Tv,Ti},AbstractDeviceSparseMatrix{Tv,Ti}}
1717

18+
const AbstractDeviceSparseMatrixInclAdjointAndTranspose = Union{
19+
AbstractDeviceSparseMatrix,
20+
Adjoint{<:Any,<:AbstractDeviceSparseMatrix},
21+
Transpose{<:Any,<:AbstractDeviceSparseMatrix},
22+
}
23+
1824
Base.sum(A::AbstractDeviceSparseArray) = sum(nonzeros(A))
1925

2026
function LinearAlgebra.rmul!(A::AbstractDeviceSparseArray, x::Number)
@@ -43,43 +49,17 @@ end
4349

4450
KernelAbstractions.get_backend(A::AbstractDeviceSparseArray) = get_backend(nonzeros(A))
4551

46-
trans_adj_wrappers(fmt) = (
47-
(T -> :($fmt{$T}), false, false, identity, T -> :($T)),
48-
(T -> :(Transpose{$T,<:$fmt{$T}}), true, false, A -> :(parent($A)), T -> :($T<:Real)),
49-
(
50-
T -> :(Transpose{$T,<:$fmt{$T}}),
51-
true,
52-
false,
53-
A -> :(parent($A)),
54-
T -> :($T<:Complex),
55-
),
56-
(T -> :(Adjoint{$T,<:$fmt{$T}}), true, true, A -> :(parent($A)), T -> :($T)),
57-
)
52+
# called by `show(io, MIME("text/plain"), ::AbstractDeviceSparseMatrixInclAdjointAndTranspose)`
53+
function Base.print_array(io::IO, A::AbstractDeviceSparseMatrixInclAdjointAndTranspose)
54+
S = SparseMatrixCSC(A)
55+
if max(size(S)...) < 16
56+
Base.print_matrix(io, S)
57+
else
58+
_show_with_braille_patterns(io, S)
59+
end
60+
end
5861

5962
# Generic addition between AbstractDeviceSparseMatrix and DenseMatrix
60-
"""
61-
+(A::AbstractDeviceSparseMatrix, B::DenseMatrix)
62-
63-
Add a sparse matrix `A` to a dense matrix `B`, returning a dense matrix.
64-
All backends must be compatible.
65-
66-
# Examples
67-
```jldoctest
68-
julia> using DeviceSparseArrays, SparseArrays
69-
70-
julia> A = DeviceSparseMatrixCSC(sparse([1, 2], [1, 2], [1.0, 2.0], 3, 3));
71-
72-
julia> B = ones(3, 3);
73-
74-
julia> C = A + B;
75-
76-
julia> collect(C)
77-
3×3 Matrix{Float64}:
78-
2.0 1.0 1.0
79-
1.0 3.0 1.0
80-
1.0 1.0 1.0
81-
```
82-
"""
8363
function Base.:+(A::AbstractDeviceSparseMatrix, B::DenseMatrix)
8464
size(A) == size(B) || throw(
8565
DimensionMismatch(
@@ -101,10 +81,18 @@ function Base.:+(A::AbstractDeviceSparseMatrix, B::DenseMatrix)
10181
return C
10282
end
10383

104-
"""
105-
+(B::DenseMatrix, A::AbstractDeviceSparseMatrix)
106-
107-
Add a dense matrix `B` to a sparse matrix `A`, returning a dense matrix.
108-
This is the commutative version of `A + B`.
109-
"""
11084
Base.:+(B::DenseMatrix, A::AbstractDeviceSparseMatrix) = A + B
85+
86+
# Keep this at the end of the file
87+
trans_adj_wrappers(fmt) = (
88+
(T -> :($fmt{$T}), false, false, identity, T -> :($T)),
89+
(T -> :(Transpose{$T,<:$fmt{$T}}), true, false, A -> :(parent($A)), T -> :($T<:Real)),
90+
(
91+
T -> :(Transpose{$T,<:$fmt{$T}}),
92+
true,
93+
false,
94+
A -> :(parent($A)),
95+
T -> :($T<:Complex),
96+
),
97+
(T -> :(Adjoint{$T,<:$fmt{$T}}), true, true, A -> :(parent($A)), T -> :($T)),
98+
)

src/helpers.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ _check_type(::Type{T}, v::AbstractArray{T}) where {T} = true
66
_check_type(::Type{T}, v::AbstractArray) where {T} = false
77

88
_get_eltype(::AbstractArray{T}) where {T} = T
9+
10+
_sortperm_AK(x) = AcceleratedKernels.sortperm(x)
11+
_cumsum_AK(x) = AcceleratedKernels.cumsum(x)

src/matrix_coo/matrix_coo.jl

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ function DeviceSparseMatrixCOO(
7070
} where {Ti<:Integer,Tv}
7171
Ti2 = _get_eltype(rowind)
7272
Tv2 = _get_eltype(nzval)
73-
DeviceSparseMatrixCOO{Tv2,Ti2,RowIndT,ColIndT,NzValT}(m, n, rowind, colind, nzval)
73+
DeviceSparseMatrixCOO{Tv2,Ti2,RowIndT,ColIndT,NzValT}(
74+
m,
75+
n,
76+
copy(rowind),
77+
copy(colind),
78+
copy(nzval),
79+
)
7480
end
7581

7682
# Conversion from SparseMatrixCSC to COO
@@ -95,16 +101,6 @@ function DeviceSparseMatrixCOO(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
95101
return DeviceSparseMatrixCOO(m, n, rowind, colind, nzval)
96102
end
97103

98-
# Conversion from COO to SparseMatrixCSC
99-
function SparseMatrixCSC(A::DeviceSparseMatrixCOO)
100-
m, n = size(A)
101-
rowind = collect(A.rowind)
102-
colind = collect(A.colind)
103-
nzval = collect(A.nzval)
104-
105-
return sparse(rowind, colind, nzval, m, n)
106-
end
107-
108104
Adapt.adapt_structure(to, A::DeviceSparseMatrixCOO) = DeviceSparseMatrixCOO(
109105
A.m,
110106
A.n,

0 commit comments

Comments
 (0)