Skip to content

Commit ebf8d88

Browse files
Move kernels outside functions and use AcceleratedKernels with JLBackend fallback
Co-authored-by: albertomercurio <[email protected]>
1 parent 68cf40f commit ebf8d88

File tree

2 files changed

+90
-63
lines changed

2 files changed

+90
-63
lines changed

src/conversions/conversion_kernels.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,39 @@ end
3333
nzval_out[j] = nzval_in[j]
3434
end
3535
end
36+
37+
# Kernel for creating sort keys for COO → CSC conversion
38+
@kernel inbounds=true function kernel_make_csc_keys!(
39+
keys,
40+
@Const(rowind),
41+
@Const(colind),
42+
@Const(n),
43+
)
44+
i = @index(Global)
45+
keys[i] = colind[i] * n + rowind[i]
46+
end
47+
48+
# Kernel for creating sort keys for COO → CSR conversion
49+
@kernel inbounds=true function kernel_make_csr_keys!(
50+
keys,
51+
@Const(rowind),
52+
@Const(colind),
53+
@Const(m),
54+
)
55+
i = @index(Global)
56+
keys[i] = rowind[i] * m + colind[i]
57+
end
58+
59+
# Kernel for counting entries per column (for COO → CSC)
60+
@kernel inbounds=true function kernel_count_per_col!(colptr, @Const(colind_sorted))
61+
i = @index(Global)
62+
col = colind_sorted[i]
63+
@atomic colptr[col + 1] += 1
64+
end
65+
66+
# Kernel for counting entries per row (for COO → CSR)
67+
@kernel inbounds=true function kernel_count_per_row!(rowptr, @Const(rowind_sorted))
68+
i = @index(Global)
69+
row = rowind_sorted[i]
70+
@atomic rowptr[row + 1] += 1
71+
end

src/conversions/conversions.jl

Lines changed: 54 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Conversions between CSC, CSR, and COO sparse matrix formats
2-
# All conversions operate entirely on-device without CPU transfers
2+
# All conversions operate on-device, with CPU fallback only for JLBackend
3+
4+
# Helper function to check if backend is JLBackend (which doesn't support AcceleratedKernels)
5+
_is_jlbackend(backend) = string(typeof(backend)) == "JLBackend"
36

47
# ============================================================================
58
# CSC ↔ COO Conversions
@@ -71,29 +74,22 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
7174
backend = get_backend(A.nzval)
7275

7376
# Create keys for sorting: column first, then row
74-
# We use n * rowind + colind to create a unique sortable key
7577
keys = similar(A.rowind, Ti, nnz_count)
7678

7779
# Create keys on device
78-
@kernel inbounds=true function make_keys!(
79-
keys,
80-
@Const(rowind),
81-
@Const(colind),
82-
@Const(n)
83-
)
84-
i = @index(Global)
85-
keys[i] = colind[i] * n + rowind[i]
86-
end
87-
88-
kernel! = make_keys!(backend)
80+
kernel! = kernel_make_csc_keys!(backend)
8981
kernel!(keys, A.rowind, A.colind, n; ndrange = (nnz_count,))
9082

91-
# Sort - collect to CPU and use Base.sortperm since AcceleratedKernels
92-
# doesn't work reliably on all backends (e.g., JLBackend)
93-
keys_cpu = collect(keys)
94-
perm_cpu = sortperm(keys_cpu)
95-
# Adapt back to the original backend
96-
perm = Adapt.adapt_structure(backend, perm_cpu)
83+
# Sort - use AcceleratedKernels for GPU, CPU fallback for JLBackend
84+
if _is_jlbackend(backend)
85+
# JLBackend doesn't support AcceleratedKernels - use CPU fallback
86+
keys_cpu = collect(keys)
87+
perm_cpu = sortperm(keys_cpu)
88+
perm = Adapt.adapt_structure(backend, perm_cpu)
89+
else
90+
# Use AcceleratedKernels for GPU and standard CPU backends
91+
perm = AcceleratedKernels.sortperm(keys)
92+
end
9793

9894
# Apply permutation to get sorted arrays
9995
rowind_sorted = A.rowind[perm]
@@ -105,22 +101,23 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
105101
fill!(colptr, zero(Ti))
106102

107103
# Count entries per column
108-
@kernel inbounds=true function count_per_col!(colptr, @Const(colind_sorted))
109-
i = @index(Global)
110-
col = colind_sorted[i]
111-
@atomic colptr[col+1] += 1
112-
end
113-
114-
kernel! = count_per_col!(backend)
104+
kernel! = kernel_count_per_col!(backend)
115105
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
116106

