Skip to content

Conversation

@kshyatt
Copy link
Member

@kshyatt kshyatt commented Oct 8, 2025

This ports the (duplicated) sparse broadcasting support from CUDA.jl and AMDGPU.jl to GPUArrays.jl. It should allow all the "child" GPU libraries to use one unified set of on-device sparse types and broadcasting kernels. I implemented the appropriate types in JLArrays and tests there are passing. If this merges, we should be able to strip out most of the sparse broadcasting code from downstream packages.

@kshyatt kshyatt requested review from amontoison and maleadt October 8, 2025 13:56
@kshyatt kshyatt force-pushed the ksh/sparse branch 2 times, most recently from 844f20c to 9a74b4d Compare October 8, 2025 14:07
@kshyatt
Copy link
Member Author

kshyatt commented Oct 8, 2025

I had to @allowscalar around the accumulate! and sort! calls since KA doesn't have device-agnostic implementations of these IIRC. Could be added as part of this PR if people prefer.

@maleadt
Copy link
Member

maleadt commented Oct 8, 2025

I had to @allowscalar around the accumulate! and sort! calls since KA doesn't have device-agnostic implementations of these IIRC

You mean GPUArrays.jl itself? I wouldn't expect those to be defined in KA.jl (maybe AK.jl, but with different signatures).

@kshyatt kshyatt marked this pull request as ready for review October 8, 2025 18:03
@kshyatt
Copy link
Member Author

kshyatt commented Oct 8, 2025

Now with all tests uncommented and testing SparseVector and SparseMatrixCSC also

@github-actions
Copy link
Contributor

github-actions bot commented Oct 8, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/JLArrays/src/JLArrays.jl b/lib/JLArrays/src/JLArrays.jl
index 097f015..4ffe1d4 100644
--- a/lib/JLArrays/src/JLArrays.jl
+++ b/lib/JLArrays/src/JLArrays.jl
@@ -24,7 +24,7 @@ end
 
 module AS
 
-const Generic  = 0
+    const Generic = 0
 
 end
 
@@ -129,28 +129,32 @@ mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, T
     len::Int
     nnz::Ti
 
-    function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
-                                    len::Integer) where {Tv, Ti <: Integer}
-        new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
+    function JLSparseVector{Tv, Ti}(
+            iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1},
+            len::Integer
+        ) where {Tv, Ti <: Integer}
+        return new{Tv, Ti}(iPtr, nzVal, len, length(nzVal))
     end
 end
-SparseArrays.nnz(x::JLSparseVector)          = x.nnz 
-SparseArrays.nonzeroinds(x::JLSparseVector)  = x.iPtr
-SparseArrays.nonzeros(x::JLSparseVector)     = x.nzVal
+SparseArrays.nnz(x::JLSparseVector) = x.nnz
+SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr
+SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal
 
 mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti}
     colPtr::JLArray{Ti, 1}
     rowVal::JLArray{Ti, 1}
     nzVal::JLArray{Tv, 1}
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 
-    function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
-                                       nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
-        new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
+    function JLSparseMatrixCSC{Tv, Ti}(
+            colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1},
+            nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+        ) where {Tv, Ti <: Integer}
+        return new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal))
     end
 end
-function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
     return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims)
 end
 SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal))
@@ -160,28 +164,30 @@ JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A
 function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti}
     @boundscheck checkbounds(A, i, j)
     r1 = Int(@inbounds A.colPtr[j])
-    r2 = Int(@inbounds A.colPtr[j+1]-1)
+    r2 = Int(@inbounds A.colPtr[j + 1] - 1)
     (r1 > r2) && return zero(Tv)
     r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1
-    ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
+    return ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1]
 end
 
 mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti}
     rowPtr::JLArray{Ti, 1}
     colVal::JLArray{Ti, 1}
     nzVal::JLArray{Tv, 1}
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 
-    function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
-                                       nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer}
-        new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
+    function JLSparseMatrixCSR{Tv, Ti}(
+            rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1},
+            nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}
+        ) where {Tv, Ti <: Integer}
+        return new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal))
     end
 end
-function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer}
+function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2, <:Integer}) where {Tv, Ti <: Integer}
     return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims)
 end
-function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR) 
+function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR)
     x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal))
     return SparseMatrixCSC(transpose(x_transpose))
 end
@@ -199,22 +205,26 @@ function Base.size(g::JLSparseMatrixCSR, d::Integer)
 end
 
 JLSparseMatrixCSR(Mat::Transpose{Tv, <:SparseMatrixCSC}) where {Tv} =
-    JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
-                         JLVector(parent(Mat).nzval), size(Mat))
+    JLSparseMatrixCSR(
+    JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
+    JLVector(parent(Mat).nzval), size(Mat)
+)
 JLSparseMatrixCSR(Mat::Adjoint{Tv, <:SparseMatrixCSC}) where {Tv} =
-    JLSparseMatrixCSR(JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
-                         JLVector(conj.(parent(Mat).nzval)), size(Mat))
+    JLSparseMatrixCSR(
+    JLVector{Cint}(parent(Mat).colptr), JLVector{Cint}(parent(Mat).rowval),
+    JLVector(conj.(parent(Mat).nzval)), size(Mat)
+)
 
 JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A
 
 function Base.getindex(A::JLSparseMatrixCSR{Tv, Ti}, i0::Integer, i1::Integer) where {Tv, Ti}
     @boundscheck checkbounds(A, i0, i1)
     c1 = Int(A.rowPtr[i0])
-    c2 = Int(A.rowPtr[i0+1]-1)
+    c2 = Int(A.rowPtr[i0 + 1] - 1)
     (c1 > c2) && return zero(Tv)
     c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward)
     (c1 > c2 || A.colVal[c1] != i1) && return zero(Tv)
-    nonzeros(A)[c1]
+    return nonzeros(A)[c1]
 end
 
 GPUArrays.storage(a::JLArray) = a.data
@@ -230,12 +240,12 @@ GPUArrays.sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR
 GPUArrays.sparse_array_type(sa::JLSparseVector) = JLSparseVector
 GPUArrays.sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector
 
-GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray 
-GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray 
-GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray 
-GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray 
-GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray 
-GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray 
+GPUArrays.dense_array_type(sa::JLSparseVector) = JLArray
+GPUArrays.dense_array_type(::Type{<:JLSparseVector}) = JLArray
+GPUArrays.dense_array_type(sa::JLSparseMatrixCSC) = JLArray
+GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray
+GPUArrays.dense_array_type(sa::JLSparseMatrixCSR) = JLArray
+GPUArrays.dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray
 
 GPUArrays.csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC
 GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
@@ -243,19 +253,19 @@ GPUArrays.csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR
 Base.similar(Mat::JLSparseMatrixCSR) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat)), size(Mat))
 Base.similar(Mat::JLSparseMatrixCSR, T::Type) = JLSparseMatrixCSR(copy(Mat.rowPtr), copy(Mat.colVal), similar(nonzeros(Mat), T), size(Mat))
 
-Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) =  JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
-Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) =  JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
+Base.similar(Mat::JLSparseMatrixCSC, T::Type, N::Int, M::Int) = JLSparseMatrixCSC(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
+Base.similar(Mat::JLSparseMatrixCSR, T::Type, N::Int, M::Int) = JLSparseMatrixCSR(JLVector([zero(Int32)]), JLVector{Int32}(undef, 0), JLVector{T}(undef, 0), (N, M))
 
-Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) 
-Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M) 
+Base.similar(Mat::JLSparseMatrixCSC{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
+Base.similar(Mat::JLSparseMatrixCSR{Tv, Ti}, N::Int, M::Int) where {Tv, Ti} = similar(Mat, Tv, N, M)
 
-Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) 
-Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...) 
+Base.similar(Mat::JLSparseMatrixCSC, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
+Base.similar(Mat::JLSparseMatrixCSR, T::Type, dims::Tuple{Int, Int}) = similar(Mat, T, dims...)
 
-Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...) 
-Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...) 
+Base.similar(Mat::JLSparseMatrixCSC, dims::Tuple{Int, Int}) = similar(Mat, dims...)
+Base.similar(Mat::JLSparseMatrixCSR, dims::Tuple{Int, Int}) = similar(Mat, dims...)
 
