Skip to content

fixes the kron implementation for sparse + diagonal matrix #2804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

tam724
Copy link

@tam724 tam724 commented Jun 27, 2025

This generalizes the implementation of kron for the combination of a (CUDA) sparse matrix and diagonal matrix.

Currently the Diagonal type is treated as a I (UniformScaling) with certain dimension (n x n). This is not the intended use of Diagonal (julia docs) and the following code return a wrong result when using CUDA:

julia> using CUDA, SparseArrays, LinearAlgebra
julia> A = sparse(ones(1, 1))
1×1 SparseMatrixCSC{Float64, Int64} with 1 stored entry:
 1.0
julia> B = Diagonal(rand(1))
1×1 Diagonal{Float64, Vector{Float64}}:
 0.4618063241112559
julia> kron(A, B)
1×1 SparseMatrixCSC{Float64, Int64} with 1 stored entry:
 0.4618063241112559
julia> kron(A |> cu, B |> cu)
1×1 CUDA.CUSPARSE.CuSparseMatrixCSC{Float32, Int32} with 1 stored entry:
 1.0

This implementation keeps the old behaviour: multiplication with an I(3)::Diagonal{Bool, Vector{Bool}} is still interpreted as multiplication with the identity.
Multiplication with a Diagonal matrix should be fixed. I also added some tests.

Copy link
Contributor

github-actions bot commented Jun 27, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cusparse/linalg.jl b/lib/cusparse/linalg.jl
index e72d1a084..73c6f9ca3 100644
--- a/lib/cusparse/linalg.jl
+++ b/lib/cusparse/linalg.jl
@@ -62,13 +62,13 @@ _kron_CuSparseMatrixCOO_components(At::Transpose{<:Number, <:CuSparseMatrixCOO})
 _kron_CuSparseMatrixCOO_components(Ah::Adjoint{<:Number, <:CuSparseMatrixCOO}) = parent(Ah).colInd, parent(Ah).rowInd, parent(Ah).nzVal, adjoint, Int(parent(Ah).nnz)
 
 function LinearAlgebra.kron(
-    A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
-    B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
+        A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
+        B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
     ) where {TvA, TiA, TvB, TiB}
     mA, nA = size(A)
     mB, nB = size(B)
     Ti = promote_type(TiA, TiB)
-    Tv = typeof(oneunit(TvA)*oneunit(TvB))
+    Tv = typeof(oneunit(TvA) * oneunit(TvB))
 
     A_rowInd, A_colInd, A_nzVal, A_nzOp, A_nnz = _kron_CuSparseMatrixCOO_components(A)
     B_rowInd, B_colInd, B_nzVal, B_nzOp, B_nnz = _kron_CuSparseMatrixCOO_components(B)