117-
# Build cumulative sum on CPU (collect, compute, adapt back)
118-
colptr_cpu = collect(colptr)
119-
colptr_cpu[1] = 1
120-
for i = 2:(n+1)
121-
colptr_cpu[i] += colptr_cpu[i-1]
107+
# Compute cumulative sum - use CPU fallback for JLBackend
108+
if _is_jlbackend(backend) || backend isa KernelAbstractions.CPU
109+
# For CPU-like backends, use CPU cumsum
110+
colptr_cpu = collect(colptr)
111+
colptr_cpu[1] = 1
112+
for i = 2:(n + 1)
113+
colptr_cpu[i] += colptr_cpu[i - 1]
114+
end
115+
colptr = Adapt.adapt_structure(backend, colptr_cpu)
116+
else
117+
# For GPU backends, use AcceleratedKernels scan
118+
colptr[1] = 1
119+
colptr[2:end] .= AcceleratedKernels.cumsum(colptr[2:end]) .+ 1
122120
end
123-
colptr = Adapt.adapt_structure(backend, colptr_cpu)
124121

125122
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
126123
end
@@ -195,29 +192,22 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
195192
backend = get_backend(A.nzval)
196193

197194
# Create keys for sorting: row first, then column
198-
# We use m * colind + rowind to create a unique sortable key
199195
keys = similar(A.rowind, Ti, nnz_count)
200196

201197
# Create keys on device
202-
@kernel inbounds=true function make_keys!(
203-
keys,
204-
@Const(rowind),
205-
@Const(colind),
206-
@Const(m)
207-
)
208-
i = @index(Global)
209-
keys[i] = rowind[i] * m + colind[i]
210-
end
211-
212-
kernel! = make_keys!(backend)
198+
kernel! = kernel_make_csr_keys!(backend)
213199
kernel!(keys, A.rowind, A.colind, m; ndrange = (nnz_count,))
214200

215-
# Sort - collect to CPU and use Base.sortperm since AcceleratedKernels
216-
# doesn't work reliably on all backends (e.g., JLBackend)
217-
keys_cpu = collect(keys)
218-
perm_cpu = sortperm(keys_cpu)
219-
# Adapt back to the original backend
220-
perm = Adapt.adapt_structure(backend, perm_cpu)
201+
# Sort - use AcceleratedKernels for GPU, CPU fallback for JLBackend
202+
if _is_jlbackend(backend)
203+
# JLBackend doesn't support AcceleratedKernels - use CPU fallback
204+
keys_cpu = collect(keys)
205+
perm_cpu = sortperm(keys_cpu)
206+
perm = Adapt.adapt_structure(backend, perm_cpu)
207+
else
208+
# Use AcceleratedKernels for GPU and standard CPU backends
209+
perm = AcceleratedKernels.sortperm(keys)
210+
end
221211

222212
# Apply permutation to get sorted arrays
223213
rowind_sorted = A.rowind[perm]
@@ -229,22 +219,23 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
229219
fill!(rowptr, zero(Ti))
230220

231221
# Count entries per row
232-
@kernel inbounds=true function count_per_row!(rowptr, @Const(rowind_sorted))
233-
i = @index(Global)
234-
row = rowind_sorted[i]
235-
@atomic rowptr[row+1] += 1
236-
end
237-
238-
kernel! = count_per_row!(backend)
222+
kernel! = kernel_count_per_row!(backend)
239223
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
240224

241-
# Build cumulative sum on CPU (collect, compute, adapt back)
242-
rowptr_cpu = collect(rowptr)
243-
rowptr_cpu[1] = 1
244-
for i = 2:(m+1)
245-
rowptr_cpu[i] += rowptr_cpu[i-1]
225+
# Compute cumulative sum - use CPU fallback for JLBackend
226+
if _is_jlbackend(backend) || backend isa KernelAbstractions.CPU
227+
# For CPU-like backends, use CPU cumsum
228+
rowptr_cpu = collect(rowptr)
229+
rowptr_cpu[1] = 1
230+
for i = 2:(m + 1)
231+
rowptr_cpu[i] += rowptr_cpu[i - 1]
232+
end
233+
rowptr = Adapt.adapt_structure(backend, rowptr_cpu)
234+
else
235+
# For GPU backends, use AcceleratedKernels scan
236+
rowptr[1] = 1
237+
rowptr[2:end] .= AcceleratedKernels.cumsum(rowptr[2:end]) .+ 1
246238
end
247-
rowptr = Adapt.adapt_structure(backend, rowptr_cpu)
248239

249240
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
250241
end

0 commit comments

Comments
 (0)