-JLArray(x::JLSparseVector)    = JLArray(collect(SparseVector(x)))
+JLArray(x::JLSparseVector) = JLArray(collect(SparseVector(x)))
 JLArray(x::JLSparseMatrixCSC) = JLArray(collect(SparseMatrixCSC(x)))
 JLArray(x::JLSparseMatrixCSR) = JLArray(collect(SparseMatrixCSC(x)))
 
@@ -360,11 +370,11 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
 JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
 
 function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv}
-    iPtr  = JLVector{Ti}(undef, length(xs.nzind))
+    iPtr = JLVector{Ti}(undef, length(xs.nzind))
     nzVal = JLVector{Tv}(undef, length(xs.nzval))
     copyto!(iPtr, convert(Vector{Ti}, xs.nzind))
     copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
-    return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),)
+    return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs))
 end
 Base.length(x::JLSparseVector) = x.len
 Base.size(x::JLSparseVector) = (x.len,)
@@ -372,13 +382,13 @@ Base.size(x::JLSparseVector) = (x.len,)
 function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
     colPtr = JLVector{Ti}(undef, length(xs.colptr))
     rowVal = JLVector{Ti}(undef, length(xs.rowval))
-    nzVal  = JLVector{Tv}(undef, length(xs.nzval))
+    nzVal = JLVector{Tv}(undef, length(xs.nzval))
     copyto!(colPtr, convert(Vector{Ti}, xs.colptr))
     copyto!(rowVal, convert(Vector{Ti}, xs.rowval))
-    copyto!(nzVal,  convert(Vector{Tv}, xs.nzval))
+    copyto!(nzVal, convert(Vector{Tv}, xs.nzval))
     return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n))
 end
-JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs)) 
+JLSparseMatrixCSC(xs::SparseVector) = JLSparseMatrixCSC(SparseMatrixCSC(xs))
 Base.length(x::JLSparseMatrixCSC) = prod(x.dims)
 Base.size(x::JLSparseMatrixCSC) = x.dims
 
@@ -386,10 +396,10 @@ function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv}
     csr_xs = SparseMatrixCSC(transpose(xs))
     rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr))
     colVal = JLVector{Ti}(undef, length(csr_xs.rowval))
-    nzVal  = JLVector{Tv}(undef, length(csr_xs.nzval))
+    nzVal = JLVector{Tv}(undef, length(csr_xs.nzval))
     copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr))
     copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval))
-    copyto!(nzVal,  convert(Vector{Tv}, csr_xs.nzval))
+    copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval))
     return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n))
 end
 JLSparseMatrixCSR(xs::SparseVector{Tv, Ti}) where {Ti, Tv} = JLSparseMatrixCSR(SparseMatrixCSC(xs))
@@ -410,26 +420,26 @@ function Base.copyto!(dst::JLSparseMatrixCSR, src::JLSparseMatrixCSR)
     copyto!(dst.colVal, src.colVal)
     copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
     dst.nnz = src.nnz
-    dst
+    return dst
 end
 Base.length(x::JLSparseMatrixCSR) = prod(x.dims)
 Base.size(x::JLSparseMatrixCSR) = x.dims
 
 function GPUArrays._spadjoint(A::JLSparseMatrixCSR)
     Aᴴ = JLSparseMatrixCSC(A.rowPtr, A.colVal, conj(A.nzVal), reverse(size(A)))
-    JLSparseMatrixCSR(Aᴴ)
+    return JLSparseMatrixCSR(Aᴴ)
 end
 function GPUArrays._sptranspose(A::JLSparseMatrixCSR)
     Aᵀ = JLSparseMatrixCSC(A.rowPtr, A.colVal, A.nzVal, reverse(size(A)))
-    JLSparseMatrixCSR(Aᵀ)
+    return JLSparseMatrixCSR(Aᵀ)
 end
 function _spadjoint(A::JLSparseMatrixCSC)
     Aᴴ = JLSparseMatrixCSR(A.colPtr, A.rowVal, conj(A.nzVal), reverse(size(A)))
-    JLSparseMatrixCSC(Aᴴ)
+    return JLSparseMatrixCSC(Aᴴ)
 end
 function _sptranspose(A::JLSparseMatrixCSC)
     Aᵀ = JLSparseMatrixCSR(A.colPtr, A.rowVal, A.nzVal, reverse(size(A)))
-    JLSparseMatrixCSC(Aᵀ)
+    return JLSparseMatrixCSC(Aᵀ)
 end
 
 # idempotency
@@ -573,17 +583,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
     R
 end
 
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
-Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} =
-GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} =
+    GPUSparseDeviceMatrixCSC{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} =
+    GPUSparseDeviceMatrixCSR{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz)
+Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv, Ti}) where {Tv, Ti} =
+    GPUSparseDeviceVector{Tv, Ti, JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}, AS.Generic}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz)
 
 ## KernelAbstractions interface
 
 KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend()
-KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend()
+KernelAbstractions.get_backend(a::JLA) where {JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector}} = JLBackend()
 
 function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic
     return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace)
diff --git a/src/device/sparse.jl b/src/device/sparse.jl
index b8346ea..22b4752 100644
--- a/src/device/sparse.jl
+++ b/src/device/sparse.jl
@@ -10,12 +10,12 @@ using SparseArrays
 # core types
 
 export GPUSparseDeviceVector, GPUSparseDeviceMatrixCSC, GPUSparseDeviceMatrixCSR,
-       GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
+    GPUSparseDeviceMatrixBSR, GPUSparseDeviceMatrixCOO
 
 abstract type AbstractGPUSparseDeviceMatrix{Tv, Ti} <: AbstractSparseMatrix{Tv, Ti} end
 
 
-struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv, A} <: AbstractSparseVector{Tv,Ti}
+struct GPUSparseDeviceVector{Tv, Ti, Vi, Vv, A} <: AbstractSparseVector{Tv, Ti}
     iPtr::Vi
     nzVal::Vv
     len::Int
@@ -28,11 +28,11 @@ SparseArrays.nnz(g::GPUSparseDeviceVector) = g.nnz
 SparseArrays.nonzeroinds(g::GPUSparseDeviceVector) = g.iPtr
 SparseArrays.nonzeros(g::GPUSparseDeviceVector) = g.nzVal
 
-struct GPUSparseDeviceMatrixCSC{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
+struct GPUSparseDeviceMatrixCSC{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     colPtr::Vi
     rowVal::Vi
     nzVal::Vv
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 end
 
@@ -41,7 +41,7 @@ Base.size(g::GPUSparseDeviceMatrixCSC) = g.dims
 SparseArrays.nnz(g::GPUSparseDeviceMatrixCSC) = g.nnz
 SparseArrays.rowvals(g::GPUSparseDeviceMatrixCSC) = g.rowVal
 SparseArrays.getcolptr(g::GPUSparseDeviceMatrixCSC) = g.colPtr
-SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = @inbounds SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col+1]-1)
+SparseArrays.nzrange(g::GPUSparseDeviceMatrixCSC, col::Integer) = @inbounds SparseArrays.getcolptr(g)[col]:(SparseArrays.getcolptr(g)[col + 1] - 1)
 SparseArrays.nonzeros(g::GPUSparseDeviceMatrixCSC) = g.nzVal
 
 const GPUSparseDeviceColumnView{Tv, Ti, Vi, Vv, A} = SubArray{Tv, 1, GPUSparseDeviceMatrixCSC{Tv, Ti, Vi, Vv, A}, Tuple{Base.Slice{Base.OneTo{Int}}, Int}}
@@ -66,7 +66,7 @@ function SparseArrays.nnz(x::GPUSparseDeviceColumnView)
     return length(SparseArrays.nzrange(A, colidx))
 end
 
-struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCSR{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     rowPtr::Vi
     colVal::Vi
     nzVal::Vv
@@ -74,9 +74,13 @@ struct GPUSparseDeviceMatrixCSR{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{
     nnz::Ti
 end
 
-@inline function _getindex(arg::Union{GPUSparseDeviceMatrixCSR{Tv},
-                                      GPUSparseDeviceMatrixCSC{Tv},
-                                      GPUSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
+@inline function _getindex(
+        arg::Union{
+            GPUSparseDeviceMatrixCSR{Tv},
+            GPUSparseDeviceMatrixCSC{Tv},
+            GPUSparseDeviceVector{Tv},
+        }, I, ptr
+    )::Tv where {Tv}
     if ptr == 0
         return zero(Tv)
     else
@@ -84,21 +88,21 @@ end
     end
 end
 
-struct GPUSparseDeviceMatrixBSR{Tv,Ti,Vi,Vv,A} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixBSR{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     rowPtr::Vi
     colVal::Vi
     nzVal::Vv
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     blockDim::Ti
     dir::Char
     nnz::Ti
 end
 
-struct GPUSparseDeviceMatrixCOO{Tv,Ti,Vi,Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv,Ti}
+struct GPUSparseDeviceMatrixCOO{Tv, Ti, Vi, Vv, A} <: AbstractGPUSparseDeviceMatrix{Tv, Ti}
     rowInd::Vi
     colInd::Vi
     nzVal::Vv
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 end
 
@@ -115,9 +119,9 @@ struct GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M, A} <: AbstractSparseArray{T
     nnz::Ti
 end
 
-function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N,<:Integer}) where {Tv, Ti<:Integer, M, Vi<:AbstractDeviceArray{<:Integer,M}, Vv<:AbstractDeviceArray{Tv, M}, N}
+function GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N}(rowPtr::Vi, colVal::Vi, nzVal::Vv, dims::NTuple{N, <:Integer}) where {Tv, Ti <: Integer, M, Vi <: AbstractDeviceArray{<:Integer, M}, Vv <: AbstractDeviceArray{Tv, M}, N}
     @assert M == N - 1 "GPUSparseDeviceArrayCSR requires ndims(rowPtr) == ndims(colVal) == ndims(nzVal) == length(dims) - 1"
-    GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
+    return GPUSparseDeviceArrayCSR{Tv, Ti, Vi, Vv, N, M}(rowPtr, colVal, nzVal, dims, length(nzVal))
 end
 
 Base.length(g::GPUSparseDeviceArrayCSR) = prod(g.dims)
@@ -130,42 +134,42 @@ SparseArrays.getnzval(g::GPUSparseDeviceArrayCSR) = g.nzVal
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceVector)
     println(io, "$(length(A))-element device sparse vector at:")
     println(io, "  iPtr: $(A.iPtr)")
-    print(io,   "  nzVal: $(A.nzVal)")
+    return print(io, "  nzVal: $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSR)
     println(io, "$(length(A))-element device sparse matrix CSR at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colVal: $(A.colVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCSC)
     println(io, "$(length(A))-element device sparse matrix CSC at:")
     println(io, "  colPtr: $(A.colPtr)")
     println(io, "  rowVal: $(A.rowVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixBSR)
     println(io, "$(length(A))-element device sparse matrix BSR at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colVal: $(A.colVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceMatrixCOO)
     println(io, "$(length(A))-element device sparse matrix COO at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colInd: $(A.colInd)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 function Base.show(io::IO, ::MIME"text/plain", A::GPUSparseDeviceArrayCSR)
     println(io, "$(length(A))-element device sparse array CSR at:")
     println(io, "  rowPtr: $(A.rowPtr)")
     println(io, "  colVal: $(A.colVal)")
-    print(io,   "  nzVal:  $(A.nzVal)")
+    return print(io, "  nzVal:  $(A.nzVal)")
 end
 
 # COV_EXCL_STOP
diff --git a/src/host/sparse.jl b/src/host/sparse.jl
index 1bee2ce..80e816d 100644
--- a/src/host/sparse.jl
+++ b/src/host/sparse.jl
@@ -10,13 +10,13 @@ abstract type AbstractGPUSparseMatrixCSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, T
 abstract type AbstractGPUSparseMatrixCOO{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
 abstract type AbstractGPUSparseMatrixBSR{Tv, Ti} <: AbstractGPUSparseArray{Tv, Ti, 2} end
 
-const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector,AbstractGPUSparseMatrix}
+const AbstractGPUSparseVecOrMat = Union{AbstractGPUSparseVector, AbstractGPUSparseMatrix}
 
-SparseArrays.nnz(g::T) where {T<:AbstractGPUSparseArray} = g.nnz
-SparseArrays.nonzeros(g::T) where {T<:AbstractGPUSparseArray} = g.nzVal
+SparseArrays.nnz(g::T) where {T <: AbstractGPUSparseArray} = g.nnz
+SparseArrays.nonzeros(g::T) where {T <: AbstractGPUSparseArray} = g.nzVal
 
-SparseArrays.nonzeroinds(g::T) where {T<:AbstractGPUSparseVector} = g.iPtr
-SparseArrays.rowvals(g::T) where {T<:AbstractGPUSparseVector} = SparseArrays.nonzeroinds(g)
+SparseArrays.nonzeroinds(g::T) where {T <: AbstractGPUSparseVector} = g.iPtr
+SparseArrays.rowvals(g::T) where {T <: AbstractGPUSparseVector} = SparseArrays.nonzeroinds(g)
 
 SparseArrays.rowvals(g::AbstractGPUSparseMatrixCSC) = g.rowVal
 SparseArrays.getcolptr(S::AbstractGPUSparseMatrixCSC) = S.colPtr
@@ -30,7 +30,7 @@ Base.collect(x::AbstractGPUSparseMatrixCSR) = collect(SparseMatrixCSC(x))
 Base.collect(x::AbstractGPUSparseMatrixBSR) = collect(SparseMatrixCSC(x))
 Base.collect(x::AbstractGPUSparseMatrixCOO) = collect(SparseMatrixCSC(x))
 
-Base.Array(x::AbstractGPUSparseVector)    = collect(SparseVector(x))
+Base.Array(x::AbstractGPUSparseVector) = collect(SparseVector(x))
 Base.Array(x::AbstractGPUSparseMatrixCSC) = collect(SparseMatrixCSC(x))
 Base.Array(x::AbstractGPUSparseMatrixCSR) = collect(SparseMatrixCSC(x))
 Base.Array(x::AbstractGPUSparseMatrixBSR) = collect(SparseMatrixCSC(x))
@@ -40,44 +40,44 @@ SparseArrays.SparseVector(x::AbstractGPUSparseVector) = SparseVector(length(x),
 SparseArrays.SparseMatrixCSC(x::AbstractGPUSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(SparseArrays.getcolptr(x)), Array(SparseArrays.rowvals(x)), Array(SparseArrays.nonzeros(x)))
 
 # similar
-Base.similar(Vec::V) where {V<:AbstractGPUSparseVector} = V(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec)), length(Vec))
-Base.similar(Mat::M) where {M<:AbstractGPUSparseMatrixCSC} = M(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat)), size(Mat))
+Base.similar(Vec::V) where {V <: AbstractGPUSparseVector} = V(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec)), length(Vec))
+Base.similar(Mat::M) where {M <: AbstractGPUSparseMatrixCSC} = M(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat)), size(Mat))
 