@@ -87,13 +87,13 @@ function LinearAlgebra.kron(
 end
 
 function LinearAlgebra.kron(
-    A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
-    B::Diagonal{TvB, <:Union{CuVector{TvB}, Base.ReshapedArray{TvB, 1, <:Adjoint{TvB, <:CuVector{TvB}}}}}
+        A::Union{CuSparseMatrixCOO{TvA, TiA}, Transpose{TvA, <:CuSparseMatrixCOO{TvA, TiA}}, Adjoint{TvA, <:CuSparseMatrixCOO{TvA, TiA}}},
+        B::Diagonal{TvB, <:Union{CuVector{TvB}, Base.ReshapedArray{TvB, 1, <:Adjoint{TvB, <:CuVector{TvB}}}}}
     ) where {TvA, TiA, TvB}
     mA, nA = size(A)
     mB, nB = size(B)
     Ti = TiA
-    Tv = typeof(oneunit(TvA)*oneunit(TvB))
+    Tv = typeof(oneunit(TvA) * oneunit(TvB))
 
     A_rowInd, A_colInd, A_nzVal, A_nzOp, A_nnz = _kron_CuSparseMatrixCOO_components(A)
     B_rowInd, B_colInd, B_nzVal, B_nnz = one(Ti):Ti(nB), one(Ti):Ti(nB), B.diag, Int(nB)
@@ -112,13 +112,13 @@ function LinearAlgebra.kron(
 end
 
 function LinearAlgebra.kron(
-    A::Diagonal{TvA, <:Union{CuVector{TvA}, Base.ReshapedArray{TvA, 1, <:Adjoint{TvA, <:CuVector{TvA}}}}},
-    B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
+        A::Diagonal{TvA, <:Union{CuVector{TvA}, Base.ReshapedArray{TvA, 1, <:Adjoint{TvA, <:CuVector{TvA}}}}},
+        B::Union{CuSparseMatrixCOO{TvB, TiB}, Transpose{TvB, <:CuSparseMatrixCOO{TvB, TiB}}, Adjoint{TvB, <:CuSparseMatrixCOO{TvB, TiB}}}
     ) where {TvA, TvB, TiB}
     mA, nA = size(A)
     mB, nB = size(B)
     Ti = TiB
-    Tv = typeof(oneunit(TvA)*oneunit(TvB))
+    Tv = typeof(oneunit(TvA) * oneunit(TvB))
 
     A_rowInd, A_colInd, A_nzVal, A_nnz = one(Ti):Ti(nA), one(Ti):Ti(nA), A.diag, Int(nA)
     B_rowInd, B_colInd, B_nzVal, B_nzOp, B_nnz = _kron_CuSparseMatrixCOO_components(B)
diff --git a/test/libraries/cusparse/linalg.jl b/test/libraries/cusparse/linalg.jl
index 30c529821..5e9948ab8 100644
--- a/test/libraries/cusparse/linalg.jl
+++ b/test/libraries/cusparse/linalg.jl
@@ -40,12 +40,12 @@ m = 10
             end
         end
         @testset "kronecker product with I opa = $opa" for opa in (identity, transpose, adjoint)
-            @test collect(kron(opa(dA), dC)) ≈ kron(opa(A), C) 
-            @test collect(kron(dC, opa(dA))) ≈ kron(C, opa(A)) 
+            @test collect(kron(opa(dA), dC)) ≈ kron(opa(A), C)
+            @test collect(kron(dC, opa(dA))) ≈ kron(C, opa(A))
             @test collect(kron(opa(dZA), dC)) ≈ kron(opa(ZA), C)
             @test collect(kron(dC, opa(dZA))) ≈ kron(C, opa(ZA))
         end
-        @testset "kronecker product with Diagonal opa = $opa" for opa in (identity, transpose, adjoint) 
+        @testset "kronecker product with Diagonal opa = $opa" for opa in (identity, transpose, adjoint)
             @test collect(kron(opa(dA), dD)) ≈ kron(opa(A), D)
             @test collect(kron(dD, opa(dA))) ≈ kron(D, opa(A))
             @test collect(kron(opa(dZA), dD)) ≈ kron(opa(ZA), D)
@@ -58,7 +58,7 @@ end
     mat_sizes = [(2, 3), (2, 0)]
     @testset "size(A) = ($(mA), $(nA)), size(B) = ($(mB), $(nB))" for (mA, nA) in mat_sizes, (mB, nB) in mat_sizes
         A = sprand(T, mA, nA, 0.5)
-        B  = sprand(T, mB, nB, 0.5)
+        B = sprand(T, mB, nB, 0.5)
 
         A_I, A_J, A_V = findnz(A)
         dA = CuSparseMatrixCOO{T, Cint}(adapt(CuVector{Cint}, A_I), adapt(CuVector{Cint}, A_J), adapt(CuVector{T}, A_V), size(A))
@@ -67,7 +67,7 @@ end
 
         @testset "kronecker (COO ⊗ COO) opa = $opa, opb = $opb" for opa in (identity, transpose, adjoint), opb in (identity, transpose, adjoint)
             dC = kron(opa(dA), opb(dB))
-            @test collect(dC)  ≈ kron(opa(A), opb(B))
+            @test collect(dC) ≈ kron(opa(A), opb(B))
             @test eltype(dC) == typeof(oneunit(T) * oneunit(T))
             @test dC isa CuSparseMatrixCOO
         end
@@ -76,7 +76,7 @@ end
 
 @testset "TA = $TA, TvB = $TvB" for TvB in [Float32, Float64, ComplexF32, ComplexF64], TA in [Bool, TvB]
     A = Diagonal(rand(TA, 2))
-    B  = sprand(TvB, 3, 4, 0.5)
+    B = sprand(TvB, 3, 4, 0.5)
     dA = adapt(CuArray, A)
 
     B_I, B_J, B_V = findnz(B)
@@ -84,14 +84,14 @@ end
 
     @testset "kronecker (diagonal ⊗ COO) opa = $opa, opb = $opb" for opa in (identity, adjoint), opb in (identity, transpose, adjoint)
         dC = kron(opa(dA), opb(dB))
-        @test collect(dC)  ≈ kron(opa(A), opb(B))
+        @test collect(dC) ≈ kron(opa(A), opb(B))
         @test eltype(dC) == typeof(oneunit(TA) * oneunit(TvB))
         @test dC isa CuSparseMatrixCOO
     end
 
     @testset "kronecker (COO ⊗ diagonal) opa = $opa, opb = $opb" for opa in (identity, adjoint), opb in (identity, transpose, adjoint)
         dC = kron(opb(dB), opa(dA))
-        @test collect(dC)  ≈ kron(opb(B), opa(A))
+        @test collect(dC) ≈ kron(opb(B), opa(A))
         @test eltype(dC) == typeof(oneunit(TvB) * oneunit(TA))
         @test dC isa CuSparseMatrixCOO
     end

@tam724
Copy link
Author

tam724 commented Jun 27, 2025

For reference: here the behaviour of Diagonal{Bool} = identity matrix is mentioned. JuliaGPU/GPUArrays.jl#585

However, not even for all instances of Diagonal{Bool} this implementation of kron is correct, because the diagonal could contain false elements:

julia> Diagonal([true, false, true])
3×3 Diagonal{Bool, Vector{Bool}}:
 1    
   0  
     1

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Benchmark suite Current: d675c5b Previous: 205c238 Ratio
latency/precompile 43025806155.5 ns 42934926801 ns 1.00
latency/ttfp 7074219824 ns 7008552789 ns 1.01
latency/import 3594760805 ns 3569139582 ns 1.01
integration/volumerhs 9624374 ns 9606581 ns 1.00
integration/byval/slices=1 147102 ns 147311 ns 1.00
integration/byval/slices=3 425799 ns 426127 ns 1.00
integration/byval/reference 145010 ns 145282 ns 1.00
integration/byval/slices=2 286557 ns 286537 ns 1.00
integration/cudadevrt 103547 ns 103674 ns 1.00
kernel/indexing 14240 ns 14638.5 ns 0.97
kernel/indexing_checked 14935 ns 15045 ns 0.99
kernel/occupancy 693.5906040268457 ns 669.9465408805031 ns 1.04
kernel/launch 2187.5 ns 2202.4444444444443 ns 0.99
kernel/rand 14621 ns 17466 ns 0.84
array/reverse/1d 19991 ns 20143 ns 0.99
array/reverse/2d 24727 ns 24692 ns 1.00
array/reverse/1d_inplace 10935 ns 11332 ns 0.96
array/reverse/2d_inplace 13205 ns 13662 ns 0.97
array/copy 21010 ns 21281 ns 0.99
array/iteration/findall/int 158289 ns 159966.5 ns 0.99
array/iteration/findall/bool 139961 ns 141602 ns 0.99
array/iteration/findfirst/int 157177.5 ns 163419 ns 0.96
array/iteration/findfirst/bool 158056.5 ns 165377 ns 0.96
array/iteration/scalar 72457 ns 76152 ns 0.95
array/iteration/logical 215837 ns 219912.5 ns 0.98
array/iteration/findmin/1d 46345 ns 47580 ns 0.97
array/iteration/findmin/2d 96386.5 ns 97060 ns 0.99
array/reductions/reduce/Int64/1d 42288 ns 43742.5 ns 0.97
array/reductions/reduce/Int64/dims=1 47080 ns 47519.5 ns 0.99
array/reductions/reduce/Int64/dims=2 61959 ns 62503 ns 0.99
array/reductions/reduce/Int64/dims=1L 89219 ns 89134 ns 1.00
array/reductions/reduce/Int64/dims=2L 86912 ns 87634.5 ns 0.99
array/reductions/reduce/Float32/1d 34155.5 ns 35637 ns 0.96
array/reductions/reduce/Float32/dims=1 41602 ns 51967.5 ns 0.80
array/reductions/reduce/Float32/dims=2 59864 ns 59824 ns 1.00
array/reductions/reduce/Float32/dims=1L 52461 ns 52680 ns 1.00
array/reductions/reduce/Float32/dims=2L 70375 ns 70568 ns 1.00
array/reductions/mapreduce/Int64/1d 41681.5 ns 43514 ns 0.96
array/reductions/mapreduce/Int64/dims=1 47300.5 ns 46605.5 ns 1.01
array/reductions/mapreduce/Int64/dims=2 61915.5 ns 62143.5 ns 1.00
array/reductions/mapreduce/Int64/dims=1L 89234 ns 89174 ns 1.00
array/reductions/mapreduce/Int64/dims=2L 86652 ns 87305.5 ns 0.99
array/reductions/mapreduce/Float32/1d 34137 ns 35464 ns 0.96
array/reductions/mapreduce/Float32/dims=1 41874.5 ns 42505.5 ns 0.99
array/reductions/mapreduce/Float32/dims=2 60225 ns 60252 ns 1.00
array/reductions/mapreduce/Float32/dims=1L 52762 ns 52803 ns 1.00
array/reductions/mapreduce/Float32/dims=2L 70590 ns 70795 ns 1.00
array/broadcast 20257 ns 20737 ns 0.98
array/copyto!/gpu_to_gpu 12846 ns 13192 ns 0.97
array/copyto!/cpu_to_gpu 215252 ns 217123 ns 0.99
array/copyto!/gpu_to_cpu 283170 ns 287100 ns 0.99
array/accumulate/Int64/1d 124592 ns 126109 ns 0.99
array/accumulate/Int64/dims=1 83243 ns 84201 ns 0.99
array/accumulate/Int64/dims=2 157820 ns 158968 ns 0.99
array/accumulate/Int64/dims=1L 1720180 ns 1710638 ns 1.01
array/accumulate/Int64/dims=2L 967741 ns 967410.5 ns 1.00
array/accumulate/Float32/1d 109131.5 ns 109994 ns 0.99
array/accumulate/Float32/dims=1 80544.5 ns 81343 ns 0.99
array/accumulate/Float32/dims=2 147835.5 ns 148659 ns 0.99
array/accumulate/Float32/dims=1L 1618457 ns 1619411 ns 1.00
array/accumulate/Float32/dims=2L 698496 ns 699433 ns 1.00
array/construct 1297.6 ns 1288.5 ns 1.01
array/random/randn/Float32 43987 ns 45344 ns 0.97
array/random/randn!/Float32 25001 ns 25330 ns 0.99
array/random/rand!/Int64 27384 ns 27554 ns 0.99
array/random/rand!/Float32 8807.333333333334 ns 8908.333333333334 ns 0.99
array/random/rand/Int64 29933 ns 30218 ns 0.99
array/random/rand/Float32 12892.5 ns 13361 ns 0.96
array/permutedims/4d 60596 ns 60397 ns 1.00
array/permutedims/2d 54406 ns 54394 ns 1.00
array/permutedims/3d 55082 ns 55362 ns 0.99
array/sorting/1d 2755635 ns 2758561 ns 1.00
array/sorting/by 3342611 ns 3368461 ns 0.99
array/sorting/2d 1080126 ns 1089562 ns 0.99
cuda/synchronization/stream/auto 1041.5 ns 1066.6 ns 0.98
cuda/synchronization/stream/nonblocking 7436.4 ns 7691.3 ns 0.97
cuda/synchronization/stream/blocking 821.0430107526881 ns 844.0121951219512 ns 0.97
cuda/synchronization/context/auto 1159.2 ns 1211.4 ns 0.96
cuda/synchronization/context/nonblocking 7904.9 ns 6881.1 ns 1.15
cuda/synchronization/context/blocking 896.42 ns 924.7692307692307 ns 0.97

This comment was automatically generated by workflow using github-action-benchmark.

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a single nit.

@maleadt
Copy link
Member

maleadt commented Jul 29, 2025

Sorry for the delay here, JuliaCon and everything...

IIUC all kinds of diagonals (I, Bool, and non-I) now work properly? Anything else that needs to change here?

@tam724
Copy link
Author

tam724 commented Jul 29, 2025

I'm in the process of understanding the Adapt.jl\GPUArrays.jl\CUDA.jl interplay of dealing with the structured matrices from LinearAlgebra.jl. I threw in 5e51422 here, but now I think the adjoint(::Diagonal) should be treated differently. I'll revert the commit, and test locally why this was required. Then the PR should be ready.

@tam724
Copy link
Author

tam724 commented Jul 30, 2025

Since CI is more or less happy now, I think this is good to go.
The last commit was required because of a bug in the conversion from csc -> coo (if CUDA.runtime_version() == v"12.0"). Probably the same reason as this:

if !(v"12.0" <= CUSPARSE.version() < v"12.1")

Also, I was hoping to find a better way to dispatch also for (adjoint(::Diagonal{Complex, <: some cuVector})), than in the current implementation.

Also note, that the index computation only works, if the indices of the (coo) sparse matrix use one-based indices (the default). But there is no way of figuring this out from the matrix and we would need an additional argument to kron (and possibly other routines) to support that.
Would be nice to have this as part of the type, maybe this PR needs a follow up.

Copy link

codecov bot commented Jul 30, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 89.79%. Comparing base (205c238) to head (f5baedb).

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #2804   +/-   ##
=======================================
  Coverage   89.78%   89.79%           
=======================================
  Files         150      150           
  Lines       13229    13228    -1     
=======================================
  Hits        11878    11878           
+ Misses       1351     1350    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants