Skip to content

Commit 68cf40f

Browse files
Fix conversions to work with all backends by using CPU sorting and cumsum
Co-authored-by: albertomercurio <[email protected]>
1 parent c2bad25 commit 68cf40f

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

src/conversions/conversions.jl

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
8888
kernel! = make_keys!(backend)
8989
kernel!(keys, A.rowind, A.colind, n; ndrange = (nnz_count,))
9090

91-
# Sort on device using AcceleratedKernels
92-
perm = AcceleratedKernels.sortperm(keys)
91+
# Sort - collect to CPU and use Base.sortperm since AcceleratedKernels
92+
# doesn't work reliably on all backends (e.g., JLBackend)
93+
keys_cpu = collect(keys)
94+
perm_cpu = sortperm(keys_cpu)
95+
# Adapt back to the original backend
96+
perm = Adapt.adapt_structure(backend, perm_cpu)
9397

9498
# Apply permutation to get sorted arrays
9599
rowind_sorted = A.rowind[perm]
@@ -110,18 +114,13 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
110114
kernel! = count_per_col!(backend)
111115
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
112116

113-
# Set colptr[1] = 1
114-
if backend isa KernelAbstractions.CPU
115-
colptr[1] = 1
116-
# Compute cumulative sum
117-
for i = 2:(n+1)
118-
colptr[i] += colptr[i-1]
119-
end
120-
else
121-
# For non-CPU backends, use AcceleratedKernels scan
122-
colptr[1] = 1
123-
colptr[2:end] .= AcceleratedKernels.cumsum(colptr[2:end]) .+ 1
117+
# Build cumulative sum on CPU (collect, compute, adapt back)
118+
colptr_cpu = collect(colptr)
119+
colptr_cpu[1] = 1
120+
for i = 2:(n+1)
121+
colptr_cpu[i] += colptr_cpu[i-1]
124122
end
123+
colptr = Adapt.adapt_structure(backend, colptr_cpu)
125124

126125
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
127126
end
@@ -213,8 +212,12 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
213212
kernel! = make_keys!(backend)
214213
kernel!(keys, A.rowind, A.colind, m; ndrange = (nnz_count,))
215214

216-
# Sort on device using AcceleratedKernels
217-
perm = AcceleratedKernels.sortperm(keys)
215+
# Sort - collect to CPU and use Base.sortperm since AcceleratedKernels
216+
# doesn't work reliably on all backends (e.g., JLBackend)
217+
keys_cpu = collect(keys)
218+
perm_cpu = sortperm(keys_cpu)
219+
# Adapt back to the original backend
220+
perm = Adapt.adapt_structure(backend, perm_cpu)
218221

219222
# Apply permutation to get sorted arrays
220223
rowind_sorted = A.rowind[perm]
@@ -235,18 +238,13 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
235238
kernel! = count_per_row!(backend)
236239
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
237240

238-
# Set rowptr[1] = 1
239-
if backend isa KernelAbstractions.CPU
240-
rowptr[1] = 1
241-
# Compute cumulative sum
242-
for i = 2:(m+1)
243-
rowptr[i] += rowptr[i-1]
244-
end
245-
else
246-
# For non-CPU backends, use AcceleratedKernels scan
247-
rowptr[1] = 1
248-
rowptr[2:end] .= AcceleratedKernels.cumsum(rowptr[2:end]) .+ 1
241+
# Build cumulative sum on CPU (collect, compute, adapt back)
242+
rowptr_cpu = collect(rowptr)
243+
rowptr_cpu[1] = 1
244+
for i = 2:(m+1)
245+
rowptr_cpu[i] += rowptr_cpu[i-1]
249246
end
247+
rowptr = Adapt.adapt_structure(backend, rowptr_cpu)
250248

251249
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
252250
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4+
DeviceSparseArrays = "da3fe0eb-88a8-4d14-ae1a-857c283e9c70"
45
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
56
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

0 commit comments

Comments
 (0)