Skip to content

Commit 41830c2

Browse files
Remove JLArray dispatch and fix scalar indexing issue
1 parent ebf8d88 commit 41830c2

File tree

7 files changed

+128
-139
lines changed

7 files changed

+128
-139
lines changed

src/conversions/conversion_kernels.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ end
6060
@kernel inbounds=true function kernel_count_per_col!(colptr, @Const(colind_sorted))
6161
i = @index(Global)
6262
col = colind_sorted[i]
63-
@atomic colptr[col + 1] += 1
63+
@atomic colptr[col+1] += 1
6464
end
6565

6666
# Kernel for counting entries per row (for COO → CSR)
6767
@kernel inbounds=true function kernel_count_per_row!(rowptr, @Const(rowind_sorted))
6868
i = @index(Global)
6969
row = rowind_sorted[i]
70-
@atomic rowptr[row + 1] += 1
70+
@atomic rowptr[row+1] += 1
7171
end

src/conversions/conversions.jl

Lines changed: 10 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Conversions between CSC, CSR, and COO sparse matrix formats
2-
# All conversions operate on-device, with CPU fallback only for JLBackend
3-
4-
# Helper function to check if backend is JLBackend (which doesn't support AcceleratedKernels)
5-
_is_jlbackend(backend) = string(typeof(backend)) == "JLBackend"
2+
# All conversions operate on-device
63

74
# ============================================================================
85
# CSC ↔ COO Conversions
@@ -80,16 +77,7 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
8077
kernel! = kernel_make_csc_keys!(backend)
8178
kernel!(keys, A.rowind, A.colind, n; ndrange = (nnz_count,))
8279

83-
# Sort - use AcceleratedKernels for GPU, CPU fallback for JLBackend
84-
if _is_jlbackend(backend)
85-
# JLBackend doesn't support AcceleratedKernels - use CPU fallback
86-
keys_cpu = collect(keys)
87-
perm_cpu = sortperm(keys_cpu)
88-
perm = Adapt.adapt_structure(backend, perm_cpu)
89-
else
90-
# Use AcceleratedKernels for GPU and standard CPU backends
91-
perm = AcceleratedKernels.sortperm(keys)
92-
end
80+
perm = AcceleratedKernels.sortperm(keys)
9381

9482
# Apply permutation to get sorted arrays
9583
rowind_sorted = A.rowind[perm]
@@ -104,20 +92,9 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
10492
kernel! = kernel_count_per_col!(backend)
10593
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
10694

107-
# Compute cumulative sum - use CPU fallback for JLBackend
108-
if _is_jlbackend(backend) || backend isa KernelAbstractions.CPU
109-
# For CPU-like backends, use CPU cumsum
110-
colptr_cpu = collect(colptr)
111-
colptr_cpu[1] = 1
112-
for i = 2:(n + 1)
113-
colptr_cpu[i] += colptr_cpu[i - 1]
114-
end
115-
colptr = Adapt.adapt_structure(backend, colptr_cpu)
116-
else
117-
# For GPU backends, use AcceleratedKernels scan
118-
colptr[1] = 1
119-
colptr[2:end] .= AcceleratedKernels.cumsum(colptr[2:end]) .+ 1
120-
end
95+
# Compute cumulative sum
96+
allowed_setindex!(colptr, 1, 1) # TODO: Is there a better way to do this?
97+
colptr[2:end] .= AcceleratedKernels.cumsum(colptr[2:end]) .+ 1
12198

12299
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
123100
end
@@ -198,16 +175,8 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
198175
kernel! = kernel_make_csr_keys!(backend)
199176
kernel!(keys, A.rowind, A.colind, m; ndrange = (nnz_count,))
200177

201-
# Sort - use AcceleratedKernels for GPU, CPU fallback for JLBackend
202-
if _is_jlbackend(backend)
203-
# JLBackend doesn't support AcceleratedKernels - use CPU fallback
204-
keys_cpu = collect(keys)
205-
perm_cpu = sortperm(keys_cpu)
206-
perm = Adapt.adapt_structure(backend, perm_cpu)
207-
else
208-
# Use AcceleratedKernels for GPU and standard CPU backends
209-
perm = AcceleratedKernels.sortperm(keys)
210-
end
178+
# Sort - use AcceleratedKernels
179+
perm = AcceleratedKernels.sortperm(keys)
211180

