-
Notifications
You must be signed in to change notification settings - Fork 87
Sparse GPU array and broadcasting support #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
844f20c to
9a74b4d
Compare
|
I had to |
You mean GPUArrays.jl itself? I wouldn't expect those to be defined in KA.jl (maybe AK.jl, but with different signatures). |
|
Now with all tests uncommented and testing |
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl
index 097f015..4ffe1d4 100644
--- a/lib/JLArrays/src/JLArrays.jl
+++ b/lib/JLArrays/src/JLArrays.jl
@@ -24,7 +24,7 @@ end
module AS
-const Generic = 0
+ const Generic = 0
end
@@ -129,28 +129,32 @@ mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, T
len::Int
nnz::Ti
- function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
- len::Integer) where {Tv, Ti <: Integer}
- new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
+ function JLSparseVector{Tv, Ti}(
+ iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
+ len::Integer
+ ) where {Tv, Ti <: Integer}
+ return new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
end
end
-SparseArrays.nnz(x::JLSparseVector) = x.nnz
-SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
-SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
+SparseArrays.nnz(x::JLSparseVector) = x.nnz
+SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
+SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti}
colPtr::JLArray{Ti, 1}
rowVal::JLArray{Ti, 1}
nzVal::JLArray{Tv, 1}
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
- function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
- nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
- new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
+ function JLSparseMatrixCSC{Tv, Ti}(
+ colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
+ nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+ ) where {Tv, Ti <: Integer}
+ return new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
end
end
-function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
end
SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal))
@@ -160,28 +164,30 @@ JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A
function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
@boundscheck checkbounds(A, i, j)
r1 = Int(@inbounds A.colPtr[j])
- r2 = Int(@inbounds A.colPtr[j+1]-1)
+ r2 = Int(@inbounds A.colPtr[j + 1] - 1)
(r1 > r2) && return zero(Tv)
r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1
- ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
+ return ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
end
mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti}
rowPtr::JLArray{Ti, 1}
colVal::JLArray{Ti, 1}
nzVal::JLArray{Tv, 1}
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
- function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
- nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer}
- new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
+ function JLSparseMatrixCSR{Tv, Ti}(
+ rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
+ nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+ ) where {Tv, Ti <: Integer}
+ return new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
end
end
-function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
end
-function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
+function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
return SparseMatrixCSC(transpose(x_transpose))
end
@@ -199,22 +205,26 @@ function Base.size(g::JLSparseMatrixCSR, d::Integer)
end
JLSparseMatrixCSR(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {Tv} =
- JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
- JLVector(parent(Mat).nzval), size(Mat))
+ JLSparseMatrixCSR(
+ JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
+ JLVector(parent(Mat).nzval), size(Mat)
+)
JLSparseMatrixCSR(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {Tv} =
- JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
- JLVector(conj.(parent(Mat).nzval)), size(Mat))
+ JLSparseMatrixCSR(
+ JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
+ JLVector(conj.(parent(Mat).nzval)), size(Mat)
+)
JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A
function Base.getindex(A::JLSparseMatrixCSR{Tv, Ti}, i0::Integer, i1::Integer) where {Tv, Ti}
@boundscheck checkbounds(A, i0, i1)
c1 = Int(A.rowPtr[i0])
- c2 = Int(A.rowPtr[i0+1]-1)
+ c2 = Int(A.rowPtr[i0 + 1] - 1)
(c1 > c2) && return zero(Tv)
c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward)
(c1 > c2 || A.colVal[c1] != i1) && return zero(Tv)
- nonzeros(A)[c1]
+ return nonzeros(A)[c1]
end
GPUArrays.storage(a::JLArray) = a.data
@@ -230,12 +240,12 @@ GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
GPUArrays.sparse_array_type(sa::JLSparseVector) = JLSparseVector
GPUArrays.sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
-GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray
-GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray
-GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray
-GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
-GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray
-GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
+GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray
+GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray
+GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray
+GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
+GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray
+GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
GPUArrays.csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC
GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
@@ -243,19 +253,19 @@ GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
Base.similar(Mat::JLSparseMatrixCSR) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
Base.similar(Mat::JLSparseMatrixCSR, T::Type) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat), T), size(Mat))
-Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) = JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
-Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) = JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
+Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) = JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
+Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) = JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
-Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
-Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
+Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
+Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
-Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
-Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
+Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
+Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
-Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...)
-Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...)
+Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...)
+Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...)
-JLArray(x::JLSparseVector) = JLArray(collect(SparseVector(x)))
+JLArray(x::JLSparseVector) = JLArray(collect(SparseVector(x)))
JLArray(x::JLSparseMatrixCSC) = JLArray(collect(SparseMatrixCSC(x)))
JLArray(x::JLSparseMatrixCSR) = JLArray(collect(SparseMatrixCSC(x)))
@@ -360,11 +370,11 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv}
- iPtr = JLVector{Ti}(undef, length(xs.nzind))
+ iPtr = JLVector{Ti}(undef, length(xs.nzind))
nzVal = JLVector{Tv}(undef, length(xs.nzval))
copyto!(iPtr, convert(Vector{Ti}, xs.nzind))
copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
- return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),)
+ return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs))
end
Base.length(x::JLSparseVector) = x.len
Base.size(x::JLSparseVector) = (x.len,)
@@ -372,13 +382,13 @@ Base.size(x::JLSparseVector) = (x.len,)
function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
colPtr = JLVector{Ti}(undef, length(xs.colptr))
rowVal = JLVector{Ti}(undef, length(xs.rowval))
- nzVal = JLVector{Tv}(undef, length(xs.nzval))
+ nzVal = JLVector{Tv}(undef, length(xs.nzval))
copyto!(colPtr, convert(Vector{Ti}, xs.colptr))
copyto!(rowVal, convert(Vector{Ti}, xs.rowval))
- copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
+ copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
end
-JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs))
+JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs))
Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
Base.size(x::JLSparseMatrixCSC) = x.dims
@@ -386,10 +396,10 @@ function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
csr_xs = SparseMatrixCSC(transpose(xs))
rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr))
colVal = JLVector{Ti}(undef, length(csr_xs.rowval))
- nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
+ nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr))
copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval))
- copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
+ copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
end
JLSparseMatrixCSR(xs::SparseVector{Tv, Ti}) where {Ti, Tv} = JLSparseMatrixCSR(SparseMatrixCSC(xs))
@@ -410,26 +420,26 @@ function Base.copyto!(dst::JLSparseMatrixCSR, src::JLSparseMatrixCSR)
copyto!(dst.colVal, src.colVal)
copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
dst.nnz = src.nnz
- dst
+ return dst
end
Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
Base.size(x::JLSparseMatrixCSR) = x.dims
function GPUArrays._spadjoint(A::JLSparseMatrixCSR)
Aᴴ = JLSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A)))
- JLSparseMatrixCSR(Aᴴ)
+ return JLSparseMatrixCSR(Aᴴ)
end
function GPUArrays._sptranspose(A::JLSparseMatrixCSR)
Aᵀ = JLSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A)))
- JLSparseMatrixCSR(Aᵀ)
+ return JLSparseMatrixCSR(Aᵀ)
end
function _spadjoint(A::JLSparseMatrixCSC)
Aᴴ = JLSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A)))
- JLSparseMatrixCSC(Aᴴ)
+ return JLSparseMatrixCSC(Aᴴ)
end
function _sptranspose(A::JLSparseMatrixCSC)
Aᵀ = JLSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
- JLSparseMatrixCSC(Aᵀ)
+ return JLSparseMatrixCSC(Aᵀ)
end
# idempotency
@@ -573,17 +583,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
R
end
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} =
+ GPUSparseDeviceMatrixCSC{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} =
+ GPUSparseDeviceMatrixCSR{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv, Ti}) where {Tv, Ti} =
+ GPUSparseDeviceVector{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
## KernelAbstractions interface
KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
-KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend()
+KernelAbstractions.get_backend(a::JLA) where {JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector}} = JLBackend()
function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
diff --git a/src/device/sparse.jl b/src/device/sparse.jl
index b8346ea..22b4752 100644
--- a/src/device/sparse.jl
+++ b/src/device/sparse.jl
@@ -10,12 +10,12 @@ using SparseArrays
# core types
export GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR,
- GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
+ GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
abstract type AbstractGPUSparseDeviceMatrix{Tv, Ti} <: AbstractSparseMatrix{Tv, Ti} end
-struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv, A} <: AbstractSparseVector{Tv,Ti}
+struct GPUSparseDeviceVector{Tv, Ti, Vi, Vv, A} <: AbstractSparseVector{Tv, Ti}
iPtr::Vi
nzVal::Vv
len::Int
@@ -28,11 +28,11 @@ SparseArrays.nnz(g::GPUSparseDeviceVector) = g.nnz
SparseArrays.nonzeroinds(g::GPUSparseDeviceVector) = g.iPtr
SparseArrays.nonzeros(g::GPUSparseDeviceVector) = g.nzVal
-struct GPUSparseDeviceMatrixCSC{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
+struct GPUSparseDeviceMatrixCSC{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
colPtr::Vi
rowVal::Vi
nzVal::Vv
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
end
@@ -41,7 +41,7 @@ Base.size(g::GPUSparseDeviceMatrixCSC) = g.dims
SparseArrays.nnz(g::GPUSparseDeviceMatrixCSC) = g.nnz
SparseArrays.rowvals(g::GPUSparseDeviceMatrixCSC) = g.rowVal
SparseArrays.getcolptr(g::GPUSparseDeviceMatrixCSC) = g.colPtr
-SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = @inbounds SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)
+SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = @inbounds SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col + 1] - 1)
SparseArrays.nonzeros(g::GPUSparseDeviceMatrixCSC) = g.nzVal
const GPUSparseDeviceColumnView{Tv, Ti, Vi, Vv, A} = SubArray{Tv, 1, GPUSparseDeviceMatrixCSC{Tv, Ti, Vi, Vv, A}, Tuple{Base.Slice{Base.OneTo{Int}}, Int}}
@@ -66,7 +66,7 @@ function SparseArrays.nnz(x::GPUSparseDeviceColumnView)
return length(SparseArrays.nzrange(A, colidx))
end
-struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCSR{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
rowPtr::Vi
colVal::Vi
nzVal::Vv
@@ -74,9 +74,13 @@ struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{
nnz::Ti
end
-@inline function _getindex(arg::Union{GPUSparseDeviceMatrixCSR{Tv},
- GPUSparseDeviceMatrixCSC{Tv},
- GPUSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
+@inline function _getindex(
+ arg::Union{
+ GPUSparseDeviceMatrixCSR{Tv},
+ GPUSparseDeviceMatrixCSC{Tv},
+ GPUSparseDeviceVector{Tv},
+ }, I, ptr
+ )::Tv where {Tv}
if ptr == 0
return zero(Tv)
else
@@ -84,21 +88,21 @@ end
end
end
-struct GPUSparseDeviceMatrixBSR{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixBSR{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
rowPtr::Vi
colVal::Vi
nzVal::Vv
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
blockDim::Ti
dir::Char
nnz::Ti
end
-struct GPUSparseDeviceMatrixCOO{Tv,Ti,Vi,Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCOO{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
rowInd::Vi
colInd::Vi
nzVal::Vv
- dims::NTuple{2,Int}
+ dims::NTuple{2, Int}
nnz::Ti
end
@@ -115,9 +119,9 @@ struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M, A} <: AbstractSparseArray{T
nnz::Ti
end
-function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, Vi<:AbstractDeviceArray{<:Integer,M}, Vv<:AbstractDeviceArray{Tv, M}, N}
+function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N, <:Integer}) where {Tv, Ti <: Integer, M, Vi <: AbstractDeviceArray{<:Integer, M}, Vv <: AbstractDeviceArray{Tv, M}, N}
@assert M == N - 1 "GPUSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
- GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
+ return GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
end
Base.length(g::GPUSparseDeviceArrayCSR) = prod(g.dims)
@@ -130,42 +134,42 @@ SparseArrays.getnzval(g::GPUSparseDeviceArrayCSR) = g.nzVal
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceVector)
println(io, "$(length(A))-element device sparse vector at:")
println(io, " iPtr: $(A.iPtr)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSR)
println(io, "$(length(A))-element device sparse matrix CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSC)
println(io, "$(length(A))-element device sparse matrix CSC at:")
println(io, " colPtr: $(A.colPtr)")
println(io, " rowVal: $(A.rowVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixBSR)
println(io, "$(length(A))-element device sparse matrix BSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCOO)
println(io, "$(length(A))-element device sparse matrix COO at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colInd: $(A.colInd)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceArrayCSR)
println(io, "$(length(A))-element device sparse array CSR at:")
println(io, " rowPtr: $(A.rowPtr)")
println(io, " colVal: $(A.colVal)")
- print(io, " nzVal: $(A.nzVal)")
+ return print(io, " nzVal: $(A.nzVal)")
end
# COV_EXCL_STOP
diff --git a/src/host/sparse.jl b/src/host/sparse.jl
index 1bee2ce..80e816d 100644
--- a/src/host/sparse.jl
+++ b/src/host/sparse.jl
@@ -10,13 +10,13 @@ abstract type AbstractGPUSparseMatrixCSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, T
abstract type AbstractGPUSparseMatrixCOO{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
abstract type AbstractGPUSparseMatrixBSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
-const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector,AbstractGPUSparseMatrix}
+const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector, AbstractGPUSparseMatrix}
-SparseArrays.nnz(g::T) where {T<:AbstractGPUSparseArray} = g.nnz
-SparseArrays.nonzeros(g::T) where {T<:AbstractGPUSparseArray} = g.nzVal
+SparseArrays.nnz(g::T) where {T <: AbstractGPUSparseArray} = g.nnz
+SparseArrays.nonzeros(g::T) where {T <: AbstractGPUSparseArray} = g.nzVal
-SparseArrays.nonzeroinds(g::T) where {T<:AbstractGPUSparseVector} = g.iPtr
-SparseArrays.rowvals(g::T) where {T<:AbstractGPUSparseVector} = SparseArrays.nonzeroinds(g)
+SparseArrays.nonzeroinds(g::T) where {T <: AbstractGPUSparseVector} = g.iPtr
+SparseArrays.rowvals(g::T) where {T <: AbstractGPUSparseVector} = SparseArrays.nonzeroinds(g)
SparseArrays.rowvals(g::AbstractGPUSparseMatrixCSC) = g.rowVal
SparseArrays.getcolptr(S::AbstractGPUSparseMatrixCSC) = S.colPtr
@@ -30,7 +30,7 @@ Base.collect(x::AbstractGPUSparseMatrixCSR) = collect(SparseMatrixCSC(x))
Base.collect(x::AbstractGPUSparseMatrixBSR) = collect(SparseMatrixCSC(x))
Base.collect(x::AbstractGPUSparseMatrixCOO) = collect(SparseMatrixCSC(x))
-Base.Array(x::AbstractGPUSparseVector) = collect(SparseVector(x))
+Base.Array(x::AbstractGPUSparseVector) = collect(SparseVector(x))
Base.Array(x::AbstractGPUSparseMatrixCSC) = collect(SparseMatrixCSC(x))
Base.Array(x::AbstractGPUSparseMatrixCSR) = collect(SparseMatrixCSC(x))
Base.Array(x::AbstractGPUSparseMatrixBSR) = collect(SparseMatrixCSC(x))
@@ -40,44 +40,44 @@ SparseArrays.SparseVector(x::AbstractGPUSparseVector) = SparseVector(length(x),
SparseArrays.SparseMatrixCSC(x::AbstractGPUSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(SparseArrays.getcolptr(x)), Array(SparseArrays.rowvals(x)), Array(SparseArrays.nonzeros(x)))
# similar
-Base.similar(Vec::V) where {V<:AbstractGPUSparseVector} = V(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec)), length(Vec))
-Base.similar(Mat::M) where {M<:AbstractGPUSparseMatrixCSC} = M(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat)), size(Mat))
+Base.similar(Vec::V) where {V <: AbstractGPUSparseVector} = V(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec)), length(Vec))
+Base.similar(Mat::M) where {M <: AbstractGPUSparseMatrixCSC} = M(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat)), size(Mat))
-Base.similar(Vec::V, T::Type) where {Tv, Ti, V<:AbstractGPUSparseVector{Tv, Ti}} = sparse_array_type(V){T, Ti}(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec), T), length(Vec))
-Base.similar(Mat::M, T::Type) where {M<:AbstractGPUSparseMatrixCSC} = sparse_array_type(M)(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat), T), size(Mat))
+Base.similar(Vec::V, T::Type) where {Tv, Ti, V <: AbstractGPUSparseVector{Tv, Ti}} = sparse_array_type(V){T, Ti}(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec), T), length(Vec))
+Base.similar(Mat::M, T::Type) where {M <: AbstractGPUSparseMatrixCSC} = sparse_array_type(M)(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat), T), size(Mat))
-dense_array_type(sa::SparseVector) = SparseVector
+dense_array_type(sa::SparseVector) = SparseVector
dense_array_type(::Type{SparseVector}) = SparseVector
sparse_array_type(sa::SparseVector) = SparseVector
dense_vector_type(sa::AbstractSparseArray) = Vector
-dense_vector_type(sa::AbstractArray) = Vector
+dense_vector_type(sa::AbstractArray) = Vector
dense_vector_type(::Type{<:AbstractSparseArray}) = Vector
-dense_vector_type(::Type{<:AbstractArray}) = Vector
-dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
+dense_vector_type(::Type{<:AbstractArray}) = Vector
+dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
dense_array_type(::Type{SparseMatrixCSC}) = SparseMatrixCSC
-sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
+sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
function sparse_array_type(sa::AbstractGPUSparseArray) end
function dense_array_type(sa::AbstractGPUSparseArray) end
function coo_type(sa::AbstractGPUSparseArray) end
-coo_type(::SA) where {SA<:AbstractGPUSparseMatrixCSC} = SA
+coo_type(::SA) where {SA <: AbstractGPUSparseMatrixCSC} = SA
function _spadjoint end
function _sptranspose end
-function LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSR, p::Real=2)
+function LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSR, p::Real = 2)
if p == Inf
- return maximum(sum(abs, A; dims=2))
+ return maximum(sum(abs, A; dims = 2))
elseif p == 1
- return maximum(sum(abs, A; dims=1))
+ return maximum(sum(abs, A; dims = 1))
else
throw(ArgumentError("p=$p is not supported"))
end
end
-LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSC, p::Real=2) = opnorm(_csr_type(A)(A), p)
+LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSC, p::Real = 2) = opnorm(_csr_type(A)(A), p)
-function LinearAlgebra.norm(A::AbstractGPUSparseMatrix{T}, p::Real=2) where T
+function LinearAlgebra.norm(A::AbstractGPUSparseMatrix{T}, p::Real = 2) where {T}
if p == Inf
return maximum(abs.(SparseArrays.nonzeros(A)))
elseif p == -Inf
@@ -85,7 +85,7 @@ function LinearAlgebra.norm(A::AbstractGPUSparseMatrix{T}, p::Real=2) where T
elseif p == 0
return Float64(SparseArrays.nnz(A))
else
- return sum(abs.(SparseArrays.nonzeros(A)).^p)^(1/p)
+ return sum(abs.(SparseArrays.nonzeros(A)) .^ p)^(1 / p)
end
end
@@ -105,22 +105,24 @@ function SparseArrays.findnz(S::MT) where {MT <: AbstractGPUSparseMatrix}
end
### WRAPPED ARRAYS
-LinearAlgebra.issymmetric(M::Union{AbstractGPUSparseMatrixCSC,AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - transpose(M), Inf) == 0 : false
-LinearAlgebra.ishermitian(M::Union{AbstractGPUSparseMatrixCSC,AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - adjoint(M), Inf) == 0 : false
+LinearAlgebra.issymmetric(M::Union{AbstractGPUSparseMatrixCSC, AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - transpose(M), Inf) == 0 : false
+LinearAlgebra.ishermitian(M::Union{AbstractGPUSparseMatrixCSC, AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - adjoint(M), Inf) == 0 : false
-LinearAlgebra.istriu(M::UpperTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
-LinearAlgebra.istril(M::UpperTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
-LinearAlgebra.istriu(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
-LinearAlgebra.istril(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
+LinearAlgebra.istriu(M::UpperTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
+LinearAlgebra.istril(M::UpperTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
+LinearAlgebra.istriu(M::LowerTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
+LinearAlgebra.istril(M::LowerTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
-Hermitian{T}(Mat::AbstractGPUSparseMatrix{T}) where {T} = Hermitian{eltype(Mat),typeof(Mat)}(Mat,'U')
+Hermitian{T}(Mat::AbstractGPUSparseMatrix{T}) where {T} = Hermitian{eltype(Mat), typeof(Mat)}(Mat, 'U')
# work around upstream breakage from JuliaLang/julia#55547
@static if VERSION >= v"1.11.2"
const GPUSparseUpperOrUnitUpperTriangular = LinearAlgebra.UpperOrUnitUpperTriangular{
- <:Any,<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}}
+ <:Any, <:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}},
+ }
const GPUSparseLowerOrUnitLowerTriangular = LinearAlgebra.LowerOrUnitLowerTriangular{
- <:Any,<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}}
+ <:Any, <:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}},
+ }
LinearAlgebra.istriu(::GPUSparseUpperOrUnitUpperTriangular) = true
LinearAlgebra.istril(::GPUSparseUpperOrUnitUpperTriangular) = false
LinearAlgebra.istriu(::GPUSparseLowerOrUnitLowerTriangular) = false
@@ -129,67 +131,67 @@ end
for SparseMatrixType in [:AbstractGPUSparseMatrixCSC, :AbstractGPUSparseMatrixCSR]
@eval begin
- LinearAlgebra.triu(A::ST, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
- ST( triu(coo_type(A)(A), k) )
- LinearAlgebra.triu(A::Transpose{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
- ST( triu(coo_type(A)(_sptranspose(parent(A))), k) )
- LinearAlgebra.triu(A::Adjoint{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
- ST( triu(coo_type(A)(_spadjoint(parent(A))), k) )
-
- LinearAlgebra.tril(A::ST, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
- ST( tril(coo_type(A)(A), k) )
- LinearAlgebra.tril(A::Transpose{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
- ST( tril(coo_type(A)(_sptranspose(parent(A))), k) )
- LinearAlgebra.tril(A::Adjoint{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
- ST( tril(coo_type(A)(_spadjoint(parent(A))), k) )
-
- LinearAlgebra.triu(A::Union{ST, Transpose{T,<:ST}, Adjoint{T,<:ST}}) where {T, ST<:$SparseMatrixType{T}} =
- ST( triu(coo_type(A)(A), 0) )
- LinearAlgebra.tril(A::Union{ST,Transpose{T,<:ST}, Adjoint{T,<:ST}}) where {T, ST<:$SparseMatrixType{T}} =
- ST( tril(coo_type(A)(A), 0) )
-
- LinearAlgebra.kron(A::ST, B::ST) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(A), coo_type(B)(B)) )
- LinearAlgebra.kron(A::ST, B::Diagonal) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(A), B) )
- LinearAlgebra.kron(A::Diagonal, B::ST) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(A, coo_type(B)(B)) )
-
- LinearAlgebra.kron(A::Transpose{T,<:ST}, B::ST) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(_sptranspose(parent(A))), coo_type(B)(B)) )
- LinearAlgebra.kron(A::ST, B::Transpose{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(A), coo_type(parent(B))(_sptranspose(parent(B)))) )
- LinearAlgebra.kron(A::Transpose{T,<:ST}, B::Transpose{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(_sptranspose(parent(A))), coo_type(parent(B))(_sptranspose(parent(B)))) )
- LinearAlgebra.kron(A::Transpose{T,<:ST}, B::Diagonal) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(_sptranspose(parent(A))), B) )
- LinearAlgebra.kron(A::Diagonal, B::Transpose{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(A, coo_type(B)(_sptranspose(parent(B)))) )
-
- LinearAlgebra.kron(A::Adjoint{T,<:ST}, B::ST) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(_spadjoint(parent(A))), coo_type(B)(B)) )
- LinearAlgebra.kron(A::ST, B::Adjoint{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(A)(A), coo_type(parent(B))(_spadjoint(parent(B)))) )
- LinearAlgebra.kron(A::Adjoint{T,<:ST}, B::Adjoint{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(parent(A))(_spadjoint(parent(A))), coo_type(parent(B))(_spadjoint(parent(B)))) )
- LinearAlgebra.kron(A::Adjoint{T,<:ST}, B::Diagonal) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(coo_type(parent(A))(_spadjoint(parent(A))), B) )
- LinearAlgebra.kron(A::Diagonal, B::Adjoint{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
- ST( kron(A, coo_type(parent(B))(_spadjoint(parent(B)))) )
-
-
- function Base.reshape(A::ST, dims::Dims) where {ST<:$SparseMatrixType}
+ LinearAlgebra.triu(A::ST, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+ ST(triu(coo_type(A)(A), k))
+ LinearAlgebra.triu(A::Transpose{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+ ST(triu(coo_type(A)(_sptranspose(parent(A))), k))
+ LinearAlgebra.triu(A::Adjoint{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+ ST(triu(coo_type(A)(_spadjoint(parent(A))), k))
+
+ LinearAlgebra.tril(A::ST, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+ ST(tril(coo_type(A)(A), k))
+ LinearAlgebra.tril(A::Transpose{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+ ST(tril(coo_type(A)(_sptranspose(parent(A))), k))
+ LinearAlgebra.tril(A::Adjoint{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+ ST(tril(coo_type(A)(_spadjoint(parent(A))), k))
+
+ LinearAlgebra.triu(A::Union{ST, Transpose{T, <:ST}, Adjoint{T, <:ST}}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(triu(coo_type(A)(A), 0))
+ LinearAlgebra.tril(A::Union{ST, Transpose{T, <:ST}, Adjoint{T, <:ST}}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(tril(coo_type(A)(A), 0))
+
+ LinearAlgebra.kron(A::ST, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(A), coo_type(B)(B)))
+ LinearAlgebra.kron(A::ST, B::Diagonal) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(A), B))
+ LinearAlgebra.kron(A::Diagonal, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(A, coo_type(B)(B)))
+
+ LinearAlgebra.kron(A::Transpose{T, <:ST}, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(_sptranspose(parent(A))), coo_type(B)(B)))
+ LinearAlgebra.kron(A::ST, B::Transpose{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(A), coo_type(parent(B))(_sptranspose(parent(B)))))
+ LinearAlgebra.kron(A::Transpose{T, <:ST}, B::Transpose{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(_sptranspose(parent(A))), coo_type(parent(B))(_sptranspose(parent(B)))))
+ LinearAlgebra.kron(A::Transpose{T, <:ST}, B::Diagonal) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(_sptranspose(parent(A))), B))
+ LinearAlgebra.kron(A::Diagonal, B::Transpose{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(A, coo_type(B)(_sptranspose(parent(B)))))
+
+ LinearAlgebra.kron(A::Adjoint{T, <:ST}, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(_spadjoint(parent(A))), coo_type(B)(B)))
+ LinearAlgebra.kron(A::ST, B::Adjoint{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(A)(A), coo_type(parent(B))(_spadjoint(parent(B)))))
+ LinearAlgebra.kron(A::Adjoint{T, <:ST}, B::Adjoint{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(parent(A))(_spadjoint(parent(A))), coo_type(parent(B))(_spadjoint(parent(B)))))
+ LinearAlgebra.kron(A::Adjoint{T, <:ST}, B::Diagonal) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(coo_type(parent(A))(_spadjoint(parent(A))), B))
+ LinearAlgebra.kron(A::Diagonal, B::Adjoint{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+ ST(kron(A, coo_type(parent(B))(_spadjoint(parent(B)))))
+
+
+ function Base.reshape(A::ST, dims::Dims) where {ST <: $SparseMatrixType}
B = coo_type(A)(A)
- ST(reshape(B, dims))
+ return ST(reshape(B, dims))
end
- function SparseArrays.droptol!(A::ST, tol::Real) where {ST<:$SparseMatrixType}
+ function SparseArrays.droptol!(A::ST, tol::Real) where {ST <: $SparseMatrixType}
B = coo_type(A)(A)
droptol!(B, tol)
- copyto!(A, ST(B))
+ return copyto!(A, ST(B))
end
- function LinearAlgebra.exp(A::ST; threshold = 1e-7, nonzero_tol = 1e-14) where {ST<:$SparseMatrixType}
+ function LinearAlgebra.exp(A::ST; threshold = 1.0e-7, nonzero_tol = 1.0e-14) where {ST <: $SparseMatrixType}
rows = LinearAlgebra.checksquare(A) # Throws exception if not square
typeA = eltype(A)
@@ -211,40 +213,40 @@ for SparseMatrixType in [:AbstractGPUSparseMatrixCSC, :AbstractGPUSparseMatrixCS
copyto!(P, P + next_term)
n = n + 1
end
- for n = 1:log2(scaling_factor)
- P = P * P;
+ for n in 1:log2(scaling_factor)
+ P = P * P
if nnz(P) / length(P) < 0.25
droptol!(P, nonzero_tol)
end
end
- P
+ return P
end
end
end
### INDEXING
-Base.getindex(A::AbstractGPUSparseVector, ::Colon) = copy(A)
+Base.getindex(A::AbstractGPUSparseVector, ::Colon) = copy(A)
Base.getindex(A::AbstractGPUSparseMatrix, ::Colon, ::Colon) = copy(A)
-Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon) = getindex(A, i, 1:size(A, 2))
-Base.getindex(A::AbstractGPUSparseMatrix, ::Colon, i) = getindex(A, 1:size(A, 1), i)
-Base.getindex(A::AbstractGPUSparseMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])
+Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon) = getindex(A, i, 1:size(A, 2))
+Base.getindex(A::AbstractGPUSparseMatrix, ::Colon, i) = getindex(A, 1:size(A, 1), i)
+Base.getindex(A::AbstractGPUSparseMatrix, I::Tuple{Integer, Integer}) = getindex(A, I[1], I[2])
function Base.getindex(A::AbstractGPUSparseVector{Tv, Ti}, i::Integer) where {Tv, Ti}
@boundscheck checkbounds(A, i)
ii = searchsortedfirst(SparseArrays.nonzeroinds(A), convert(Ti, i))
(ii > SparseArrays.nnz(A) || SparseArrays.nonzeroinds(A)[ii] != i) && return zero(Tv)
- SparseArrays.nonzeros(A)[ii]
+ return SparseArrays.nonzeros(A)[ii]
end
-function Base.getindex(A::AbstractGPUSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
+function Base.getindex(A::AbstractGPUSparseMatrixCSC{T}, i0::Integer, i1::Integer) where {T}
@boundscheck checkbounds(A, i0, i1)
r1 = Int(SparseArrays.getcolptr(A)[i1])
- r2 = Int(SparseArrays.getcolptr(A)[i1+1]-1)
+ r2 = Int(SparseArrays.getcolptr(A)[i1 + 1] - 1)
(r1 > r2) && return zero(T)
r1 = searchsortedfirst(SparseArrays.rowvals(A), i0, r1, r2, Base.Order.Forward)
(r1 > r2 || SparseArrays.rowvals(A)[r1] != i0) && return zero(T)
- SparseArrays.nonzeros(A)[r1]
+ return SparseArrays.nonzeros(A)[r1]
end
## copying between sparse GPU arrays
@@ -259,7 +261,7 @@ function Base.copyto!(dst::AbstractGPUSparseVector, src::AbstractGPUSparseVector
copyto!(SparseArrays.nonzeroinds(dst), SparseArrays.nonzeroinds(src))
copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
dst.nnz = src.nnz
- dst
+ return dst
end
function Base.copyto!(dst::AbstractGPUSparseMatrixCSC, src::AbstractGPUSparseMatrixCSC)
@@ -273,7 +275,7 @@ function Base.copyto!(dst::AbstractGPUSparseMatrixCSC, src::AbstractGPUSparseMat
copyto!(SparseArrays.rowvals(dst), SparseArrays.rowvals(src))
copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
dst.nnz = src.nnz
- dst
+ return dst
end
### BROADCAST
@@ -283,7 +285,7 @@ struct GPUSparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct GPUSparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseVector}) = GPUSparseVecStyle()
Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseMatrix}) = GPUSparseMatStyle()
-const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
+const SPVM = Union{GPUSparseVecStyle, GPUSparseMatStyle}
# GPUSparseVecStyle handles 0-1 dimensions, GPUSparseMatStyle 0-2 dimensions.
# GPUSparseVecStyle promotes to GPUSparseMatStyle for 2 dimensions.
@@ -291,11 +293,11 @@ const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
GPUSparseVecStyle(::Val{0}) = GPUSparseVecStyle()
GPUSparseVecStyle(::Val{1}) = GPUSparseVecStyle()
GPUSparseVecStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseVecStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
GPUSparseMatStyle(::Val{0}) = GPUSparseMatStyle()
GPUSparseMatStyle(::Val{1}) = GPUSparseMatStyle()
GPUSparseMatStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseMatStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{1}) = GPUSparseVecStyle()
Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{2}) = GPUSparseMatStyle()
@@ -322,37 +324,37 @@ end
# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
- capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+ capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} =
# This definition is identical to the one above and necessary only for
# avoiding method ambiguity.
- capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+ capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
@inline capturescalars(f, mixedargs::Tuple{AbstractGPUSparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
- capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
-@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
- capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
+ capturescalars((a1, args...) -> f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
+@inline capturescalars(f, mixedargs::Tuple{Union{Ref, AbstractArray{<:Any, 0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
+ capturescalars((args...) -> f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
scalararg(::Number) = true
scalararg(::Any) = false
-scalarwrappedarg(::Union{AbstractArray{<:Any,0},Ref}) = true
+scalarwrappedarg(::Union{AbstractArray{<:Any, 0}, Ref}) = true
scalarwrappedarg(::Any) = false
@inline function _capturescalars()
return (), () -> ()
end
@inline function _capturescalars(arg, mixedargs...)
- let (rest, f) = _capturescalars(mixedargs...)
+ return let (rest, f) = _capturescalars(mixedargs...)
if scalararg(arg)
- return rest, @inline function(tail...)
- (arg, f(tail...)...)
+ return rest, @inline function (tail...)
+ return (arg, f(tail...)...)
end # add back scalararg after (in makeargs)
elseif scalarwrappedarg(arg)
- return rest, @inline function(tail...)
- (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
+ return rest, @inline function (tail...)
+ return (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
end # unwrap and add back scalararg after (in makeargs)
else
- return (arg, rest...), @inline function(head, tail...)
- (head, f(tail...)...)
+ return (arg, rest...), @inline function (head, tail...)
+ return (head, f(tail...)...)
end # pass-through to broadcast
end
end
@@ -388,13 +390,13 @@ access the elements themselves.
For convenience, this iterator can be passed non-sparse arguments as well, which will be
ignored (with the returned `col`/`ptr` values set to 0).
"""
-struct CSRIterator{Ti,N,ATs}
+struct CSRIterator{Ti, N, ATs}
row::Ti
col_ends::NTuple{N, Ti}
args::ATs
end
-function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
+function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti, N}
# check that `row` is valid for all arguments
@boundscheck begin
ntuple(Val(N)) do i
@@ -406,16 +408,16 @@ function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
col_ends = ntuple(Val(N)) do i
arg = @inbounds args[i]
if arg isa GPUSparseDeviceMatrixCSR
- @inbounds(arg.rowPtr[row+1])
+ @inbounds(arg.rowPtr[row + 1])
else
zero(Ti)
end
end
- CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
+ return CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
end
-@inline function Base.iterate(iter::CSRIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSRIterator{Ti, N}, state = nothing) where {Ti, N}
# helper function to get the column of a sparse array at a specific pointer
@inline function get_col(i, ptr)
arg = @inbounds iter.args[i]
@@ -425,13 +427,14 @@ end
return @inbounds arg.colVal[ptr] % Ti
end
end
- typemax(Ti)
+ return typemax(Ti)
end
# initialize the state
# - ptr: the current index into the colVal/nzVal arrays
# - col: the current column index (cached so that we don't have to re-read each time)
- state = something(state,
+ state = something(
+ state,
ntuple(Val(N)) do i
arg = @inbounds iter.args[i]
if arg isa GPUSparseDeviceMatrixCSR
@@ -473,13 +476,13 @@ end
return (cur_col, ptrs), new_state
end
-struct CSCIterator{Ti,N,ATs}
+struct CSCIterator{Ti, N, ATs}
col::Ti
row_ends::NTuple{N, Ti}
args::ATs
end
-function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
+function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti, N}
# check that `col` is valid for all arguments
@boundscheck begin
ntuple(Val(N)) do i
@@ -491,17 +494,17 @@ function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
row_ends = ntuple(Val(N)) do i
arg = @inbounds args[i]
x = if arg isa GPUSparseDeviceMatrixCSC
- @inbounds(arg.colPtr[col+1])
+ @inbounds(arg.colPtr[col + 1])
else
zero(Ti)
end
x
end
- CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
+ return CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
end
-@inline function Base.iterate(iter::CSCIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSCIterator{Ti, N}, state = nothing) where {Ti, N}
# helper function to get the column of a sparse array at a specific pointer
@inline function get_col(i, ptr)
arg = @inbounds iter.args[i]
@@ -511,13 +514,14 @@ end
return @inbounds arg.rowVal[ptr] % Ti
end
end
- typemax(Ti)
+ return typemax(Ti)
end
# initialize the state
# - ptr: the current index into the rowVal/nzVal arrays
# - row: the current row index (cached so that we don't have to re-read each time)
- state = something(state,
+ state = something(
+ state,
ntuple(Val(N)) do i
arg = @inbounds iter.args[i]
if arg isa GPUSparseDeviceMatrixCSC
@@ -560,8 +564,8 @@ end
end
# helpers to index a sparse or dense array
-function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR,GPUSparseDeviceMatrixCSC}, I, ptr)
- if ptr == 0
+function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR, GPUSparseDeviceMatrixCSC}, I, ptr)
+ return if ptr == 0
zero(eltype(arg))
else
@inbounds arg.nzVal[ptr]
@@ -589,9 +593,11 @@ function _has_row(A::GPUSparseDeviceVector, offsets, row, ::Bool)
return 0
end
-@kernel function compute_offsets_kernel(::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
- fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
- args...) where {Ti, N}
+@kernel function compute_offsets_kernel(
+ ::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
+ fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+ args...
+ ) where {Ti, N}
my_ix = @index(Global, Linear)
row = my_ix + first_row - one(eltype(my_ix))
if row ≤ last_row
@@ -610,11 +616,13 @@ end
end
# kernel to count the number of non-zeros in a row, to determine the row offsets
-@kernel function compute_offsets_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
- offsets::AbstractVector{Ti}, args...) where Ti
+@kernel function compute_offsets_kernel(
+ T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
+ offsets::AbstractVector{Ti}, args...
+ ) where {Ti}
# every thread processes an entire row
leading_dim = @index(Global, Linear)
- if leading_dim ≤ length(offsets)-1
+ if leading_dim ≤ length(offsets) - 1
iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
# count the nonzero leading_dims of all inputs
@@ -629,19 +637,21 @@ end
if leading_dim == 1
offsets[1] = 1
end
- offsets[leading_dim+1] = accum
+ offsets[leading_dim + 1] = accum
end
end
end
-@kernel function sparse_to_sparse_broadcast_kernel(f::F, output::GPUSparseDeviceVector{Tv,Ti},
- offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
- args...) where {Tv, Ti, N, F}
+@kernel function sparse_to_sparse_broadcast_kernel(
+ f::F, output::GPUSparseDeviceVector{Tv, Ti},
+ offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+ args...
+ ) where {Tv, Ti, N, F}
row_ix = @index(Global, Linear)
if row_ix ≤ output.nnz
row_and_ptrs = @inbounds offsets[row_ix]
- row = @inbounds row_and_ptrs[1]
- arg_ptrs = @inbounds row_and_ptrs[2]
+ row = @inbounds row_and_ptrs[1]
+ arg_ptrs = @inbounds row_and_ptrs[2]
vals = ntuple(Val(N)) do i
@inline
arg = @inbounds args[i]
@@ -651,14 +661,20 @@ end
_getindex(arg, row, ptr)
end
output_val = f(vals...)
- @inbounds output.iPtr[row_ix] = row
+ @inbounds output.iPtr[row_ix] = row
@inbounds output.nzVal[row_ix] = output_val
end
end
-@kernel function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{<:AbstractArray,Nothing},
- args...) where {Ti, T<:Union{GPUSparseDeviceMatrixCSR{<:Any,Ti},
- GPUSparseDeviceMatrixCSC{<:Any,Ti}}}
+@kernel function sparse_to_sparse_broadcast_kernel(
+ f, output::T, offsets::Union{<:AbstractArray, Nothing},
+ args...
+ ) where {
+ Ti, T <: Union{
+ GPUSparseDeviceMatrixCSR{<:Any, Ti},
+ GPUSparseDeviceMatrixCSC{<:Any, Ti},
+ },
+ }
# every thread processes an entire row
leading_dim = @index(Global, Linear)
leading_dim_size = output isa GPUSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
@@ -666,19 +682,19 @@ end
iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
- output_ptrs = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
+ output_ptrs = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
output_ivals = output isa GPUSparseDeviceMatrixCSR ? output.colVal : output.rowVal
# fetch the row offset, and write it to the output
@inbounds begin
output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
if leading_dim == leading_dim_size
- output_ptrs[leading_dim+one(eltype(leading_dim))] = offsets[leading_dim+one(eltype(leading_dim))]
+ output_ptrs[leading_dim + one(eltype(leading_dim))] = offsets[leading_dim + one(eltype(leading_dim))]
end
end
# set the values for this row
for (sub_leading_dim, ptrs) in iter
- index_first = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
+ index_first = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
index_second = output isa GPUSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim
I = CartesianIndex(index_first, index_second)
vals = ntuple(Val(length(args))) do i
@@ -693,9 +709,15 @@ end
end
end
end
-@kernel function sparse_to_dense_broadcast_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR{Tv, Ti},
- AbstractGPUSparseMatrixCSC{Tv, Ti}}},
- f, output::AbstractArray, args...) where {Tv, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+ T::Type{
+ <:Union{
+ AbstractGPUSparseMatrixCSR{Tv, Ti},
+ AbstractGPUSparseMatrixCSC{Tv, Ti},
+ },
+ },
+ f, output::AbstractArray, args...
+ ) where {Tv, Ti}
# every thread processes an entire row
leading_dim = @index(Global, Linear)
leading_dim_size = T <: AbstractGPUSparseMatrixCSR ? size(output, 1) : size(output, 2)
@@ -704,7 +726,7 @@ end
# set the values for this row
for (sub_leading_dim, ptrs) in iter
- index_first = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
+ index_first = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
index_second = T <: AbstractGPUSparseMatrixCSR ? sub_leading_dim : leading_dim
I = CartesianIndex(index_first, index_second)
vals = ntuple(Val(length(args))) do i
@@ -718,16 +740,18 @@ end
end
end
-@kernel function sparse_to_dense_broadcast_kernel(::Type{<:AbstractGPUSparseVector}, f::F,
- output::AbstractArray{Tv},
- offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
- args...) where {Tv, F, N, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+ ::Type{<:AbstractGPUSparseVector}, f::F,
+ output::AbstractArray{Tv},
+ offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+ args...
+ ) where {Tv, F, N, Ti}
# every thread processes an entire row
row_ix = @index(Global, Linear)
if row_ix ≤ length(output)
row_and_ptrs = @inbounds offsets[row_ix]
- row = @inbounds row_and_ptrs[1]
- arg_ptrs = @inbounds row_and_ptrs[2]
+ row = @inbounds row_and_ptrs[1]
+ arg_ptrs = @inbounds row_and_ptrs[2]
vals = ntuple(Val(length(args))) do i
@inline
arg = @inbounds args[i]
@@ -742,39 +766,41 @@ end
end
## COV_EXCL_STOP
-function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatStyle}})
+function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle, GPUSparseMatStyle}})
# find the sparse inputs
bc = Broadcast.flatten(bc)
sparse_args = findall(bc.args) do arg
arg isa AbstractGPUSparseArray
end
- sparse_types = unique(map(i->nameof(typeof(bc.args[i])), sparse_args))
+ sparse_types = unique(map(i -> nameof(typeof(bc.args[i])), sparse_args))
if length(sparse_types) > 1
error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported")
end
sparse_typ = typeof(bc.args[first(sparse_args)])
- sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC,AbstractGPUSparseVector} ||
+ sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC, AbstractGPUSparseVector} ||
error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices")
Ti = if sparse_typ <: AbstractGPUSparseMatrixCSR
- reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args))
+ reduce(promote_type, map(i -> eltype(bc.args[i].rowPtr), sparse_args))
elseif sparse_typ <: AbstractGPUSparseMatrixCSC
- reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args))
+ reduce(promote_type, map(i -> eltype(bc.args[i].colPtr), sparse_args))
elseif sparse_typ <: AbstractGPUSparseVector
- reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args))
+ reduce(promote_type, map(i -> eltype(bc.args[i].iPtr), sparse_args))
end
# determine the output type
Tv = Broadcast.combine_eltypes(bc.f, eltype.(bc.args))
if !Base.isconcretetype(Tv)
- error("""GPU sparse broadcast resulted in non-concrete element type $Tv.
- This probably means that the function you are broadcasting contains an error or type instability.""")
+ error(
+ """GPU sparse broadcast resulted in non-concrete element type $Tv.
+ This probably means that the function you are broadcasting contains an error or type instability."""
+ )
end
# partially-evaluate the function, removing scalars.
parevalf, passedsrcargstup = capturescalars(bc.f, bc.args)
# check if the partially-evaluated function preserves zeros. if so, we'll only need to
# apply it to the sparse input arguments, preserving the sparse structure.
- if all(arg->isa(arg, AbstractSparseArray), passedsrcargstup)
+ if all(arg -> isa(arg, AbstractSparseArray), passedsrcargstup)
fofzeros = parevalf(_zeros_eltypes(passedsrcargstup...)...)
fpreszeros = _iszero(fofzeros)
else
@@ -783,7 +809,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
# the kernels below parallelize across rows or cols, not elements, so it's unlikely
# we'll launch many threads. to maximize utilization, parallelize across blocks first.
- rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1)
+ rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1)
# `size(bc, ::Int)` is missing
# for AbstractGPUSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
# but the only rows present in any sparse vector input are between 2 and 128, no need to
@@ -798,7 +824,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
# either we have dense inputs, or the function isn't preserving zeros,
# so use a dense output to broadcast into.
val_array = nonzeros(sparse_arg)
- output = similar(val_array, Tv, size(bc))
+ output = similar(val_array, Tv, size(bc))
# since we'll be iterating the sparse inputs, we need to pre-fill the dense output
# with appropriate values (while setting the sparse inputs to zero). we do this by
# re-using the dense broadcast implementation.
@@ -816,24 +842,24 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
# this avoids a kernel launch and costly synchronization.
if sparse_typ <: AbstractGPUSparseMatrixCSR
offsets = rowPtr = sparse_arg.rowPtr
- colVal = similar(sparse_arg.colVal)
- nzVal = similar(sparse_arg.nzVal, Tv)
- output = sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
+ colVal = similar(sparse_arg.colVal)
+ nzVal = similar(sparse_arg.nzVal, Tv)
+ output = sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
elseif sparse_typ <: AbstractGPUSparseMatrixCSC
offsets = colPtr = sparse_arg.colPtr
- rowVal = similar(sparse_arg.rowVal)
- nzVal = similar(sparse_arg.nzVal, Tv)
- output = sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
+ rowVal = similar(sparse_arg.rowVal)
+ nzVal = similar(sparse_arg.nzVal, Tv)
+ output = sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
end
else
# determine the number of non-zero elements per row so that we can create an
# appropriately-structured output container
offsets = if sparse_typ <: AbstractGPUSparseMatrixCSR
ptr_array = sparse_arg.rowPtr
- similar(ptr_array, Ti, rows+1)
+ similar(ptr_array, Ti, rows + 1)
elseif sparse_typ <: AbstractGPUSparseMatrixCSC
ptr_array = sparse_arg.colPtr
- similar(ptr_array, Ti, cols+1)
+ similar(ptr_array, Ti, cols + 1)
elseif sparse_typ <: AbstractGPUSparseVector
ptr_array = sparse_arg.iPtr
@allowscalar begin
@@ -847,7 +873,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
end
end
overall_first_row = min(arg_first_rows...)
- overall_last_row = max(arg_last_rows...)
+ overall_last_row = max(arg_last_rows...)
similar(ptr_array, Pair{Ti, NTuple{length(bc.args), Ti}}, overall_last_row - overall_first_row + 1)
end
let
@@ -857,7 +883,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
(sparse_typ, offsets, bc.args...)
end
kernel = compute_offsets_kernel(get_backend(bc.args[first(sparse_args)]))
- kernel(args...; ndrange=length(offsets))
+ kernel(args...; ndrange = length(offsets))
end
# accumulate these values so that we can use them directly as row pointer offsets,
# as well a...*[Comment body truncated]* |
src/device/sparse.jl
Outdated
| struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti} | ||
| iPtr::Vi | ||
| nzVal::Vv | ||
| len::Int | ||
| nnz::Ti | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit inconsistent that we keep the host sparse object layout to the back-end, but define the device one concretely. I'm not sure if it's better to entirely move the definitions away from (or rather into) GPUArrays.jl though. I guess back-ends may want additional control over the object layout in order to facilitate vendor library interactions, but maybe we should then also leave the device-side version up to the back-end and only implement things here in terms of SparseArrays interfaces (rowvals, getcolptr, etc). Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was torn on this too. The advantage here is that libraries get a working device-side implementation "for free" -- they are able to implement their own (better) one and just give Adapt.jl information about how to move their host-side structs to it.
|
Now with added support for |
|
Made some small updates based on tests "in the wild" with CUDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now may be a good time to start and document the exact interface methods that back-ends should implement. Could you add a section to the README on the added SparseArray methods?
## Interface methods
...
### Sparse Array support (optional)
...
| V = S2.nzVal | ||
|
|
||
| # To make it compatible with the SparseArrays.jl version | ||
| idxs = sortperm(J) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have sorting here yet, so I take it this doesn't work everywhere?
Co-authored-by: Tim Besard <[email protected]>
Co-authored-by: Tim Besard <[email protected]>
This ports the (duplicated) sparse broadcasting support from CUDA.jl and AMDGPU.jl to GPUArrays.jl. It should allow all the "child" GPU libraries to use one unified set of on-device sparse types and broadcasting kernels. I implemented the appropriate types in
JLArraysand tests there are passing. If this merges, we should be able to strip out most of the sparse broadcasting code from downstream packages.