Skip to content

Commit 6d40625

Browse files
Fix Metal errors
1 parent 4b2ef8a commit 6d40625

File tree

4 files changed

+19
-34
lines changed

4 files changed

+19
-34
lines changed

ext/DeviceSparseArraysJLArraysExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ import DeviceSparseArrays
55

66
DeviceSparseArrays._sortperm_AK(x::JLArray) = JLArray(sortperm(collect(x)))
77
DeviceSparseArrays._cumsum_AK(x::JLArray) = JLArray(cumsum(collect(x)))
8+
DeviceSparseArrays._searchsortedfirst_AK(v::JLArray, x::JLArray) =
9+
JLArray(searchsortedfirst.(Ref(collect(v)), collect(x)))
810

911
end

src/conversions/conversion_kernels.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,3 @@ end
5555
i = @index(Global)
5656
keys[i] = rowind[i] * n + colind[i]
5757
end
58-
59-
# Kernel for counting entries per column (for COO → CSC)
60-
@kernel inbounds=true function kernel_count_per_col!(colptr, @Const(colind_sorted))
61-
i = @index(Global)
62-
col = colind_sorted[i]
63-
@atomic colptr[col+1] += 1
64-
end
65-
66-
# Kernel for counting entries per row (for COO → CSR)
67-
@kernel inbounds=true function kernel_count_per_row!(rowptr, @Const(rowind_sorted))
68-
i = @index(Global)
69-
row = rowind_sorted[i]
70-
@atomic rowptr[row+1] += 1
71-
end

src/conversions/conversions.jl

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,15 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
165165
colind_sorted = A.colind[perm]
166166
nzval_sorted = A.nzval[perm]
167167

168-
# Build colptr on device using a histogram approach
169-
colptr = similar(A.colind, Ti, n + 1)
170-
fill!(colptr, zero(Ti))
171-
172-
# Count entries per column
173-
kernel! = kernel_count_per_col!(backend)
174-
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
168+
# Build colptr on device using searchsortedfirst approach
169+
# Since colind_sorted is sorted, find where each column starts
170+
col_indices = similar(A.colind, Ti, n)
171+
col_indices .= Ti(1):Ti(n)
175172

176-
# Compute cumulative sum
177-
@allowscalar colptr[1] = 1 # TODO: Is there a better way to do this?
178-
colptr[2:end] .= _cumsum_AK(colptr[2:end]) .+ 1
173+
# Find start positions for each column
174+
colptr = similar(A.colind, Ti, n + 1)
175+
colptr[1:n] .= _searchsortedfirst_AK(colind_sorted, col_indices)
176+
@allowscalar colptr[n+1] = Ti(nnz_count + 1)
179177

180178
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
181179
end
@@ -242,17 +240,15 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
242240
colind_sorted = A.colind[perm]
243241
nzval_sorted = A.nzval[perm]
244242

245-
# Build rowptr on device using a histogram approach
246-
rowptr = similar(A.rowind, Ti, m + 1)
247-
fill!(rowptr, zero(Ti))
248-
249-
# Count entries per row
250-
kernel! = kernel_count_per_row!(backend)
251-
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
243+
# Build rowptr on device using searchsortedfirst approach
244+
# Since rowind_sorted is sorted, find where each row starts
245+
row_indices = similar(A.rowind, Ti, m)
246+
row_indices .= Ti(1):Ti(m)
252247

253-
# Compute cumulative sum
254-
@allowscalar rowptr[1] = 1 # TODO: Is there a better way to do this?
255-
rowptr[2:end] .= _cumsum_AK(rowptr[2:end]) .+ 1
248+
# Find start positions for each row
249+
rowptr = similar(A.rowind, Ti, m + 1)
250+
rowptr[1:m] .= _searchsortedfirst_AK(rowind_sorted, row_indices)
251+
@allowscalar rowptr[m+1] = Ti(nnz_count + 1)
256252

257253
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
258254
end

src/helpers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Helper functions to call AcceleratedKernels methods
22
_sortperm_AK(x) = AcceleratedKernels.sortperm(x)
33
_cumsum_AK(x) = AcceleratedKernels.cumsum(x)
4+
_searchsortedfirst_AK(v, x) = AcceleratedKernels.searchsortedfirst(v, x)

0 commit comments

Comments
 (0)