212181
# Apply permutation to get sorted arrays
213182
rowind_sorted = A.rowind[perm]
@@ -222,20 +191,9 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
222191
kernel! = kernel_count_per_row!(backend)
223192
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
224193

225-
# Compute cumulative sum - use CPU fallback for JLBackend
226-
if _is_jlbackend(backend) || backend isa KernelAbstractions.CPU
227-
# For CPU-like backends, use CPU cumsum
228-
rowptr_cpu = collect(rowptr)
229-
rowptr_cpu[1] = 1
230-
for i = 2:(m + 1)
231-
rowptr_cpu[i] += rowptr_cpu[i - 1]
232-
end
233-
rowptr = Adapt.adapt_structure(backend, rowptr_cpu)
234-
else
235-
# For GPU backends, use AcceleratedKernels scan
236-
rowptr[1] = 1
237-
rowptr[2:end] .= AcceleratedKernels.cumsum(rowptr[2:end]) .+ 1
238-
end
194+
# Compute cumulative sum
195+
allowed_setindex!(rowptr, 1, 1) # TODO: Is there a better way to do this?
196+
rowptr[2:end] .= AcceleratedKernels.cumsum(rowptr[2:end]) .+ 1
239197

240198
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)
241199
end

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4-
DeviceSparseArrays = "da3fe0eb-88a8-4d14-ae1a-857c283e9c70"
54
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
65
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
76
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/cuda/cuda.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,11 @@
2727
(Float32, Float64),
2828
(ComplexF32, ComplexF64),
2929
)
30+
shared_test_conversions(
31+
CuArray,
32+
"CUDA",
33+
(Int32, Int64),
34+
(Float32, Float64),
35+
(ComplexF32, ComplexF64),
36+
)
3037
end

test/metal/metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
shared_test_matrix_csc(MtlArray, "Metal", (Int32,), (Float32,), (ComplexF32,))
44
shared_test_matrix_csr(MtlArray, "Metal", (Int32,), (Float32,), (ComplexF32,))
55
shared_test_matrix_coo(MtlArray, "Metal", (Int32,), (Float32,), (ComplexF32,))
6+
shared_test_conversions(MtlArray, "Metal", (Int32,), (Float32,), (ComplexF32,))
67
end

test/reactant/reactant.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,11 @@
2727
(Float32, Float64),
2828
(ComplexF32, ComplexF64),
2929
)
30+
shared_test_conversions(
31+
Reactant.ConcreteRArray,
32+
"Reactant",
33+
(Int32, Int64),
34+
(Float32, Float64),
35+
(ComplexF32, ComplexF64),
36+
)
3037
end

test/shared/conversions.jl