-Base.similar(Vec::V, T::Type) where {Tv, Ti, V<:AbstractGPUSparseVector{Tv, Ti}} = sparse_array_type(V){T, Ti}(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec), T), length(Vec))
-Base.similar(Mat::M, T::Type) where {M<:AbstractGPUSparseMatrixCSC} = sparse_array_type(M)(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat), T), size(Mat))
+Base.similar(Vec::V, T::Type) where {Tv, Ti, V <: AbstractGPUSparseVector{Tv, Ti}} = sparse_array_type(V){T, Ti}(copy(SparseArrays.nonzeroinds(Vec)), similar(SparseArrays.nonzeros(Vec), T), length(Vec))
+Base.similar(Mat::M, T::Type) where {M <: AbstractGPUSparseMatrixCSC} = sparse_array_type(M)(copy(SparseArrays.getcolptr(Mat)), copy(SparseArrays.rowvals(Mat)), similar(SparseArrays.nonzeros(Mat), T), size(Mat))
 
-dense_array_type(sa::SparseVector)     = SparseVector
+dense_array_type(sa::SparseVector) = SparseVector
 dense_array_type(::Type{SparseVector}) = SparseVector
 sparse_array_type(sa::SparseVector) = SparseVector
 dense_vector_type(sa::AbstractSparseArray) = Vector
-dense_vector_type(sa::AbstractArray)       = Vector
+dense_vector_type(sa::AbstractArray) = Vector
 dense_vector_type(::Type{<:AbstractSparseArray}) = Vector
-dense_vector_type(::Type{<:AbstractArray})       = Vector
-dense_array_type(sa::SparseMatrixCSC)     = SparseMatrixCSC
+dense_vector_type(::Type{<:AbstractArray}) = Vector
+dense_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
 dense_array_type(::Type{SparseMatrixCSC}) = SparseMatrixCSC
-sparse_array_type(sa::SparseMatrixCSC)    = SparseMatrixCSC
+sparse_array_type(sa::SparseMatrixCSC) = SparseMatrixCSC
 
 function sparse_array_type(sa::AbstractGPUSparseArray) end
 function dense_array_type(sa::AbstractGPUSparseArray) end
 function coo_type(sa::AbstractGPUSparseArray) end
-coo_type(::SA) where {SA<:AbstractGPUSparseMatrixCSC} = SA
+coo_type(::SA) where {SA <: AbstractGPUSparseMatrixCSC} = SA
 
 function _spadjoint end
 function _sptranspose end
 
-function LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSR, p::Real=2)
+function LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSR, p::Real = 2)
     if p == Inf
-        return maximum(sum(abs, A; dims=2))
+        return maximum(sum(abs, A; dims = 2))
     elseif p == 1
-        return maximum(sum(abs, A; dims=1))
+        return maximum(sum(abs, A; dims = 1))
     else
         throw(ArgumentError("p=$p is not supported"))
     end
 end
 
-LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSC, p::Real=2) = opnorm(_csr_type(A)(A), p)
+LinearAlgebra.opnorm(A::AbstractGPUSparseMatrixCSC, p::Real = 2) = opnorm(_csr_type(A)(A), p)
 
-function LinearAlgebra.norm(A::AbstractGPUSparseMatrix{T}, p::Real=2) where T
+function LinearAlgebra.norm(A::AbstractGPUSparseMatrix{T}, p::Real = 2) where {T}
     if p == Inf
         return maximum(abs.(SparseArrays.nonzeros(A)))
     elseif p == -Inf
@@ -85,7 +85,7 @@ function LinearAlgebra.norm(A::AbstractGPUSparseMatrix{T}, p::Real=2) where T
     elseif p == 0
         return Float64(SparseArrays.nnz(A))
     else
-        return sum(abs.(SparseArrays.nonzeros(A)).^p)^(1/p)
+        return sum(abs.(SparseArrays.nonzeros(A)) .^ p)^(1 / p)
     end
 end
 
@@ -105,22 +105,24 @@ function SparseArrays.findnz(S::MT) where {MT <: AbstractGPUSparseMatrix}
 end
 
 ### WRAPPED ARRAYS
