Skip to content

Commit 87b95a9

Browse files
Add vector-vector and matrix-matrix Kronecker product (#575)
1 parent 8094ded commit 87b95a9

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ Manifest.toml
1212

1313
# MacOS generated files
1414
*.DS_Store
15+
16+
/.vscode

src/host/linalg.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,3 +736,80 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T}
736736

737737
Array(y)[]
738738
end
739+
740+
## Kronecker product
741+
742+
function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2}, y::AbstractGPUVector{T3}) where {T1,T2,T3}
743+
@assert length(z) == length(x) * length(y)
744+
745+
@kernel function kron_kernel!(z, @Const(x), @Const(y))
746+
i, j = @index(Global, NTuple)
747+
748+
@inbounds z[(i - 1) * length(y) + j] = x[i] * y[j]
749+
end
750+
751+
backend = KernelAbstractions.get_backend(z)
752+
kernel = kron_kernel!(backend)
753+
754+
kernel(z, x, y, ndrange=(length(x), length(y)))
755+
756+
return z
757+
end
758+
759+
function LinearAlgebra.kron(x::AbstractGPUVector{T1}, y::AbstractGPUVector{T2}) where {T1,T2}
760+
T = promote_type(T1, T2)
761+
z = similar(x, T, length(x) * length(y))
762+
return LinearAlgebra.kron!(z, x, y)
763+
end
764+
765+
trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$T}), T -> 'N', identity),
766+
(T -> :(Transpose{$T, <:AbstractGPUMatrix{$T}}), T -> 'T', A -> :(parent($A))),
767+
(T -> :(Adjoint{$T, <:AbstractGPUMatrix{$T}}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))
768+
769+
for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in trans_adj_wrappers
770+
TypeA = wrapa(:(T1))
771+
TypeB = wrapb(:(T2))
772+
TypeC = :(AbstractGPUMatrix{T3})
773+
774+
@eval function LinearAlgebra.kron!(C::$TypeC, A::$TypeA, B::$TypeB) where {T1,T2,T3}
775+
@assert size(C, 1) == size(A, 1) * size(B, 1)
776+
@assert size(C, 2) == size(A, 2) * size(B, 2)
777+
778+
ta = $transa(T1)
779+
tb = $transb(T2)
780+
781+
@kernel function kron_kernel!(C, @Const(A), @Const(B))
782+
ai, aj = @index(Global, NTuple) # Indices in the result matrix
783+
784+
# lb1, lb2 = size(B) # Dimensions of B
785+
lb1, lb2 = tb == 'N' ? size(B) : reverse(size(B))
786+
787+
# Map global indices (ai, aj) to submatrices of the Kronecker product
788+
i_a = (ai - 1) ÷ lb1 + 1 # Corresponding row index in A
789+
i_b = (ai - 1) % lb1 + 1 # Corresponding row index in B
790+
j_a = (aj - 1) ÷ lb2 + 1 # Corresponding col index in A
791+
j_b = (aj - 1) % lb2 + 1 # Corresponding col index in B
792+
793+
@inbounds begin
794+
a_ij = ta == 'N' ? A[i_a, j_a] : (ta == 'T' ? A[j_a, i_a] : conj(A[j_a, i_a]))
795+
b_ij = tb == 'N' ? B[i_b, j_b] : (tb == 'T' ? B[j_b, i_b] : conj(B[j_b, i_b]))
796+
797+
C[ai, aj] = a_ij * b_ij
798+
end
799+
end
800+
801+
backend = KernelAbstractions.get_backend(C)
802+
kernel = kron_kernel!(backend)
803+
804+
kernel(C, $(unwrapa(:A)), $(unwrapb(:B)), ndrange=(size(C, 1), size(C, 2)))
805+
806+
return C
807+
end
808+
809+
@eval function LinearAlgebra.kron(A::$TypeA, B::$TypeB) where {T1, T2}
810+
T = promote_type(T1, T2)
811+
size_C = (size(A, 1) * size(B, 1), size(A, 2) * size(B, 2))
812+
C = similar(A, T, size_C...)
813+
return kron!(C, A, B)
814+
end
815+
end

test/testsuite/linalg.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,15 @@
310310
@test iszero(A)
311311
@test isone(A) == false
312312
end
313+
314+
@testset "kron" begin
315+
for T in eltypes
316+
@test compare(kron, AT, rand(T, 32), rand(T, 64))
317+
for opa in (identity, transpose, adjoint), opb in (identity, transpose, adjoint)
318+
@test compare(kron, AT, opa(rand(T, 32, 64)), opb(rand(T, 128, 16)))
319+
end
320+
end
321+
end
313322
end
314323

315324
@testsuite "linalg/mul!/vector-matrix" (AT, eltypes)->begin

0 commit comments

Comments
 (0)