11# Conversions between CSC, CSR, and COO sparse matrix formats
2- # All conversions operate on-device
2+
3+ # ============================================================================
4+ # SparseMatrixCSC ↔ DeviceSparseMatrix (CSC-CSR-COO) Conversions
5+ # ============================================================================
6+
7+ DeviceSparseMatrixCSC(A:: SparseMatrixCSC ) =
8+ DeviceSparseMatrixCSC(size(A, 1 ), size(A, 2 ), A. colptr, A. rowval, A. nzval)
9+
10+ SparseMatrixCSC(A:: DeviceSparseMatrixCSC ) = SparseMatrixCSC(
11+ size(A, 1 ),
12+ size(A, 2 ),
13+ collect(A. colptr),
14+ collect(A. rowval),
15+ collect(A. nzval),
16+ )
17+ function SparseMatrixCSC(A:: Transpose{Tv,<:DeviceSparseMatrixCSC} ) where {Tv}
18+ SparseMatrixCSC(DeviceSparseMatrixCSR(A))
19+ end
20+ function SparseMatrixCSC(A:: Adjoint{Tv,<:DeviceSparseMatrixCSC} ) where {Tv}
21+ SparseMatrixCSC(DeviceSparseMatrixCSR(A))
22+ end
23+
24+ function DeviceSparseMatrixCSR(A:: SparseMatrixCSC )
25+ # TODO : Implement a direct CSC to CSR conversion without going through transposition
26+ At = transpose(A)
27+ At_sparse = transpose(SparseMatrixCSC(At))
28+ return DeviceSparseMatrixCSR(At_sparse)
29+ end
30+
31+ function SparseMatrixCSC(A:: DeviceSparseMatrixCSR )
32+ # Convert CSR to CSC by creating transposed CSC and then transposing back
33+ At_csc =
34+ SparseMatrixCSC(A. n, A. m, collect(A. rowptr), collect(A. colval), collect(A. nzval))
35+ return SparseMatrixCSC(transpose(At_csc))
36+ end
37+ function SparseMatrixCSC(A:: Transpose{Tv,<:DeviceSparseMatrixCSR} ) where {Tv}
38+ At = A. parent
39+ SparseMatrixCSC(At. n, At. m, collect(At. rowptr), collect(At. colval), collect(At. nzval))
40+ end
41+ function SparseMatrixCSC(A:: Adjoint{Tv,<:DeviceSparseMatrixCSR} ) where {Tv}
42+ At = A. parent
43+ SparseMatrixCSC(
44+ size(A, 1 ),
45+ size(A, 2 ),
46+ collect(At. rowptr),
47+ collect(At. colval),
48+ collect(conj.(At. nzval)),
49+ )
50+ end
51+
52+ function DeviceSparseMatrixCOO(A:: SparseMatrixCSC )
53+ m, n = size(A)
54+ rows, cols, vals = findnz(A)
55+ return DeviceSparseMatrixCOO(m, n, rows, cols, vals)
56+ end
57+
58+ function SparseMatrixCSC(A:: DeviceSparseMatrixCOO )
59+ m, n = size(A)
60+ rowind = collect(A. rowind)
61+ colind = collect(A. colind)
62+ nzval = collect(A. nzval)
63+
64+ return sparse(rowind, colind, nzval, m, n)
65+ end
66+ SparseMatrixCSC(A:: Transpose{Tv,<:DeviceSparseMatrixCOO} ) where {Tv} = SparseMatrixCSC(
67+ size(A, 1 ),
68+ size(A, 2 ),
69+ collect(A. parent. colind),
70+ collect(A. parent. rowind),
71+ collect(A. parent. nzval),
72+ )
73+ SparseMatrixCSC(A:: Adjoint{Tv,<:DeviceSparseMatrixCOO} ) where {Tv} = SparseMatrixCSC(
74+ size(A, 1 ),
75+ size(A, 2 ),
76+ collect(A. parent. colind),
77+ collect(A. parent. rowind),
78+ collect(conj.(A. parent. nzval)),
79+ )
80+
81+ # ============================================================================
82+ # CSC ↔ CSR Conversions
83+ # ============================================================================
84+
85+ DeviceSparseMatrixCSC(A:: DeviceSparseMatrixCSR ) =
86+ DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(A))
87+ DeviceSparseMatrixCSC(A:: Transpose{Tv,<:DeviceSparseMatrixCSR} ) where {Tv} =
88+ DeviceSparseMatrixCSC(
89+ size(A, 1 ),
90+ size(A, 2 ),
91+ A. parent. rowptr,
92+ A. parent. colval,
93+ A. parent. nzval,
94+ )
95+ DeviceSparseMatrixCSC(A:: Adjoint{Tv,<:DeviceSparseMatrixCSR} ) where {Tv} =
96+ DeviceSparseMatrixCSC(
97+ size(A, 1 ),
98+ size(A, 2 ),
99+ A. parent. rowptr,
100+ A. parent. colval,
101+ conj.(A. parent. nzval),
102+ )
103+
104+ DeviceSparseMatrixCSR(A:: DeviceSparseMatrixCSC ) =
105+ DeviceSparseMatrixCSR(DeviceSparseMatrixCOO(A))
106+ function DeviceSparseMatrixCSR(
107+ A:: Transpose{Tv,<:Union{<:SparseMatrixCSC,<:DeviceSparseMatrixCSC}} ,
108+ ) where {Tv}
109+ At = A. parent
110+ DeviceSparseMatrixCSR(size(A, 1 ), size(A, 2 ), At. colptr, rowvals(At), nonzeros(At))
111+ end
112+ function DeviceSparseMatrixCSR(
113+ A:: Adjoint{Tv,<:Union{<:SparseMatrixCSC,<:DeviceSparseMatrixCSC}} ,
114+ ) where {Tv}
115+ At = A. parent
116+ DeviceSparseMatrixCSR(
117+ size(A, 1 ),
118+ size(A, 2 ),
119+ At. colptr,
120+ rowvals(At),
121+ conj.(nonzeros(At)),
122+ )
123+ end
3124
4125# ============================================================================
5126# CSC ↔ COO Conversions
@@ -36,7 +157,8 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
36157 kernel! = kernel_make_csc_keys!(backend)
37158 kernel!(keys, A. rowind, A. colind, m; ndrange = (nnz_count,))
38159
39- perm = AcceleratedKernels. sortperm(keys)
160+ # Sort - use AcceleratedKernels
161+ perm = _sortperm_AK(keys)
40162
41163 # Apply permutation to get sorted arrays
42164 rowind_sorted = A. rowind[perm]
@@ -53,7 +175,7 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
53175
54176 # Compute cumulative sum
55177 allowed_setindex!(colptr, 1 , 1 ) # TODO : Is there a better way to do this?
56- colptr[2 : end ] .= AcceleratedKernels . cumsum (colptr[2 : end ]) .+ 1
178+ colptr[2 : end ] .= _cumsum_AK (colptr[2 : end ]) .+ 1
57179
58180 return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
59181end
@@ -94,7 +216,7 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
94216 kernel!(keys, A. rowind, A. colind, n; ndrange = (nnz_count,))
95217
96218 # Sort - use AcceleratedKernels
97- perm = AcceleratedKernels . sortperm (keys)
219+ perm = _sortperm_AK (keys)
98220
99221 # Apply permutation to get sorted arrays
100222 rowind_sorted = A. rowind[perm]
@@ -111,7 +233,7 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
111233
112234 # Compute cumulative sum
113235 allowed_setindex!(rowptr, 1 , 1 ) # TODO : Is there a better way to do this?
114- rowptr[2 : end ] .= AcceleratedKernels . cumsum (rowptr[2 : end ]) .+ 1
236+ rowptr[2 : end ] .= _cumsum_AK (rowptr[2 : end ]) .+ 1
115237
116238 return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
117239end
0 commit comments