-LinearAlgebra.issymmetric(M::Union{AbstractGPUSparseMatrixCSC,AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - transpose(M), Inf) == 0 : false
-LinearAlgebra.ishermitian(M::Union{AbstractGPUSparseMatrixCSC,AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - adjoint(M), Inf) == 0 : false
+LinearAlgebra.issymmetric(M::Union{AbstractGPUSparseMatrixCSC, AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - transpose(M), Inf) == 0 : false
+LinearAlgebra.ishermitian(M::Union{AbstractGPUSparseMatrixCSC, AbstractGPUSparseMatrixCSR}) = size(M, 1) == size(M, 2) ? norm(M - adjoint(M), Inf) == 0 : false
 
-LinearAlgebra.istriu(M::UpperTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
-LinearAlgebra.istril(M::UpperTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
-LinearAlgebra.istriu(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
-LinearAlgebra.istril(M::LowerTriangular{T,S}) where {T<:BlasFloat, S<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
+LinearAlgebra.istriu(M::UpperTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
+LinearAlgebra.istril(M::UpperTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
+LinearAlgebra.istriu(M::LowerTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = false
+LinearAlgebra.istril(M::LowerTriangular{T, S}) where {T <: BlasFloat, S <: Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}} = true
 
-Hermitian{T}(Mat::AbstractGPUSparseMatrix{T}) where {T} = Hermitian{eltype(Mat),typeof(Mat)}(Mat,'U')
+Hermitian{T}(Mat::AbstractGPUSparseMatrix{T}) where {T} = Hermitian{eltype(Mat), typeof(Mat)}(Mat, 'U')
 
 # work around upstream breakage from JuliaLang/julia#55547
 @static if VERSION >= v"1.11.2"
     const GPUSparseUpperOrUnitUpperTriangular = LinearAlgebra.UpperOrUnitUpperTriangular{
-        <:Any,<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}}
+        <:Any, <:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}},
+    }
     const GPUSparseLowerOrUnitLowerTriangular = LinearAlgebra.LowerOrUnitLowerTriangular{
-        <:Any,<:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}}}
+        <:Any, <:Union{<:AbstractGPUSparseMatrix, Adjoint{<:Any, <:AbstractGPUSparseMatrix}, Transpose{<:Any, <:AbstractGPUSparseMatrix}},
+    }
     LinearAlgebra.istriu(::GPUSparseUpperOrUnitUpperTriangular) = true
     LinearAlgebra.istril(::GPUSparseUpperOrUnitUpperTriangular) = false
     LinearAlgebra.istriu(::GPUSparseLowerOrUnitLowerTriangular) = false
@@ -129,67 +131,67 @@ end
 
 for SparseMatrixType in [:AbstractGPUSparseMatrixCSC, :AbstractGPUSparseMatrixCSR]
     @eval begin
-        LinearAlgebra.triu(A::ST, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
-            ST( triu(coo_type(A)(A), k) )
-        LinearAlgebra.triu(A::Transpose{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
-            ST( triu(coo_type(A)(_sptranspose(parent(A))), k) )
-        LinearAlgebra.triu(A::Adjoint{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
-            ST( triu(coo_type(A)(_spadjoint(parent(A))), k) )
-
-        LinearAlgebra.tril(A::ST, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
-            ST( tril(coo_type(A)(A), k) )
-        LinearAlgebra.tril(A::Transpose{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
-            ST( tril(coo_type(A)(_sptranspose(parent(A))), k) )
-        LinearAlgebra.tril(A::Adjoint{T,<:ST}, k::Integer) where {T, ST<:$SparseMatrixType{T}} =
-            ST( tril(coo_type(A)(_spadjoint(parent(A))), k) )
-
-        LinearAlgebra.triu(A::Union{ST, Transpose{T,<:ST}, Adjoint{T,<:ST}}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( triu(coo_type(A)(A), 0) )
-        LinearAlgebra.tril(A::Union{ST,Transpose{T,<:ST}, Adjoint{T,<:ST}}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( tril(coo_type(A)(A), 0) )
-
-        LinearAlgebra.kron(A::ST, B::ST) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(A), coo_type(B)(B)) )
-        LinearAlgebra.kron(A::ST, B::Diagonal) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(A), B) )
-        LinearAlgebra.kron(A::Diagonal, B::ST) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(A, coo_type(B)(B)) )
-
-        LinearAlgebra.kron(A::Transpose{T,<:ST}, B::ST) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(_sptranspose(parent(A))), coo_type(B)(B)) )
-        LinearAlgebra.kron(A::ST, B::Transpose{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(A), coo_type(parent(B))(_sptranspose(parent(B)))) )
-        LinearAlgebra.kron(A::Transpose{T,<:ST}, B::Transpose{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(_sptranspose(parent(A))), coo_type(parent(B))(_sptranspose(parent(B)))) )
-        LinearAlgebra.kron(A::Transpose{T,<:ST}, B::Diagonal) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(_sptranspose(parent(A))), B) )
-        LinearAlgebra.kron(A::Diagonal, B::Transpose{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(A, coo_type(B)(_sptranspose(parent(B)))) )
-
-        LinearAlgebra.kron(A::Adjoint{T,<:ST}, B::ST) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(_spadjoint(parent(A))), coo_type(B)(B)) )
-        LinearAlgebra.kron(A::ST, B::Adjoint{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(A)(A), coo_type(parent(B))(_spadjoint(parent(B)))) )
-        LinearAlgebra.kron(A::Adjoint{T,<:ST}, B::Adjoint{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(parent(A))(_spadjoint(parent(A))), coo_type(parent(B))(_spadjoint(parent(B)))) )
-        LinearAlgebra.kron(A::Adjoint{T,<:ST}, B::Diagonal) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(coo_type(parent(A))(_spadjoint(parent(A))), B) )
-        LinearAlgebra.kron(A::Diagonal, B::Adjoint{T,<:ST}) where {T, ST<:$SparseMatrixType{T}} =
-            ST( kron(A, coo_type(parent(B))(_spadjoint(parent(B)))) )
-
-
-        function Base.reshape(A::ST, dims::Dims) where {ST<:$SparseMatrixType}
+        LinearAlgebra.triu(A::ST, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+            ST(triu(coo_type(A)(A), k))
+        LinearAlgebra.triu(A::Transpose{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+            ST(triu(coo_type(A)(_sptranspose(parent(A))), k))
+        LinearAlgebra.triu(A::Adjoint{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+            ST(triu(coo_type(A)(_spadjoint(parent(A))), k))
+
+        LinearAlgebra.tril(A::ST, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+            ST(tril(coo_type(A)(A), k))
+        LinearAlgebra.tril(A::Transpose{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+            ST(tril(coo_type(A)(_sptranspose(parent(A))), k))
+        LinearAlgebra.tril(A::Adjoint{T, <:ST}, k::Integer) where {T, ST <: $SparseMatrixType{T}} =
+            ST(tril(coo_type(A)(_spadjoint(parent(A))), k))
+
+        LinearAlgebra.triu(A::Union{ST, Transpose{T, <:ST}, Adjoint{T, <:ST}}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(triu(coo_type(A)(A), 0))
+        LinearAlgebra.tril(A::Union{ST, Transpose{T, <:ST}, Adjoint{T, <:ST}}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(tril(coo_type(A)(A), 0))
+
+        LinearAlgebra.kron(A::ST, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(A), coo_type(B)(B)))
+        LinearAlgebra.kron(A::ST, B::Diagonal) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(A), B))
+        LinearAlgebra.kron(A::Diagonal, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(A, coo_type(B)(B)))
+
+        LinearAlgebra.kron(A::Transpose{T, <:ST}, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(_sptranspose(parent(A))), coo_type(B)(B)))
+        LinearAlgebra.kron(A::ST, B::Transpose{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(A), coo_type(parent(B))(_sptranspose(parent(B)))))
+        LinearAlgebra.kron(A::Transpose{T, <:ST}, B::Transpose{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(_sptranspose(parent(A))), coo_type(parent(B))(_sptranspose(parent(B)))))
+        LinearAlgebra.kron(A::Transpose{T, <:ST}, B::Diagonal) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(_sptranspose(parent(A))), B))
+        LinearAlgebra.kron(A::Diagonal, B::Transpose{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(A, coo_type(B)(_sptranspose(parent(B)))))
+
+        LinearAlgebra.kron(A::Adjoint{T, <:ST}, B::ST) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(_spadjoint(parent(A))), coo_type(B)(B)))
+        LinearAlgebra.kron(A::ST, B::Adjoint{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(A)(A), coo_type(parent(B))(_spadjoint(parent(B)))))
+        LinearAlgebra.kron(A::Adjoint{T, <:ST}, B::Adjoint{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(parent(A))(_spadjoint(parent(A))), coo_type(parent(B))(_spadjoint(parent(B)))))
+        LinearAlgebra.kron(A::Adjoint{T, <:ST}, B::Diagonal) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(coo_type(parent(A))(_spadjoint(parent(A))), B))
+        LinearAlgebra.kron(A::Diagonal, B::Adjoint{T, <:ST}) where {T, ST <: $SparseMatrixType{T}} =
+            ST(kron(A, coo_type(parent(B))(_spadjoint(parent(B)))))
+
+
+        function Base.reshape(A::ST, dims::Dims) where {ST <: $SparseMatrixType}
             B = coo_type(A)(A)
-            ST(reshape(B, dims))
+            return ST(reshape(B, dims))
         end
 
-        function SparseArrays.droptol!(A::ST, tol::Real) where {ST<:$SparseMatrixType}
+        function SparseArrays.droptol!(A::ST, tol::Real) where {ST <: $SparseMatrixType}
             B = coo_type(A)(A)
             droptol!(B, tol)
-            copyto!(A, ST(B))
+            return copyto!(A, ST(B))
         end
 
-        function LinearAlgebra.exp(A::ST; threshold = 1e-7, nonzero_tol = 1e-14) where {ST<:$SparseMatrixType}
+        function LinearAlgebra.exp(A::ST; threshold = 1.0e-7, nonzero_tol = 1.0e-14) where {ST <: $SparseMatrixType}
             rows = LinearAlgebra.checksquare(A) # Throws exception if not square
             typeA = eltype(A)
 
@@ -211,40 +213,40 @@ for SparseMatrixType in [:AbstractGPUSparseMatrixCSC, :AbstractGPUSparseMatrixCS
                 copyto!(P, P + next_term)
                 n = n + 1
             end
-            for n = 1:log2(scaling_factor)
-                P = P * P;
+            for n in 1:log2(scaling_factor)
+                P = P * P
                 if nnz(P) / length(P) < 0.25
                     droptol!(P, nonzero_tol)
                 end
             end
-            P
+            return P
         end
     end
 end
 
 
 ### INDEXING
-Base.getindex(A::AbstractGPUSparseVector, ::Colon)          = copy(A)
+Base.getindex(A::AbstractGPUSparseVector, ::Colon) = copy(A)
 Base.getindex(A::AbstractGPUSparseMatrix, ::Colon, ::Colon) = copy(A)
-Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon)       = getindex(A, i, 1:size(A, 2))
-Base.getindex(A::AbstractGPUSparseMatrix, ::Colon, i)       = getindex(A, 1:size(A, 1), i)
-Base.getindex(A::AbstractGPUSparseMatrix, I::Tuple{Integer,Integer}) = getindex(A, I[1], I[2])
+Base.getindex(A::AbstractGPUSparseMatrix, i, ::Colon) = getindex(A, i, 1:size(A, 2))
+Base.getindex(A::AbstractGPUSparseMatrix, ::Colon, i) = getindex(A, 1:size(A, 1), i)
+Base.getindex(A::AbstractGPUSparseMatrix, I::Tuple{Integer, Integer}) = getindex(A, I[1], I[2])
 
 function Base.getindex(A::AbstractGPUSparseVector{Tv, Ti}, i::Integer) where {Tv, Ti}
     @boundscheck checkbounds(A, i)
     ii = searchsortedfirst(SparseArrays.nonzeroinds(A), convert(Ti, i))
     (ii > SparseArrays.nnz(A) || SparseArrays.nonzeroinds(A)[ii] != i) && return zero(Tv)
-    SparseArrays.nonzeros(A)[ii]
+    return SparseArrays.nonzeros(A)[ii]
 end
 
-function Base.getindex(A::AbstractGPUSparseMatrixCSC{T}, i0::Integer, i1::Integer) where T
+function Base.getindex(A::AbstractGPUSparseMatrixCSC{T}, i0::Integer, i1::Integer) where {T}
     @boundscheck checkbounds(A, i0, i1)
     r1 = Int(SparseArrays.getcolptr(A)[i1])
-    r2 = Int(SparseArrays.getcolptr(A)[i1+1]-1)
+    r2 = Int(SparseArrays.getcolptr(A)[i1 + 1] - 1)
     (r1 > r2) && return zero(T)
     r1 = searchsortedfirst(SparseArrays.rowvals(A), i0, r1, r2, Base.Order.Forward)
     (r1 > r2 || SparseArrays.rowvals(A)[r1] != i0) && return zero(T)
-    SparseArrays.nonzeros(A)[r1]
+    return SparseArrays.nonzeros(A)[r1]
 end
 
 ## copying between sparse GPU arrays
@@ -259,7 +261,7 @@ function Base.copyto!(dst::AbstractGPUSparseVector, src::AbstractGPUSparseVector
     copyto!(SparseArrays.nonzeroinds(dst), SparseArrays.nonzeroinds(src))
     copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
     dst.nnz = src.nnz
-    dst
+    return dst
 end
 
 function Base.copyto!(dst::AbstractGPUSparseMatrixCSC, src::AbstractGPUSparseMatrixCSC)
@@ -273,7 +275,7 @@ function Base.copyto!(dst::AbstractGPUSparseMatrixCSC, src::AbstractGPUSparseMat
     copyto!(SparseArrays.rowvals(dst), SparseArrays.rowvals(src))
     copyto!(SparseArrays.nonzeros(dst), SparseArrays.nonzeros(src))
     dst.nnz = src.nnz
-    dst
+    return dst
 end
 
 ### BROADCAST
@@ -283,7 +285,7 @@ struct GPUSparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
 struct GPUSparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
 Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseVector}) = GPUSparseVecStyle()
 Broadcast.BroadcastStyle(::Type{<:AbstractGPUSparseMatrix}) = GPUSparseMatStyle()
-const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
+const SPVM = Union{GPUSparseVecStyle, GPUSparseMatStyle}
 
 # GPUSparseVecStyle handles 0-1 dimensions, GPUSparseMatStyle 0-2 dimensions.
 # GPUSparseVecStyle promotes to GPUSparseMatStyle for 2 dimensions.
@@ -291,11 +293,11 @@ const SPVM = Union{GPUSparseVecStyle,GPUSparseMatStyle}
 GPUSparseVecStyle(::Val{0}) = GPUSparseVecStyle()
 GPUSparseVecStyle(::Val{1}) = GPUSparseVecStyle()
 GPUSparseVecStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseVecStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
 GPUSparseMatStyle(::Val{0}) = GPUSparseMatStyle()
 GPUSparseMatStyle(::Val{1}) = GPUSparseMatStyle()
 GPUSparseMatStyle(::Val{2}) = GPUSparseMatStyle()
-GPUSparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
+GPUSparseMatStyle(::Val{N}) where {N} = Broadcast.DefaultArrayStyle{N}()
 
 Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{1}) = GPUSparseVecStyle()
 Broadcast.BroadcastStyle(::GPUSparseVecStyle, ::AbstractGPUArrayStyle{2}) = GPUSparseMatStyle()
@@ -322,37 +324,37 @@ end
 
 # Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
 @inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
-    capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+    capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
 @inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} =
     # This definition is identical to the one above and necessary only for
     # avoiding method ambiguity.
-    capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
+    capturescalars((args...) -> f(T, args...), Base.tail(mixedargs))
 @inline capturescalars(f, mixedargs::Tuple{AbstractGPUSparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
-    capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
-@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
-    capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
+    capturescalars((a1, args...) -> f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
+@inline capturescalars(f, mixedargs::Tuple{Union{Ref, AbstractArray{<:Any, 0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
+    capturescalars((args...) -> f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
 
 scalararg(::Number) = true
 scalararg(::Any) = false
-scalarwrappedarg(::Union{AbstractArray{<:Any,0},Ref}) = true
+scalarwrappedarg(::Union{AbstractArray{<:Any, 0}, Ref}) = true
 scalarwrappedarg(::Any) = false
 
 @inline function _capturescalars()
     return (), () -> ()
 end
 @inline function _capturescalars(arg, mixedargs...)
-    let (rest, f) = _capturescalars(mixedargs...)
+    return let (rest, f) = _capturescalars(mixedargs...)
         if scalararg(arg)
-            return rest, @inline function(tail...)
-                (arg, f(tail...)...)
+            return rest, @inline function (tail...)
+                    return (arg, f(tail...)...)
             end # add back scalararg after (in makeargs)
         elseif scalarwrappedarg(arg)
-            return rest, @inline function(tail...)
-                (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
+            return rest, @inline function (tail...)
+                    return (arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
             end # unwrap and add back scalararg after (in makeargs)
         else
-            return (arg, rest...), @inline function(head, tail...)
-                (head, f(tail...)...)
+            return (arg, rest...), @inline function (head, tail...)
+                    return (head, f(tail...)...)
             end # pass-through to broadcast
         end
     end
@@ -388,13 +390,13 @@ access the elements themselves.
 For convenience, this iterator can be passed non-sparse arguments as well, which will be
 ignored (with the returned `col`/`ptr` values set to 0).
 """
-struct CSRIterator{Ti,N,ATs}
+struct CSRIterator{Ti, N, ATs}
     row::Ti
     col_ends::NTuple{N, Ti}
     args::ATs
 end
 
-function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
+function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti, N}
     # check that `row` is valid for all arguments
     @boundscheck begin
         ntuple(Val(N)) do i
@@ -406,16 +408,16 @@ function CSRIterator{Ti}(row, args::Vararg{Any, N}) where {Ti,N}
     col_ends = ntuple(Val(N)) do i
         arg = @inbounds args[i]
         if arg isa GPUSparseDeviceMatrixCSR
-            @inbounds(arg.rowPtr[row+1])
+            @inbounds(arg.rowPtr[row + 1])
         else
             zero(Ti)
         end
     end
 
-    CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
+    return CSRIterator{Ti, N, typeof(args)}(row, col_ends, args)
 end
 
-@inline function Base.iterate(iter::CSRIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSRIterator{Ti, N}, state = nothing) where {Ti, N}
     # helper function to get the column of a sparse array at a specific pointer
     @inline function get_col(i, ptr)
         arg = @inbounds iter.args[i]
@@ -425,13 +427,14 @@ end
                 return @inbounds arg.colVal[ptr] % Ti
             end
         end
-        typemax(Ti)
+        return typemax(Ti)
     end
 
     # initialize the state
     # - ptr: the current index into the colVal/nzVal arrays
     # - col: the current column index (cached so that we don't have to re-read each time)
-    state = something(state,
+    state = something(
+        state,
         ntuple(Val(N)) do i
             arg = @inbounds iter.args[i]
             if arg isa GPUSparseDeviceMatrixCSR
@@ -473,13 +476,13 @@ end
     return (cur_col, ptrs), new_state
 end
 
-struct CSCIterator{Ti,N,ATs}
+struct CSCIterator{Ti, N, ATs}
     col::Ti
     row_ends::NTuple{N, Ti}
     args::ATs
 end
 
-function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
+function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti, N}
     # check that `col` is valid for all arguments
     @boundscheck begin
         ntuple(Val(N)) do i
@@ -491,17 +494,17 @@ function CSCIterator{Ti}(col, args::Vararg{Any, N}) where {Ti,N}
     row_ends = ntuple(Val(N)) do i
         arg = @inbounds args[i]
         x = if arg isa GPUSparseDeviceMatrixCSC
-            @inbounds(arg.colPtr[col+1])
+            @inbounds(arg.colPtr[col + 1])
         else
             zero(Ti)
         end
         x
     end
 
-    CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
+    return CSCIterator{Ti, N, typeof(args)}(col, row_ends, args)
 end
 
-@inline function Base.iterate(iter::CSCIterator{Ti,N}, state=nothing) where {Ti,N}
+@inline function Base.iterate(iter::CSCIterator{Ti, N}, state = nothing) where {Ti, N}
     # helper function to get the column of a sparse array at a specific pointer
     @inline function get_col(i, ptr)
         arg = @inbounds iter.args[i]
@@ -511,13 +514,14 @@ end
                 return @inbounds arg.rowVal[ptr] % Ti
             end
         end
-        typemax(Ti)
+        return typemax(Ti)
     end
 
     # initialize the state
     # - ptr: the current index into the rowVal/nzVal arrays
     # - row: the current row index (cached so that we don't have to re-read each time)
-    state = something(state,
+    state = something(
+        state,
         ntuple(Val(N)) do i
             arg = @inbounds iter.args[i]
             if arg isa GPUSparseDeviceMatrixCSC
@@ -560,8 +564,8 @@ end
 end
 
 # helpers to index a sparse or dense array
-function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR,GPUSparseDeviceMatrixCSC}, I, ptr)
-    if ptr == 0
+function _getindex(arg::Union{<:GPUSparseDeviceMatrixCSR, GPUSparseDeviceMatrixCSC}, I, ptr)
+    return if ptr == 0
         zero(eltype(arg))
     else
         @inbounds arg.nzVal[ptr]
@@ -589,9 +593,11 @@ function _has_row(A::GPUSparseDeviceVector, offsets, row, ::Bool)
     return 0
 end
 
-@kernel function compute_offsets_kernel(::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
-                                        fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
-                                        args...) where {Ti, N}
+@kernel function compute_offsets_kernel(
+        ::Type{<:AbstractGPUSparseVector}, first_row::Ti, last_row::Ti,
+        fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+        args...
+    ) where {Ti, N}
     my_ix = @index(Global, Linear)
     row = my_ix + first_row - one(eltype(my_ix))
     if row ≤ last_row
@@ -610,11 +616,13 @@ end
 end
 
 # kernel to count the number of non-zeros in a row, to determine the row offsets
-@kernel function compute_offsets_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
-                                        offsets::AbstractVector{Ti}, args...) where Ti
+@kernel function compute_offsets_kernel(
+        T::Type{<:Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC}},
+        offsets::AbstractVector{Ti}, args...
+    ) where {Ti}
     # every thread processes an entire row
     leading_dim = @index(Global, Linear)
-    if leading_dim ≤ length(offsets)-1 
+    if leading_dim ≤ length(offsets) - 1
         iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
 
         # count the nonzero leading_dims of all inputs
@@ -629,19 +637,21 @@ end
             if leading_dim == 1
                 offsets[1] = 1
             end
-            offsets[leading_dim+1] = accum
+            offsets[leading_dim + 1] = accum
         end
     end
 end
 
-@kernel function sparse_to_sparse_broadcast_kernel(f::F, output::GPUSparseDeviceVector{Tv,Ti},
-                                                   offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
-                                                   args...) where {Tv, Ti, N, F}
+@kernel function sparse_to_sparse_broadcast_kernel(
+        f::F, output::GPUSparseDeviceVector{Tv, Ti},
+        offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+        args...
+    ) where {Tv, Ti, N, F}
     row_ix = @index(Global, Linear)
     if row_ix ≤ output.nnz
         row_and_ptrs = @inbounds offsets[row_ix]
-        row          = @inbounds row_and_ptrs[1]
-        arg_ptrs     = @inbounds row_and_ptrs[2]
+        row = @inbounds row_and_ptrs[1]
+        arg_ptrs = @inbounds row_and_ptrs[2]
         vals = ntuple(Val(N)) do i
             @inline
             arg = @inbounds args[i]
@@ -651,14 +661,20 @@ end
             _getindex(arg, row, ptr)
         end
         output_val = f(vals...)
-        @inbounds output.iPtr[row_ix]  = row
+        @inbounds output.iPtr[row_ix] = row
         @inbounds output.nzVal[row_ix] = output_val
     end
 end
 
-@kernel function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{<:AbstractArray,Nothing},
-                                                   args...) where {Ti, T<:Union{GPUSparseDeviceMatrixCSR{<:Any,Ti},
-                                                                                GPUSparseDeviceMatrixCSC{<:Any,Ti}}}
+@kernel function sparse_to_sparse_broadcast_kernel(
+        f, output::T, offsets::Union{<:AbstractArray, Nothing},
+        args...
+    ) where {
+        Ti, T <: Union{
+            GPUSparseDeviceMatrixCSR{<:Any, Ti},
+            GPUSparseDeviceMatrixCSC{<:Any, Ti},
+        },
+    }
     # every thread processes an entire row
     leading_dim = @index(Global, Linear)
     leading_dim_size = output isa GPUSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
@@ -666,19 +682,19 @@ end
         iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
 
 
-        output_ptrs  = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
+        output_ptrs = output isa GPUSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
         output_ivals = output isa GPUSparseDeviceMatrixCSR ? output.colVal : output.rowVal
         # fetch the row offset, and write it to the output
         @inbounds begin
             output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
             if leading_dim == leading_dim_size
-                output_ptrs[leading_dim+one(eltype(leading_dim))] = offsets[leading_dim+one(eltype(leading_dim))]
+                output_ptrs[leading_dim + one(eltype(leading_dim))] = offsets[leading_dim + one(eltype(leading_dim))]
             end
         end
 
         # set the values for this row
         for (sub_leading_dim, ptrs) in iter
-            index_first  = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
+            index_first = output isa GPUSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
             index_second = output isa GPUSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim
             I = CartesianIndex(index_first, index_second)
             vals = ntuple(Val(length(args))) do i
@@ -693,9 +709,15 @@ end
         end
     end
 end
-@kernel function sparse_to_dense_broadcast_kernel(T::Type{<:Union{AbstractGPUSparseMatrixCSR{Tv, Ti},
-                                                                  AbstractGPUSparseMatrixCSC{Tv, Ti}}},
-                                                  f, output::AbstractArray, args...) where {Tv, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+        T::Type{
+            <:Union{
+                AbstractGPUSparseMatrixCSR{Tv, Ti},
+                AbstractGPUSparseMatrixCSC{Tv, Ti},
+            },
+        },
+        f, output::AbstractArray, args...
+    ) where {Tv, Ti}
     # every thread processes an entire row
     leading_dim = @index(Global, Linear)
     leading_dim_size = T <: AbstractGPUSparseMatrixCSR ? size(output, 1) : size(output, 2)
@@ -704,7 +726,7 @@ end
 
         # set the values for this row
         for (sub_leading_dim, ptrs) in iter
-            index_first  = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
+            index_first = T <: AbstractGPUSparseMatrixCSR ? leading_dim : sub_leading_dim
             index_second = T <: AbstractGPUSparseMatrixCSR ? sub_leading_dim : leading_dim
             I = CartesianIndex(index_first, index_second)
             vals = ntuple(Val(length(args))) do i
@@ -718,16 +740,18 @@ end
     end
 end
 
-@kernel function sparse_to_dense_broadcast_kernel(::Type{<:AbstractGPUSparseVector}, f::F,
-                                                  output::AbstractArray{Tv},
-                                                  offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
-                                                  args...) where {Tv, F, N, Ti}
+@kernel function sparse_to_dense_broadcast_kernel(
+        ::Type{<:AbstractGPUSparseVector}, f::F,
+        output::AbstractArray{Tv},
+        offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
+        args...
+    ) where {Tv, F, N, Ti}
     # every thread processes an entire row
     row_ix = @index(Global, Linear)
     if row_ix ≤ length(output)
         row_and_ptrs = @inbounds offsets[row_ix]
-        row          = @inbounds row_and_ptrs[1]
-        arg_ptrs     = @inbounds row_and_ptrs[2]
+        row = @inbounds row_and_ptrs[1]
+        arg_ptrs = @inbounds row_and_ptrs[2]
         vals = ntuple(Val(length(args))) do i
             @inline
             arg = @inbounds args[i]
@@ -742,39 +766,41 @@ end
 end
 ## COV_EXCL_STOP
 
-function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatStyle}})
+function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle, GPUSparseMatStyle}})
     # find the sparse inputs
     bc = Broadcast.flatten(bc)
     sparse_args = findall(bc.args) do arg
         arg isa AbstractGPUSparseArray
     end
-    sparse_types = unique(map(i->nameof(typeof(bc.args[i])), sparse_args))
+    sparse_types = unique(map(i -> nameof(typeof(bc.args[i])), sparse_args))
     if length(sparse_types) > 1
         error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported")
     end
     sparse_typ = typeof(bc.args[first(sparse_args)])
-    sparse_typ <: Union{AbstractGPUSparseMatrixCSR,AbstractGPUSparseMatrixCSC,AbstractGPUSparseVector} ||
+    sparse_typ <: Union{AbstractGPUSparseMatrixCSR, AbstractGPUSparseMatrixCSC, AbstractGPUSparseVector} ||
         error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices")
     Ti = if sparse_typ <: AbstractGPUSparseMatrixCSR
-        reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args))
+        reduce(promote_type, map(i -> eltype(bc.args[i].rowPtr), sparse_args))
     elseif sparse_typ <: AbstractGPUSparseMatrixCSC
-        reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args))
+        reduce(promote_type, map(i -> eltype(bc.args[i].colPtr), sparse_args))
     elseif sparse_typ <: AbstractGPUSparseVector
-        reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args))
+        reduce(promote_type, map(i -> eltype(bc.args[i].iPtr), sparse_args))
     end
 
     # determine the output type
     Tv = Broadcast.combine_eltypes(bc.f, eltype.(bc.args))
     if !Base.isconcretetype(Tv)
-        error("""GPU sparse broadcast resulted in non-concrete element type $Tv.
-                 This probably means that the function you are broadcasting contains an error or type instability.""")
+        error(
+            """GPU sparse broadcast resulted in non-concrete element type $Tv.
+            This probably means that the function you are broadcasting contains an error or type instability."""
+        )
     end
 
     # partially-evaluate the function, removing scalars.
     parevalf, passedsrcargstup = capturescalars(bc.f, bc.args)
     # check if the partially-evaluated function preserves zeros. if so, we'll only need to
     # apply it to the sparse input arguments, preserving the sparse structure.
-    if all(arg->isa(arg, AbstractSparseArray), passedsrcargstup)
+    if all(arg -> isa(arg, AbstractSparseArray), passedsrcargstup)
         fofzeros = parevalf(_zeros_eltypes(passedsrcargstup...)...)
         fpreszeros = _iszero(fofzeros)
     else
@@ -783,7 +809,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
 
     # the kernels below parallelize across rows or cols, not elements, so it's unlikely
     # we'll launch many threads. to maximize utilization, parallelize across blocks first.
-    rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) 
+    rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1)
     # `size(bc, ::Int)` is missing
     # for AbstractGPUSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
     # but the only rows present in any sparse vector input are between 2 and 128, no need to
@@ -798,7 +824,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
         # either we have dense inputs, or the function isn't preserving zeros,
         # so use a dense output to broadcast into.
         val_array = nonzeros(sparse_arg)
-        output    = similar(val_array, Tv, size(bc))
+        output = similar(val_array, Tv, size(bc))
         # since we'll be iterating the sparse inputs, we need to pre-fill the dense output
         # with appropriate values (while setting the sparse inputs to zero). we do this by
         # re-using the dense broadcast implementation.
@@ -816,24 +842,24 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
         # this avoids a kernel launch and costly synchronization.
         if sparse_typ <: AbstractGPUSparseMatrixCSR
             offsets = rowPtr = sparse_arg.rowPtr
-            colVal  = similar(sparse_arg.colVal)
-            nzVal   = similar(sparse_arg.nzVal, Tv)
-            output  = sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
+            colVal = similar(sparse_arg.colVal)
+            nzVal = similar(sparse_arg.nzVal, Tv)
+            output = sparse_array_type(sparse_typ)(rowPtr, colVal, nzVal, size(bc))
         elseif sparse_typ <: AbstractGPUSparseMatrixCSC
             offsets = colPtr = sparse_arg.colPtr
-            rowVal  = similar(sparse_arg.rowVal)
-            nzVal   = similar(sparse_arg.nzVal, Tv)
-            output  = sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
+            rowVal = similar(sparse_arg.rowVal)
+            nzVal = similar(sparse_arg.nzVal, Tv)
+            output = sparse_array_type(sparse_typ)(colPtr, rowVal, nzVal, size(bc))
         end
     else
         # determine the number of non-zero elements per row so that we can create an
         # appropriately-structured output container
         offsets = if sparse_typ <: AbstractGPUSparseMatrixCSR
             ptr_array = sparse_arg.rowPtr
-            similar(ptr_array, Ti, rows+1)
+            similar(ptr_array, Ti, rows + 1)
         elseif sparse_typ <: AbstractGPUSparseMatrixCSC
             ptr_array = sparse_arg.colPtr
-            similar(ptr_array, Ti, cols+1)
+            similar(ptr_array, Ti, cols + 1)
         elseif sparse_typ <: AbstractGPUSparseVector
             ptr_array = sparse_arg.iPtr
             @allowscalar begin
@@ -847,7 +873,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
                 end
             end
             overall_first_row = min(arg_first_rows...)
-            overall_last_row  = max(arg_last_rows...)
+            overall_last_row = max(arg_last_rows...)
             similar(ptr_array, Pair{Ti, NTuple{length(bc.args), Ti}}, overall_last_row - overall_first_row + 1)
         end
         let
@@ -857,7 +883,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{GPUSparseVecStyle,GPUSparseMatSt
                 (sparse_typ, offsets, bc.args...)
             end
             kernel = compute_offsets_kernel(get_backend(bc.args[first(sparse_args)]))
-            kernel(args...; ndrange=length(offsets))
+            kernel(args...; ndrange = length(offsets))
         end
         # accumulate these values so that we can use them directly as row pointer offsets,
         # as well a...*[Comment body truncated]*

@kshyatt kshyatt changed the title [WIP] Sparse GPU array and broadcasting support Sparse GPU array and broadcasting support Oct 9, 2025
Comment on lines 18 to 23
struct GPUSparseDeviceVector{Tv,Ti,Vi,Vv} <: AbstractSparseVector{Tv,Ti}
iPtr::Vi
nzVal::Vv
len::Int
nnz::Ti
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit inconsistent that we keep the host sparse object layout to the back-end, but define the device one concretely. I'm not sure if it's better to entirely move the definitions away from (or rather into) GPUArrays.jl though. I guess back-ends may want additional control over the object layout in order to facilitate vendor library interactions, but maybe we should then also leave the device-side version up to the back-end and only implement things here in terms of SparseArrays interfaces (rowvals, getcolptr, etc). Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was torn on this too. The advantage here is that libraries get a working device-side implementation "for free" -- they are able to implement their own (better) one and just give Adapt.jl information about how to move their host-side structs to it.

@kshyatt
Copy link
Member Author

kshyatt commented Oct 18, 2025

Now with added support for mapreduce for sparse matrices!

@kshyatt
Copy link
Member Author

kshyatt commented Oct 18, 2025

Made some small updates based on tests "in the wild" with CUDA

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now may be a good time to start and document the exact interface methods that back-ends should implement. Could you add a section to the README on the added SparseArray methods?

## Interface methods

...

### Sparse Array support (optional)

...

V = S2.nzVal

# To make it compatible with the SparseArrays.jl version
idxs = sortperm(J)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have sorting here yet, so I take it this doesn't work everywhere?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants