@@ -165,17 +165,15 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
165165 colind_sorted = A. colind[perm]
166166 nzval_sorted = A. nzval[perm]
167167
168- # Build colptr on device using a histogram approach
169- colptr = similar(A. colind, Ti, n + 1 )
170- fill!(colptr, zero(Ti))
171-
172- # Count entries per column
173- kernel! = kernel_count_per_col!(backend)
174- kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
168+ # Build colptr on device using searchsortedfirst approach
169+ # Since colind_sorted is sorted, find where each column starts
170+ col_indices = similar(A. colind, Ti, n)
171+ col_indices .= Ti(1 ): Ti(n)
175172
176- # Compute cumulative sum
177- @allowscalar colptr[1 ] = 1 # TODO : Is there a better way to do this?
178- colptr[2 : end ] .= _cumsum_AK(colptr[2 : end ]) .+ 1
173+ # Find start positions for each column
174+ colptr = similar(A. colind, Ti, n + 1 )
175+ colptr[1 : n] .= _searchsortedfirst_AK(colind_sorted, col_indices)
176+ @allowscalar colptr[n+ 1 ] = Ti(nnz_count + 1 )
179177
180178 return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
181179end
@@ -242,17 +240,15 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
242240 colind_sorted = A. colind[perm]
243241 nzval_sorted = A. nzval[perm]
244242
245- # Build rowptr on device using a histogram approach
246- rowptr = similar(A. rowind, Ti, m + 1 )
247- fill!(rowptr, zero(Ti))
248-
249- # Count entries per row
250- kernel! = kernel_count_per_row!(backend)
251- kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
243+ # Build rowptr on device using searchsortedfirst approach
244+ # Since rowind_sorted is sorted, find where each row starts
245+ row_indices = similar(A. rowind, Ti, m)
246+ row_indices .= Ti(1 ): Ti(m)
252247
253- # Compute cumulative sum
254- @allowscalar rowptr[1 ] = 1 # TODO : Is there a better way to do this?
255- rowptr[2 : end ] .= _cumsum_AK(rowptr[2 : end ]) .+ 1
248+ # Find start positions for each row
249+ rowptr = similar(A. rowind, Ti, m + 1 )
250+ rowptr[1 : m] .= _searchsortedfirst_AK(rowind_sorted, row_indices)
251+ @allowscalar rowptr[m+ 1 ] = Ti(nnz_count + 1 )
256252
257253 return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
258254end
0 commit comments