Skip to content

Commit 131b789

Browse files
Force fields to have same eltype (#28)
* Force fields to have same eltype * Relax types * Update src/matrix_csc/matrix_csc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/matrix_csr/matrix_csr.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Copy arrays in COO format * Fix errors in docstring * Fix error on `tr` function * Refactor to replace `allowed_getindex` with `@allowscalar` for improved scalar handling in matrix operations --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 2f38d9f commit 131b789

File tree

9 files changed

+71
-151
lines changed

9 files changed

+71
-151
lines changed

Project.toml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,28 @@
11
name = "DeviceSparseArrays"
22
uuid = "da3fe0eb-88a8-4d14-ae1a-857c283e9c70"
3-
version = "0.1.0"
43
authors = ["Alberto Mercurio <alberto.mercurio96@gmail.com> and contributors"]
4+
version = "0.1.0"
55

66
[deps]
77
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
88
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
9-
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1010
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313

1414
[weakdeps]
1515
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
16-
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1716

1817
[extensions]
1918
DeviceSparseArraysJLArraysExt = "JLArrays"
20-
DeviceSparseArraysReactantExt = "Reactant"
2119

2220
[compat]
2321
AcceleratedKernels = "0.4"
2422
Adapt = "4"
25-
ArrayInterface = "7"
23+
GPUArraysCore = "0.2.0"
2624
JLArrays = "0.3"
2725
KernelAbstractions = "0.9"
2826
LinearAlgebra = "1"
29-
Reactant = "0.2.164"
3027
SparseArrays = "1"
3128
julia = "1.10"

ext/DeviceSparseArraysReactantExt.jl

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/DeviceSparseArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import SparseArrays: SparseVector, SparseMatrixCSC
77
import SparseArrays: getcolptr, getrowval, getnzval, nonzeroinds
88
import SparseArrays: _show_with_braille_patterns
99

10-
import ArrayInterface: allowed_getindex, allowed_setindex!
10+
import GPUArraysCore: @allowscalar
1111

1212
import KernelAbstractions
1313
import KernelAbstractions:

src/conversions/conversions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ function DeviceSparseMatrixCSC(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
174174
kernel!(colptr, colind_sorted; ndrange = (nnz_count,))
175175

176176
# Compute cumulative sum
177-
allowed_setindex!(colptr, 1, 1) # TODO: Is there a better way to do this?
177+
@allowscalar colptr[1] = 1 # TODO: Is there a better way to do this?
178178
colptr[2:end] .= _cumsum_AK(colptr[2:end]) .+ 1
179179

180180
return DeviceSparseMatrixCSC(m, n, colptr, rowind_sorted, nzval_sorted)
@@ -232,7 +232,7 @@ function DeviceSparseMatrixCSR(A::DeviceSparseMatrixCOO{Tv,Ti}) where {Tv,Ti}
232232
kernel!(rowptr, rowind_sorted; ndrange = (nnz_count,))
233233

234234
# Compute cumulative sum
235-
allowed_setindex!(rowptr, 1, 1) # TODO: Is there a better way to do this?
235+
@allowscalar rowptr[1] = 1 # TODO: Is there a better way to do this?
236236
rowptr[2:end] .= _cumsum_AK(rowptr[2:end]) .+ 1
237237

238238
return DeviceSparseMatrixCSR(m, n, rowptr, colind_sorted, nzval_sorted)

src/helpers.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
#=
2-
A method to check that an AbstractArray is of a given element type.
3-
This is needed because we can implement new methods for different arrays (e.g., Reactant.jl)
4-
=#
5-
_check_type(::Type{T}, v::AbstractArray{T}) where {T} = true
6-
_check_type(::Type{T}, v::AbstractArray) where {T} = false
7-
8-
_get_eltype(::AbstractArray{T}) where {T} = T
9-
1+
# Helper functions to call AcceleratedKernels methods
102
_sortperm_AK(x) = AcceleratedKernels.sortperm(x)
113
_cumsum_AK(x) = AcceleratedKernels.cumsum(x)

src/matrix_coo/matrix_coo.jl

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# DeviceSparseMatrixCOO implementation
22

33
"""
4-
DeviceSparseMatrixCOO{Tv,Ti,RowIndT<:AbstractVector,ColIndT<:AbstractVector,NzValT<:AbstractVector} <: AbstractDeviceSparseMatrix{Tv,Ti}
4+
DeviceSparseMatrixCOO{Tv,Ti,RowIndT<:AbstractVector{Ti},ColIndT<:AbstractVector{Ti},NzValT<:AbstractVector{Tv}} <: AbstractDeviceSparseMatrix{Tv,Ti}
55
66
Coordinate (COO) sparse matrix with generic storage vectors for row indices,
77
column indices, and nonzero values. Buffer types (e.g. `Vector`, GPU array
@@ -16,28 +16,29 @@ types) enable dispatch on device characteristics.
1616
"""
1717
struct DeviceSparseMatrixCOO{
1818
Tv,
19-
Ti<:Integer,
20-
RowIndT<:AbstractVector,
21-
ColIndT<:AbstractVector,
22-
NzValT<:AbstractVector,
19+
Ti,
20+
RowIndT<:AbstractVector{Ti},
21+
ColIndT<:AbstractVector{Ti},
22+
NzValT<:AbstractVector{Tv},
2323
} <: AbstractDeviceSparseMatrix{Tv,Ti}
2424
m::Int
2525
n::Int
2626
rowind::RowIndT
2727
colind::ColIndT
2828
nzval::NzValT
29-
function DeviceSparseMatrixCOO{Tv,Ti,RowIndT,ColIndT,NzValT}(
29+
30+
function DeviceSparseMatrixCOO(
3031
m::Integer,
3132
n::Integer,
3233
rowind::RowIndT,
3334
colind::ColIndT,
3435
nzval::NzValT,
3536
) where {
3637
Tv,
37-
Ti<:Integer,
38-
RowIndT<:AbstractVector,
39-
ColIndT<:AbstractVector,
40-
NzValT<:AbstractVector,
38+
Ti,
39+
RowIndT<:AbstractVector{Ti},
40+
ColIndT<:AbstractVector{Ti},
41+
NzValT<:AbstractVector{Tv},
4142
}
4243
get_backend(rowind) == get_backend(colind) == get_backend(nzval) ||
4344
throw(ArgumentError("All storage vectors must be on the same device/backend."))
@@ -46,39 +47,19 @@ struct DeviceSparseMatrixCOO{
4647
n >= 0 || throw(ArgumentError("n must be non-negative"))
4748
SparseArrays.sparse_check_Ti(m, n, Ti)
4849

49-
_check_type(Ti, rowind) || throw(ArgumentError("rowind must be of type $Ti"))
50-
_check_type(Ti, colind) || throw(ArgumentError("colind must be of type $Ti"))
51-
_check_type(Tv, nzval) || throw(ArgumentError("nzval must be of type $Tv"))
52-
5350
length(rowind) == length(colind) == length(nzval) ||
5451
throw(ArgumentError("rowind, colind, and nzval must have same length"))
5552

56-
return new(Int(m), Int(n), rowind, colind, nzval)
53+
return new{Tv,Ti,RowIndT,ColIndT,NzValT}(
54+
Int(m),
55+
Int(n),
56+
copy(rowind),
57+
copy(colind),
58+
copy(nzval),
59+
)
5760
end
5861
end
5962

60-
function DeviceSparseMatrixCOO(
61-
m::Integer,
62-
n::Integer,
63-
rowind::RowIndT,
64-
colind::ColIndT,
65-
nzval::NzValT,
66-
) where {
67-
RowIndT<:AbstractVector{Ti},
68-
ColIndT<:AbstractVector{Ti},
69-
NzValT<:AbstractVector{Tv},
70-
} where {Ti<:Integer,Tv}
71-
Ti2 = _get_eltype(rowind)
72-
Tv2 = _get_eltype(nzval)
73-
DeviceSparseMatrixCOO{Tv2,Ti2,RowIndT,ColIndT,NzValT}(
74-
m,
75-
n,
76-
copy(rowind),
77-
copy(colind),
78-
copy(nzval),
79-
)
80-
end
81-
8263
# Conversion from SparseMatrixCSC to COO
8364
function DeviceSparseMatrixCOO(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
8465
m, n = size(A)
@@ -179,7 +160,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCOO)
179160
kernel = kernel_tr(backend)
180161
kernel(res, getrowind(A), getcolind(A), nonzeros(A); ndrange = (length(nonzeros(A)),))
181162

182-
return allowed_getindex(res, 1)
163+
return @allowscalar res[1]
183164
end
184165

185166
# Matrix-Vector and Matrix-Matrix multiplication

src/matrix_csc/matrix_csc.jl

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# DeviceSparseMatrixCSC implementation
22

33
"""
4-
DeviceSparseMatrixCSC{Tv,Ti,ColPtrT<RowValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti}
4+
DeviceSparseMatrixCSC{Tv,Ti,ColPtrT,RowValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti}
55
66
Compressed Sparse Column (CSC) matrix with generic storage vectors for column
77
pointer, row indices, and nonzero values. Buffer types (e.g. `Vector`, GPU array
@@ -16,28 +16,29 @@ types) enable dispatch on device characteristics.
1616
"""
1717
struct DeviceSparseMatrixCSC{
1818
Tv,
19-
Ti<:Integer,
20-
ColPtrT<:AbstractVector,
21-
RowValT<:AbstractVector,
22-
NzValT<:AbstractVector,
19+
Ti,
20+
ColPtrT<:AbstractVector{Ti},
21+
RowValT<:AbstractVector{Ti},
22+
NzValT<:AbstractVector{Tv},
2323
} <: AbstractDeviceSparseMatrix{Tv,Ti}
2424
m::Int
2525
n::Int
2626
colptr::ColPtrT
2727
rowval::RowValT
2828
nzval::NzValT
29-
function DeviceSparseMatrixCSC{Tv,Ti,ColPtrT,RowValT,NzValT}(
29+
30+
function DeviceSparseMatrixCSC(
3031
m::Integer,
3132
n::Integer,
3233
colptr::ColPtrT,
3334
rowval::RowValT,
3435
nzval::NzValT,
3536
) where {
3637
Tv,
37-
Ti<:Integer,
38-
ColPtrT<:AbstractVector,
39-
RowValT<:AbstractVector,
40-
NzValT<:AbstractVector,
38+
Ti,
39+
ColPtrT<:AbstractVector{Ti},
40+
RowValT<:AbstractVector{Ti},
41+
NzValT<:AbstractVector{Tv},
4142
}
4243
get_backend(colptr) == get_backend(rowval) == get_backend(nzval) ||
4344
throw(ArgumentError("All storage vectors must be on the same device/backend."))
@@ -47,35 +48,20 @@ struct DeviceSparseMatrixCSC{
4748
SparseArrays.sparse_check_Ti(m, n, Ti)
4849
# SparseArrays.sparse_check(n, colptr, rowval, nzval) # TODO: this uses scalar indexing
4950

50-
_check_type(Ti, colptr) || throw(ArgumentError("colptr must be of type $Ti"))
51-
_check_type(Ti, rowval) || throw(ArgumentError("rowval must be of type $Ti"))
52-
_check_type(Tv, nzval) || throw(ArgumentError("nzval must be of type $Tv"))
53-
5451
length(colptr) == n + 1 || throw(ArgumentError("colptr length must be n+1"))
5552
length(rowval) == length(nzval) ||
5653
throw(ArgumentError("rowval and nzval must have same length"))
5754

58-
return new(Int(m), Int(n), copy(colptr), copy(rowval), copy(nzval))
55+
return new{Tv,Ti,ColPtrT,RowValT,NzValT}(
56+
Int(m),
57+
Int(n),
58+
copy(colptr),
59+
copy(rowval),
60+
copy(nzval),
61+
)
5962
end
6063
end
6164

62-
function DeviceSparseMatrixCSC(
63-
m::Integer,
64-
n::Integer,
65-
colptr::ColPtrT,
66-
rowval::RowValT,
67-
nzval::NzValT,
68-
) where {
69-
ColPtrT<:AbstractVector{Ti},
70-
RowValT<:AbstractVector{Ti},
71-
NzValT<:AbstractVector{Tv},
72-
} where {Ti<:Integer,Tv}
73-
Ti2 = _get_eltype(colptr)
74-
Tv2 = _get_eltype(nzval)
75-
DeviceSparseMatrixCSC{Tv2,Ti2,ColPtrT,RowValT,NzValT}(m, n, colptr, rowval, nzval)
76-
end
77-
78-
7965
Adapt.adapt_structure(to, A::DeviceSparseMatrixCSC) = DeviceSparseMatrixCSC(
8066
A.m,
8167
A.n,
@@ -163,7 +149,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCSC)
163149
kernel = kernel_tr(backend)
164150
kernel(res, getcolptr(A), getrowval(A), getnzval(A); ndrange = (n,))
165151

166-
return allowed_getindex(res, 1)
152+
return @allowscalar res[1]
167153
end
168154

169155
# Matrix-Vector and Matrix-Matrix multiplication

src/matrix_csr/matrix_csr.jl

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# DeviceSparseMatrixCSR implementation
22

33
"""
4-
DeviceSparseMatrixCSR{Tv,Ti,RowPtrT<:ColValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti}
4+
DeviceSparseMatrixCSR{Tv,Ti,RowPtrT,ColValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti}
55
66
Compressed Sparse Row (CSR) matrix with generic storage vectors for row
77
pointer, column indices, and nonzero values. Buffer types (e.g. `Vector`, GPU array
@@ -16,28 +16,29 @@ types) enable dispatch on device characteristics.
1616
"""
1717
struct DeviceSparseMatrixCSR{
1818
Tv,
19-
Ti<:Integer,
20-
RowPtrT<:AbstractVector,
21-
ColValT<:AbstractVector,
22-
NzValT<:AbstractVector,
19+
Ti,
20+
RowPtrT<:AbstractVector{Ti},
21+
ColValT<:AbstractVector{Ti},
22+
NzValT<:AbstractVector{Tv},
2323
} <: AbstractDeviceSparseMatrix{Tv,Ti}
2424
m::Int
2525
n::Int
2626
rowptr::RowPtrT
2727
colval::ColValT
2828
nzval::NzValT
29-
function DeviceSparseMatrixCSR{Tv,Ti,RowPtrT,ColValT,NzValT}(
29+
30+
function DeviceSparseMatrixCSR(
3031
m::Integer,
3132
n::Integer,
3233
rowptr::RowPtrT,
3334
colval::ColValT,
3435
nzval::NzValT,
3536
) where {
3637
Tv,
37-
Ti<:Integer,
38-
RowPtrT<:AbstractVector,
39-
ColValT<:AbstractVector,
40-
NzValT<:AbstractVector,
38+
Ti,
39+
RowPtrT<:AbstractVector{Ti},
40+
ColValT<:AbstractVector{Ti},
41+
NzValT<:AbstractVector{Tv},
4142
}
4243
get_backend(rowptr) == get_backend(colval) == get_backend(nzval) ||
4344
throw(ArgumentError("All storage vectors must be on the same device/backend."))
@@ -47,34 +48,20 @@ struct DeviceSparseMatrixCSR{
4748
SparseArrays.sparse_check_Ti(m, n, Ti)
4849
# SparseArrays.sparse_check(m, rowptr, colval, nzval) # TODO: this uses scalar indexing
4950

50-
_check_type(Ti, rowptr) || throw(ArgumentError("rowptr must be of type $Ti"))
51-
_check_type(Ti, colval) || throw(ArgumentError("colval must be of type $Ti"))
52-
_check_type(Tv, nzval) || throw(ArgumentError("nzval must be of type $Tv"))
53-
5451
length(rowptr) == m + 1 || throw(ArgumentError("rowptr length must be m+1"))
5552
length(colval) == length(nzval) ||
5653
throw(ArgumentError("colval and nzval must have same length"))
5754

58-
return new(Int(m), Int(n), copy(rowptr), copy(colval), copy(nzval))
55+
return new{Tv,Ti,RowPtrT,ColValT,NzValT}(
56+
Int(m),
57+
Int(n),
58+
copy(rowptr),
59+
copy(colval),
60+
copy(nzval),
61+
)
5962
end
6063
end
6164

62-
function DeviceSparseMatrixCSR(
63-
m::Integer,
64-
n::Integer,
65-
rowptr::RowPtrT,
66-
colval::ColValT,
67-
nzval::NzValT,
68-
) where {
69-
RowPtrT<:AbstractVector{Ti},
70-
ColValT<:AbstractVector{Ti},
71-
NzValT<:AbstractVector{Tv},
72-
} where {Ti<:Integer,Tv}
73-
Ti2 = _get_eltype(rowptr)
74-
Tv2 = _get_eltype(nzval)
75-
DeviceSparseMatrixCSR{Tv2,Ti2,RowPtrT,ColValT,NzValT}(m, n, rowptr, colval, nzval)
76-
end
77-
7865
Adapt.adapt_structure(to, A::DeviceSparseMatrixCSR) = DeviceSparseMatrixCSR(
7966
A.m,
8067
A.n,
@@ -163,7 +150,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCSR)
163150
kernel = kernel_tr(backend)
164151
kernel(res, getrowptr(A), getcolval(A), nonzeros(A); ndrange = (m,))
165152

166-
return allowed_getindex(res, 1)
153+
return @allowscalar res[1]
167154
end
168155

169156
# Matrix-Vector and Matrix-Matrix multiplication

0 commit comments

Comments
 (0)