Skip to content

Commit 18de6ff

Browse files
Refactor to replace allowed_getindex with @allowscalar for improved scalar handling in matrix operations
1 parent a8875c4 commit 18de6ff

File tree

7 files changed

+10
-10
lines changed

7 files changed

+10
-10
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
name = "DeviceSparseArrays"
22
uuid = "da3fe0eb-88a8-4d14-ae1a-857c283e9c70"
3-
version = "0.1.0"
43
authors = ["Alberto Mercurio <alberto.mercurio96@gmail.com> and contributors"]
4+
version = "0.1.0"
55

66
[deps]
77
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
88
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
9-
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1010
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -20,7 +20,7 @@ DeviceSparseArraysJLArraysExt = "JLArrays"
2020
[compat]
2121
AcceleratedKernels = "0.4"
2222
Adapt = "4"
23-
ArrayInterface = "7"
23+
GPUArraysCore = "0.2.0"
2424
JLArrays = "0.3"
2525
KernelAbstractions = "0.9"
2626
LinearAlgebra = "1"

src/DeviceSparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import SparseArrays: SparseVector, SparseMatrixCSC
77
import SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds
88
import SparseArrays: _show_with_braille_patterns
99

10-
import ArrayInterface: allowed_getindex, allowed_setindex!
10+
import GPUArraysCore: @allowscalar
1111

1212
import KernelAbstractions
1313
import KernelAbstractions:

src/conversions/conversions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
174174
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
175175

176176
# Compute cumulative sum
177-
allowed_setindex!(colptr, 1, 1) # TODO: Is there a better way to do this?
177+
@allowscalar colptr[1] = 1 # TODO: Is there a better way to do this?
178178
colptr[2:end] .= _cumsum_AK(colptr[2:end]) .+ 1
179179

180180
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
@@ -232,7 +232,7 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
232232
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
233233

234234
# Compute cumulative sum
235-
allowed_setindex!(rowptr, 1, 1) # TODO: Is there a better way to do this?
235+
@allowscalar rowptr[1] = 1 # TODO: Is there a better way to do this?
236236
rowptr[2:end] .= _cumsum_AK(rowptr[2:end]) .+ 1
237237

238238
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)

src/matrix_coo/matrix_coo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCOO)
160160
kernel = kernel_tr(backend)
161161
kernel(res, getrowind(A), getcolind(A), nonzeros(A); ndrange = (length(nonzeros(A)),))
162162

163-
return allowed_getindex(res, 1)
163+
return @allowscalar res[1]
164164
end
165165

166166
# Matrix-Vector and Matrix-Matrix multiplication

src/matrix_csc/matrix_csc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCSC)
149149
kernel = kernel_tr(backend)
150150
kernel(res, getcolptr(A), getrowval(A), getnzval(A); ndrange = (n,))
151151

152-
return allowed_getindex(res, 1)
152+
return @allowscalar res[1]
153153
end
154154

155155
# Matrix-Vector and Matrix-Matrix multiplication

src/matrix_csr/matrix_csr.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCSR)
150150
kernel = kernel_tr(backend)
151151
kernel(res, getrowptr(A), getcolval(A), nonzeros(A); ndrange = (m,))
152152

153-
return allowed_getindex(res, 1)
153+
return @allowscalar res[1]
154154
end
155155

156156
# Matrix-Vector and Matrix-Matrix multiplication

src/vector/vector.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function LinearAlgebra.dot(x::DeviceSparseVector, y::DenseVector)
135135
kernel = kernel_dot(backend)
136136
kernel(res, nzval, nzind, y; ndrange = (m,))
137137

138-
return allowed_getindex(res, 1)
138+
return @allowscalar res[1]
139139
end
140140
LinearAlgebra.dot(x::DenseVector{T1}, y::DeviceSparseVector{Tv}) where {T1<:Real,Tv<:Real} =
141141
dot(y, x)

0 commit comments

Comments
 (0)