Lines changed: 101 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,90 +6,107 @@ function shared_test_conversions(
66
complex_types::Tuple,
77
)
88
@testset "Format Conversions $array_type" verbose=true begin
9-
# Test CSC → COO → CSC round-trip
10-
@testset "CSC ↔ COO" begin
11-
A = sparse([1, 2, 3, 1, 2], [1, 2, 3, 2, 3], float_types[end][1.0, 2.0, 3.0, 4.0, 5.0], 3, 3)
12-
13-
# CSC → COO
14-
A_csc = adapt(op, DeviceSparseMatrixCSC(A))
15-
A_coo_from_csc = DeviceSparseMatrixCOO(A_csc)
16-
@test collect(SparseMatrixCSC(A_coo_from_csc)) collect(A)
17-
18-
# COO → CSC
19-
A_coo = adapt(op, DeviceSparseMatrixCOO(A))
20-
A_csc_from_coo = DeviceSparseMatrixCSC(A_coo)
21-
@test collect(SparseMatrixCSC(A_csc_from_coo)) collect(A)
22-
23-
# Round-trip
24-
A_csc_roundtrip = DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(A_csc))
25-
@test collect(SparseMatrixCSC(A_csc_roundtrip)) collect(A)
26-
end
27-
28-
# Test CSR → COO → CSR round-trip
29-
@testset "CSR ↔ COO" begin
30-
A = sparse([1, 2, 3, 1, 2], [1, 2, 3, 2, 3], float_types[end][1.0, 2.0, 3.0, 4.0, 5.0], 3, 3)
31-
32-
# CSR → COO
33-
A_csr = adapt(op, DeviceSparseMatrixCSR(A))
34-
A_coo_from_csr = DeviceSparseMatrixCOO(A_csr)
35-
@test collect(SparseMatrixCSC(A_coo_from_csr)) collect(A)
36-
37-
# COO → CSR
38-
A_coo = adapt(op, DeviceSparseMatrixCOO(A))
39-
A_csr_from_coo = DeviceSparseMatrixCSR(A_coo)
40-
@test collect(SparseMatrixCSC(A_csr_from_coo)) collect(A)
41-
42-
# Round-trip
43-
A_csr_roundtrip = DeviceSparseMatrixCSR(DeviceSparseMatrixCOO(A_csr))
44-
@test collect(SparseMatrixCSC(A_csr_roundtrip)) collect(A)
45-
end
46-
47-
# Test with different data types
48-
@testset "Different Types" begin
49-
# Test with Float32
50-
A_f32 = sparse([1, 2], [1, 2], float_types[1][1.0f0, 2.0f0], 2, 2)
51-
A_csc_f32 = adapt(op, DeviceSparseMatrixCSC(A_f32))
52-
A_coo_f32 = DeviceSparseMatrixCOO(A_csc_f32)
53-
@test collect(SparseMatrixCSC(A_coo_f32)) collect(A_f32)
54-
55-
# Test with ComplexF64
56-
A_c64 = sparse([1, 2], [1, 2], complex_types[end][1.0+im, 2.0-im], 2, 2)
57-
A_csr_c64 = adapt(op, DeviceSparseMatrixCSR(A_c64))
58-
A_coo_c64 = DeviceSparseMatrixCOO(A_csr_c64)
59-
@test collect(SparseMatrixCSC(A_coo_c64)) collect(A_c64)
60-
end
61-
62-
# Test with empty matrices
63-
@testset "Edge Cases" begin
64-
# Empty matrix
65-
A_empty = spzeros(float_types[end], 3, 3)
66-
A_csc_empty = adapt(op, DeviceSparseMatrixCSC(A_empty))
67-
A_coo_empty = DeviceSparseMatrixCOO(A_csc_empty)
68-
@test nnz(A_coo_empty) == 0
69-
@test size(A_coo_empty) == (3, 3)
70-
71-
# Single element
72-
A_single = sparse([1], [1], float_types[end][42.0], 1, 1)
73-
A_csr_single = adapt(op, DeviceSparseMatrixCSR(A_single))
74-
A_coo_single = DeviceSparseMatrixCOO(A_csr_single)
75-
@test collect(SparseMatrixCSC(A_coo_single)) collect(A_single)
76-
end
77-
78-
# Test large matrix conversion
79-
@testset "Large Matrix" begin
80-
A_large = sprand(float_types[end], 100, 100, 0.05)
81-
82-
# CSC → COO → CSC
83-
A_csc_large = adapt(op, DeviceSparseMatrixCSC(A_large))
84-
A_coo_large = DeviceSparseMatrixCOO(A_csc_large)
85-
A_csc_back = DeviceSparseMatrixCSC(A_coo_large)
86-
@test collect(SparseMatrixCSC(A_csc_back)) collect(A_large)
87-
88-
# CSR → COO → CSR
89-
A_csr_large = adapt(op, DeviceSparseMatrixCSR(A_large))
90-
A_coo_large2 = DeviceSparseMatrixCOO(A_csr_large)
91-
A_csr_back = DeviceSparseMatrixCSR(A_coo_large2)
92-
@test collect(SparseMatrixCSC(A_csr_back)) collect(A_large)
9+
# Many conversion functions rely on AcceleratedKernels sortperm
10+
# which is not supported on JLBackend. Therefore, we skip conversion
11+
# tests for JLArray.
12+
if array_type != "JLArray"
13+
# Test CSC → COO → CSC round-trip
14+
@testset "CSC ↔ COO" begin
15+
A = sparse(
16+
[1, 2, 3, 1, 2],
17+
[1, 2, 3, 2, 3],
18+
float_types[end][1.0, 2.0, 3.0, 4.0, 5.0],
19+
3,
20+
3,
21+
)
22+
23+
# CSC → COO
24+
A_csc = adapt(op, DeviceSparseMatrixCSC(A))
25+
A_coo_from_csc = DeviceSparseMatrixCOO(A_csc)
26+
@test collect(SparseMatrixCSC(A_coo_from_csc)) collect(A)
27+
28+
# COO → CSC
29+
A_coo = adapt(op, DeviceSparseMatrixCOO(A))
30+
A_csc_from_coo = DeviceSparseMatrixCSC(A_coo)
31+
@test collect(SparseMatrixCSC(A_csc_from_coo)) collect(A)
32+
33+
# Round-trip
34+
A_csc_roundtrip = DeviceSparseMatrixCSC(DeviceSparseMatrixCOO(A_csc))
35+
@test collect(SparseMatrixCSC(A_csc_roundtrip)) collect(A)
36+
end
37+
38+
# Test CSR → COO → CSR round-trip
39+
@testset "CSR ↔ COO" begin
40+
A = sparse(
41+
[1, 2, 3, 1, 2],
42+
[1, 2, 3, 2, 3],
43+
float_types[end][1.0, 2.0, 3.0, 4.0, 5.0],
44+
3,
45+
3,
46+
)
47+
48+
# CSR → COO
49+
A_csr = adapt(op, DeviceSparseMatrixCSR(A))
50+
A_coo_from_csr = DeviceSparseMatrixCOO(A_csr)
51+
@test collect(SparseMatrixCSC(A_coo_from_csr)) collect(A)
52+
53+
# COO → CSR
54+
A_coo = adapt(op, DeviceSparseMatrixCOO(A))
55+
A_csr_from_coo = DeviceSparseMatrixCSR(A_coo)
56+
@test collect(SparseMatrixCSC(A_csr_from_coo)) collect(A)
57+
58+
# Round-trip
59+
A_csr_roundtrip = DeviceSparseMatrixCSR(DeviceSparseMatrixCOO(A_csr))
60+
@test collect(SparseMatrixCSC(A_csr_roundtrip)) collect(A)
61+
end
62+
63+
# Test with different data types
64+
@testset "Different Types" begin
65+
# Test with Float32
66+
A_f32 = sparse([1, 2], [1, 2], float_types[1][1.0f0, 2.0f0], 2, 2)
67+
A_csc_f32 = adapt(op, DeviceSparseMatrixCSC(A_f32))
68+
A_coo_f32 = DeviceSparseMatrixCOO(A_csc_f32)
69+
@test collect(SparseMatrixCSC(A_coo_f32)) collect(A_f32)
70+
71+
# Test with ComplexF64
72+
A_c64 = sparse([1, 2], [1, 2], complex_types[end][1.0+im, 2.0-im], 2, 2)
73+
A_csr_c64 = adapt(op, DeviceSparseMatrixCSR(A_c64))
74+
A_coo_c64 = DeviceSparseMatrixCOO(A_csr_c64)
75+
@test collect(SparseMatrixCSC(A_coo_c64)) collect(A_c64)
76+
end
77+
78+
# Test with empty matrices
79+
@testset "Edge Cases" begin
80+
# Empty matrix
81+
A_empty = spzeros(float_types[end], 3, 3)
82+
A_csc_empty = adapt(op, DeviceSparseMatrixCSC(A_empty))
83+
A_coo_empty = DeviceSparseMatrixCOO(A_csc_empty)
84+
@test nnz(A_coo_empty) == 0
85+
@test size(A_coo_empty) == (3, 3)
86+
87+
# Single element
88+
A_single = sparse([1], [1], float_types[end][42.0], 1, 1)
89+
A_csr_single = adapt(op, DeviceSparseMatrixCSR(A_single))
90+
A_coo_single = DeviceSparseMatrixCOO(A_csr_single)
91+
@test collect(SparseMatrixCSC(A_coo_single)) collect(A_single)
92+
end
93+
94+
# Test large matrix conversion
95+
@testset "Large Matrix" begin
96+
A_large = sprand(float_types[end], 100, 100, 0.05)
97+
98+
# CSC → COO → CSC
99+
A_csc_large = adapt(op, DeviceSparseMatrixCSC(A_large))
100+
A_coo_large = DeviceSparseMatrixCOO(A_csc_large)
101+
A_csc_back = DeviceSparseMatrixCSC(A_coo_large)
102+
@test collect(SparseMatrixCSC(A_csc_back)) collect(A_large)
103+
104+
# CSR → COO → CSR
105+
A_csr_large = adapt(op, DeviceSparseMatrixCSR(A_large))
106+
A_coo_large2 = DeviceSparseMatrixCOO(A_csr_large)
107+
A_csr_back = DeviceSparseMatrixCSR(A_coo_large2)
108+
@test collect(SparseMatrixCSC(A_csr_back)) collect(A_large)
109+
end
93110
end
94111
end
95112
end

0 commit comments

Comments
 (0)