From 3e780a01bb6a51b47cc93778d111c047b29281a6 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Mon, 17 Nov 2025 03:09:10 +0100 Subject: [PATCH 1/7] Force fields to have same eltype --- Project.toml | 3 -- ext/DeviceSparseArraysReactantExt.jl | 12 -------- src/helpers.jl | 10 +----- src/matrix_coo/matrix_coo.jl | 45 ++++++--------------------- src/matrix_csc/matrix_csc.jl | 46 ++++++++++------------------ src/matrix_csr/matrix_csr.jl | 45 ++++++++++----------------- src/vector/vector.jl | 15 +++------ 7 files changed, 47 insertions(+), 129 deletions(-) delete mode 100644 ext/DeviceSparseArraysReactantExt.jl diff --git a/Project.toml b/Project.toml index 8296a16..e0f891e 100644 --- a/Project.toml +++ b/Project.toml @@ -13,11 +13,9 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [extensions] DeviceSparseArraysJLArraysExt = "JLArrays" -DeviceSparseArraysReactantExt = "Reactant" [compat] AcceleratedKernels = "0.4" @@ -26,6 +24,5 @@ ArrayInterface = "7" JLArrays = "0.3" KernelAbstractions = "0.9" LinearAlgebra = "1" -Reactant = "0.2.164" SparseArrays = "1" julia = "1.10" diff --git a/ext/DeviceSparseArraysReactantExt.jl b/ext/DeviceSparseArraysReactantExt.jl deleted file mode 100644 index 443ba00..0000000 --- a/ext/DeviceSparseArraysReactantExt.jl +++ /dev/null @@ -1,12 +0,0 @@ -module DeviceSparseArraysReactantExt - -import DeviceSparseArrays -import Reactant - -DeviceSparseArrays._check_type( - ::Type{T}, - ::Reactant.RArray{Reactant.TracedRNumber{T}}, -) where {T} = true -DeviceSparseArrays._get_eltype(::Reactant.RArray{Reactant.TracedRNumber{T}}) where {T} = T - -end diff --git a/src/helpers.jl b/src/helpers.jl index 91aea6f..025b005 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -1,11 +1,3 @@ -#= -A method to check that an AbstractArray is of a given element type. -This is needed because we can implement new methods for different arrays (e.g., Reactant.jl) -=# -_check_type(::Type{T}, v::AbstractArray{T}) where {T} = true -_check_type(::Type{T}, v::AbstractArray) where {T} = false - -_get_eltype(::AbstractArray{T}) where {T} = T - +# Helper functions to call AcceleratedKernels methods _sortperm_AK(x) = AcceleratedKernels.sortperm(x) _cumsum_AK(x) = AcceleratedKernels.cumsum(x) diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index 82f22a8..b8d790d 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -1,7 +1,7 @@ # DeviceSparseMatrixCOO implementation """ - DeviceSparseMatrixCOO{Tv,Ti,RowIndT<:AbstractVector,ColIndT<:AbstractVector,NzValT<:AbstractVector} <: AbstractDeviceSparseMatrix{Tv,Ti} + DeviceSparseMatrixCOO{Tv,Ti,RowIndT<:AbstractVector{Ti},ColIndT<:AbstractVector{Ti},NzValT<:AbstractVector{Tv}} <: AbstractDeviceSparseMatrix{Tv,Ti} Coordinate (COO) sparse matrix with generic storage vectors for row indices, column indices, and nonzero values. Buffer types (e.g. `Vector`, GPU array @@ -17,16 +17,17 @@ types) enable dispatch on device characteristics. struct DeviceSparseMatrixCOO{ Tv, Ti<:Integer, - RowIndT<:AbstractVector, - ColIndT<:AbstractVector, - NzValT<:AbstractVector, + RowIndT<:AbstractVector{Ti}, + ColIndT<:AbstractVector{Ti}, + NzValT<:AbstractVector{Tv}, } <: AbstractDeviceSparseMatrix{Tv,Ti} m::Int n::Int rowind::RowIndT colind::ColIndT nzval::NzValT - function DeviceSparseMatrixCOO{Tv,Ti,RowIndT,ColIndT,NzValT}( + + function DeviceSparseMatrixCOO( m::Integer, n::Integer, rowind::RowIndT, @@ -35,9 +36,9 @@ struct DeviceSparseMatrixCOO{ ) where { Tv, Ti<:Integer, - RowIndT<:AbstractVector, - ColIndT<:AbstractVector, - NzValT<:AbstractVector, + RowIndT<:AbstractVector{Ti}, + ColIndT<:AbstractVector{Ti}, + NzValT<:AbstractVector{Tv}, } get_backend(rowind) == get_backend(colind) == get_backend(nzval) || throw(ArgumentError("All storage vectors must be on the same device/backend.")) @@ -46,39 +47,13 @@ struct DeviceSparseMatrixCOO{ n >= 0 || throw(ArgumentError("n must be non-negative")) SparseArrays.sparse_check_Ti(m, n, Ti) - _check_type(Ti, rowind) || throw(ArgumentError("rowind must be of type $Ti")) - _check_type(Ti, colind) || throw(ArgumentError("colind must be of type $Ti")) - _check_type(Tv, nzval) || throw(ArgumentError("nzval must be of type $Tv")) - length(rowind) == length(colind) == length(nzval) || throw(ArgumentError("rowind, colind, and nzval must have same length")) - return new(Int(m), Int(n), rowind, colind, nzval) + return new{Tv,Ti,RowIndT,ColIndT,NzValT}(Int(m), Int(n), rowind, colind, nzval) end end -function DeviceSparseMatrixCOO( - m::Integer, - n::Integer, - rowind::RowIndT, - colind::ColIndT, - nzval::NzValT, -) where { - RowIndT<:AbstractVector{Ti}, - ColIndT<:AbstractVector{Ti}, - NzValT<:AbstractVector{Tv}, -} where {Ti<:Integer,Tv} - Ti2 = _get_eltype(rowind) - Tv2 = _get_eltype(nzval) - DeviceSparseMatrixCOO{Tv2,Ti2,RowIndT,ColIndT,NzValT}( - m, - n, - copy(rowind), - copy(colind), - copy(nzval), - ) -end - # Conversion from SparseMatrixCSC to COO function DeviceSparseMatrixCOO(A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} m, n = size(A) diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index 68091a0..aae304f 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -1,7 +1,7 @@ # DeviceSparseMatrixCSC implementation """ - DeviceSparseMatrixCSC{Tv,Ti,ColPtrT= 0 || throw(ArgumentError("The number of elements must be non-negative.")) length(nzind) == length(nzval) || throw(ArgumentError("index and value vectors must be the same length")) - return new(Int(n), copy(nzind), copy(nzval)) - end -end -# Param inference constructor -function DeviceSparseVector( - n::Integer, - nzind::IndT, - nzval::ValT, -) where {IndT<:AbstractVector{Ti},ValT<:AbstractVector{Tv}} where {Ti<:Integer,Tv} - DeviceSparseVector{Tv,Ti,IndT,ValT}(n, nzind, nzval) + return new{Tv,Ti,IndT,ValT}(Int(n), copy(nzind), copy(nzval)) + end end # Conversions From b330e668782d4c0cd9719b49ef555bcefb866036 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 18 Nov 2025 00:11:53 +0100 Subject: [PATCH 2/7] Relax types --- src/matrix_coo/matrix_coo.jl | 4 ++-- src/matrix_csc/matrix_csc.jl | 6 +++--- src/matrix_csr/matrix_csr.jl | 4 ++-- src/vector/vector.jl | 8 ++------ 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index b8d790d..20d6acf 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -16,7 +16,7 @@ types) enable dispatch on device characteristics. """ struct DeviceSparseMatrixCOO{ Tv, - Ti<:Integer, + Ti, RowIndT<:AbstractVector{Ti}, ColIndT<:AbstractVector{Ti}, NzValT<:AbstractVector{Tv}, @@ -35,7 +35,7 @@ struct DeviceSparseMatrixCOO{ nzval::NzValT, ) where { Tv, - Ti<:Integer, + Ti, RowIndT<:AbstractVector{Ti}, ColIndT<:AbstractVector{Ti}, NzValT<:AbstractVector{Tv}, diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index aae304f..b617e0d 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -16,7 +16,7 @@ types) enable dispatch on device characteristics. """ struct DeviceSparseMatrixCSC{ Tv, - Ti<:Integer, + Ti, ColPtrT<:AbstractVector{Ti}, RowValT<:AbstractVector{Ti}, NzValT<:AbstractVector{Tv}, @@ -35,7 +35,7 @@ struct DeviceSparseMatrixCSC{ nzval::NzValT, ) where { Tv, - Ti<:Integer, + Ti, ColPtrT<:AbstractVector{Ti}, RowValT<:AbstractVector{Ti}, NzValT<:AbstractVector{Tv}, @@ -149,7 +149,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCSC) kernel = kernel_tr(backend) kernel(res, getcolptr(A), getrowval(A), getnzval(A); ndrange = (n,)) - return allowed_getindex(res, 1) + return res #allowed_getindex(res, 1) end # Matrix-Vector and Matrix-Matrix multiplication diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index 56f859e..34c6f80 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -16,7 +16,7 @@ types) enable dispatch on device characteristics. """ struct DeviceSparseMatrixCSR{ Tv, - Ti<:Integer, + Ti, RowPtrT<:AbstractVector{Ti}, ColValT<:AbstractVector{Ti}, NzValT<:AbstractVector{Tv}, @@ -35,7 +35,7 @@ struct DeviceSparseMatrixCSR{ nzval::NzValT, ) where { Tv, - Ti<:Integer, + Ti, RowPtrT<:AbstractVector{Ti}, ColValT<:AbstractVector{Ti}, NzValT<:AbstractVector{Tv}, diff --git a/src/vector/vector.jl b/src/vector/vector.jl index 9ff230b..f0e6682 100644 --- a/src/vector/vector.jl +++ b/src/vector/vector.jl @@ -13,12 +13,8 @@ on different devices. The logical length is stored along with index/value buffer Constructors validate that the index and value vectors have matching length. """ -struct DeviceSparseVector{ - Tv, - Ti<:Integer, - IndT<:AbstractVector{Ti}, - ValT<:AbstractVector{Tv}, -} <: AbstractDeviceSparseVector{Tv,Ti} +struct DeviceSparseVector{Tv,Ti,IndT<:AbstractVector{Ti},ValT<:AbstractVector{Tv}} <: + AbstractDeviceSparseVector{Tv,Ti} n::Int nzind::IndT nzval::ValT From 38ca6c4df265e4ce26d306ced67f63c651b756e4 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio <61953577+albertomercurio@users.noreply.github.com> Date: Tue, 18 Nov 2025 00:13:57 +0100 Subject: [PATCH 3/7] Update src/matrix_csc/matrix_csc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/matrix_csc/matrix_csc.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index b617e0d..a56f414 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -1,7 +1,8 @@ # DeviceSparseMatrixCSC implementation """ - DeviceSparseMatrixCSC{Tv,Ti,ColPtrT<:AbstractVector{Ti},RowValT<:AbstractVector{Ti},NzValT<:AbstractVector{Tv}} <: AbstractDeviceSparseMatrix{Tv,Ti} +""" + DeviceSparseMatrixCSC{Tv,Ti,ColPtrT,RowValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti} Compressed Sparse Column (CSC) matrix with generic storage vectors for column pointer, row indices, and nonzero values. Buffer types (e.g. `Vector`, GPU array From 1ee3e1690d35c562fb1c8cbbb41a7a07033e7e06 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio <61953577+albertomercurio@users.noreply.github.com> Date: Tue, 18 Nov 2025 00:14:41 +0100 Subject: [PATCH 4/7] Update src/matrix_csr/matrix_csr.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/matrix_csr/matrix_csr.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index 34c6f80..908962f 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -1,7 +1,8 @@ # DeviceSparseMatrixCSR implementation """ - DeviceSparseMatrixCSR{Tv,Ti,RowPtrT<:AbstractVector{Ti},ColValT<:AbstractVector{Ti},NzValT<:AbstractVector{Tv}} <: AbstractDeviceSparseMatrix{Tv,Ti} +""" + DeviceSparseMatrixCSR{Tv,Ti,RowPtrT,ColValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti} Compressed Sparse Row (CSR) matrix with generic storage vectors for row pointer, column indices, and nonzero values. Buffer types (e.g. `Vector`, GPU array From 47826bea0a6b461a2ba6537844aea5ea04de4125 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 18 Nov 2025 00:15:27 +0100 Subject: [PATCH 5/7] Copy arrays in COO format --- src/matrix_coo/matrix_coo.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/matrix_coo/matrix_coo.jl b/src/matrix_coo/matrix_coo.jl index 20d6acf..3c6240b 100644 --- a/src/matrix_coo/matrix_coo.jl +++ b/src/matrix_coo/matrix_coo.jl @@ -50,7 +50,13 @@ struct DeviceSparseMatrixCOO{ length(rowind) == length(colind) == length(nzval) || throw(ArgumentError("rowind, colind, and nzval must have same length")) - return new{Tv,Ti,RowIndT,ColIndT,NzValT}(Int(m), Int(n), rowind, colind, nzval) + return new{Tv,Ti,RowIndT,ColIndT,NzValT}( + Int(m), + Int(n), + copy(rowind), + copy(colind), + copy(nzval), + ) end end From 7d7cf75a272635b327aac7d288f18312ffb9322a Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 18 Nov 2025 00:25:20 +0100 Subject: [PATCH 6/7] Fix errors in docstring --- src/matrix_csc/matrix_csc.jl | 1 - src/matrix_csr/matrix_csr.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index a56f414..8ff4dbb 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -1,6 +1,5 @@ # DeviceSparseMatrixCSC implementation -""" """ DeviceSparseMatrixCSC{Tv,Ti,ColPtrT,RowValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti} diff --git a/src/matrix_csr/matrix_csr.jl b/src/matrix_csr/matrix_csr.jl index 908962f..394e0f5 100644 --- a/src/matrix_csr/matrix_csr.jl +++ b/src/matrix_csr/matrix_csr.jl @@ -1,6 +1,5 @@ # DeviceSparseMatrixCSR implementation -""" """ DeviceSparseMatrixCSR{Tv,Ti,RowPtrT,ColValT,NzValT} <: AbstractDeviceSparseMatrix{Tv,Ti} From a8875c4ceadb829e3ce793f55c82ef7212b80942 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 18 Nov 2025 01:01:47 +0100 Subject: [PATCH 7/7] Fix error on `tr` function --- src/matrix_csc/matrix_csc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matrix_csc/matrix_csc.jl b/src/matrix_csc/matrix_csc.jl index 8ff4dbb..a399b11 100644 --- a/src/matrix_csc/matrix_csc.jl +++ b/src/matrix_csc/matrix_csc.jl @@ -149,7 +149,7 @@ function LinearAlgebra.tr(A::DeviceSparseMatrixCSC) kernel = kernel_tr(backend) kernel(res, getcolptr(A), getrowval(A), getnzval(A); ndrange = (n,)) - return res #allowed_getindex(res, 1) + return allowed_getindex(res, 1) end # Matrix-Vector and Matrix-Matrix multiplication