11# Conversions between CSC, CSR, and COO sparse matrix formats
2- # All conversions operate entirely on-device without CPU transfers
2+ # All conversions operate on-device, with CPU fallback only for JLBackend
3+
4+ # Helper function to check if backend is JLBackend (which doesn't support AcceleratedKernels)
5+ _is_jlbackend(backend) = string(typeof(backend)) == " JLBackend"
36
47# ============================================================================
58# CSC ↔ COO Conversions
@@ -71,29 +74,22 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
7174 backend = get_backend(A. nzval)
7275
7376 # Create keys for sorting: column first, then row
74- # We use n * rowind + colind to create a unique sortable key
7577 keys = similar(A. rowind, Ti, nnz_count)
7678
7779 # Create keys on device
78- @kernel inbounds= true function make_keys!(
79- keys,
80- @Const(rowind),
81- @Const(colind),
82- @Const(n)
83- )
84- i = @index(Global)
85- keys[i] = colind[i] * n + rowind[i]
86- end
87-
88- kernel! = make_keys!(backend)
80+ kernel! = kernel_make_csc_keys!(backend)
8981 kernel!(keys, A. rowind, A. colind, n; ndrange = (nnz_count,))
9082
91- # Sort - collect to CPU and use Base.sortperm since AcceleratedKernels
92- # doesn't work reliably on all backends (e.g., JLBackend)
93- keys_cpu = collect(keys)
94- perm_cpu = sortperm(keys_cpu)
95- # Adapt back to the original backend
96- perm = Adapt. adapt_structure(backend, perm_cpu)
83+ # Sort - use AcceleratedKernels for GPU, CPU fallback for JLBackend
84+ if _is_jlbackend(backend)
85+ # JLBackend doesn't support AcceleratedKernels - use CPU fallback
86+ keys_cpu = collect(keys)
87+ perm_cpu = sortperm(keys_cpu)
88+ perm = Adapt. adapt_structure(backend, perm_cpu)
89+ else
90+ # Use AcceleratedKernels for GPU and standard CPU backends
91+ perm = AcceleratedKernels. sortperm(keys)
92+ end
9793
9894 # Apply permutation to get sorted arrays
9995 rowind_sorted = A. rowind[perm]
@@ -105,22 +101,23 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
105101 fill!(colptr, zero(Ti))
106102
107103 # Count entries per column
108- @kernel inbounds= true function count_per_col!(colptr, @Const(colind_sorted))
109- i = @index(Global)
110- col = colind_sorted[i]
111- @atomic colptr[col+ 1 ] += 1
112- end
113-
114- kernel! = count_per_col!(backend)
104+ kernel! = kernel_count_per_col!(backend)
115105 kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
116106
117- # Build cumulative sum on CPU (collect, compute, adapt back)
118- colptr_cpu = collect(colptr)
119- colptr_cpu[1 ] = 1
120- for i = 2 : (n+ 1 )
121- colptr_cpu[i] += colptr_cpu[i- 1 ]
107+ # Compute cumulative sum - use CPU fallback for JLBackend
108+ if _is_jlbackend(backend) || backend isa KernelAbstractions. CPU
109+ # For CPU-like backends, use CPU cumsum
110+ colptr_cpu = collect(colptr)
111+ colptr_cpu[1 ] = 1
112+ for i = 2 : (n + 1 )
113+ colptr_cpu[i] += colptr_cpu[i - 1 ]
114+ end
115+ colptr = Adapt. adapt_structure(backend, colptr_cpu)
116+ else
117+ # For GPU backends, use AcceleratedKernels scan
118+ colptr[1 ] = 1
119+ colptr[2 : end ] .= AcceleratedKernels. cumsum(colptr[2 : end ]) .+ 1
122120 end
123- colptr = Adapt. adapt_structure(backend, colptr_cpu)
124121
125122 return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
126123end
@@ -195,29 +192,22 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
195192 backend = get_backend(A. nzval)
196193
197194 # Create keys for sorting: row first, then column
198- # We use m * colind + rowind to create a unique sortable key
199195 keys = similar(A. rowind, Ti, nnz_count)
200196
201197 # Create keys on device
202- @kernel inbounds= true function make_keys!(
203- keys,
204- @Const(rowind),
205- @Const(colind),
206- @Const(m)
207- )
208- i = @index(Global)
209- keys[i] = rowind[i] * m + colind[i]
210- end
211-
212- kernel! = make_keys!(backend)
198+ kernel! = kernel_make_csr_keys!(backend)
213199 kernel!(keys, A. rowind, A. colind, m; ndrange = (nnz_count,))
214200
215- # Sort - collect to CPU and use Base.sortperm since AcceleratedKernels
216- # doesn't work reliably on all backends (e.g., JLBackend)
217- keys_cpu = collect(keys)
218- perm_cpu = sortperm(keys_cpu)
219- # Adapt back to the original backend
220- perm = Adapt. adapt_structure(backend, perm_cpu)
201+ # Sort - use AcceleratedKernels for GPU, CPU fallback for JLBackend
202+ if _is_jlbackend(backend)
203+ # JLBackend doesn't support AcceleratedKernels - use CPU fallback
204+ keys_cpu = collect(keys)
205+ perm_cpu = sortperm(keys_cpu)
206+ perm = Adapt. adapt_structure(backend, perm_cpu)
207+ else
208+ # Use AcceleratedKernels for GPU and standard CPU backends
209+ perm = AcceleratedKernels. sortperm(keys)
210+ end
221211
222212 # Apply permutation to get sorted arrays
223213 rowind_sorted = A. rowind[perm]
@@ -229,22 +219,23 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
229219 fill!(rowptr, zero(Ti))
230220
231221 # Count entries per row
232- @kernel inbounds= true function count_per_row!(rowptr, @Const(rowind_sorted))
233- i = @index(Global)
234- row = rowind_sorted[i]
235- @atomic rowptr[row+ 1 ] += 1
236- end
237-
238- kernel! = count_per_row!(backend)
222+ kernel! = kernel_count_per_row!(backend)
239223 kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
240224
241- # Build cumulative sum on CPU (collect, compute, adapt back)
242- rowptr_cpu = collect(rowptr)
243- rowptr_cpu[1 ] = 1
244- for i = 2 : (m+ 1 )
245- rowptr_cpu[i] += rowptr_cpu[i- 1 ]
225+ # Compute cumulative sum - use CPU fallback for JLBackend
226+ if _is_jlbackend(backend) || backend isa KernelAbstractions. CPU
227+ # For CPU-like backends, use CPU cumsum
228+ rowptr_cpu = collect(rowptr)
229+ rowptr_cpu[1 ] = 1
230+ for i = 2 : (m + 1 )
231+ rowptr_cpu[i] += rowptr_cpu[i - 1 ]
232+ end
233+ rowptr = Adapt. adapt_structure(backend, rowptr_cpu)
234+ else
235+ # For GPU backends, use AcceleratedKernels scan
236+ rowptr[1 ] = 1
237+ rowptr[2 : end ] .= AcceleratedKernels. cumsum(rowptr[2 : end ]) .+ 1
246238 end
247- rowptr = Adapt. adapt_structure(backend, rowptr_cpu)
248239
249240 return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
250241end
0 commit comments