diff --git a/.gitignore b/.gitignore index 9b73c9742..349b1c4ae 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ Manifest.toml # MacOS generated files *.DS_Store + +/.vscode diff --git a/src/host/linalg.jl b/src/host/linalg.jl index b4acc8e71..2c9287479 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -736,3 +736,80 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T} Array(y)[] end + +## Kronecker product + +function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2}, y::AbstractGPUVector{T3}) where {T1,T2,T3} + @assert length(z) == length(x) * length(y) + + @kernel function kron_kernel!(z, @Const(x), @Const(y)) + i, j = @index(Global, NTuple) + + @inbounds z[(i - 1) * length(y) + j] = x[i] * y[j] + end + + backend = KernelAbstractions.get_backend(z) + kernel = kron_kernel!(backend) + + kernel(z, x, y, ndrange=(length(x), length(y))) + + return z +end + +function LinearAlgebra.kron(x::AbstractGPUVector{T1}, y::AbstractGPUVector{T2}) where {T1,T2} + T = promote_type(T1, T2) + z = similar(x, T, length(x) * length(y)) + return LinearAlgebra.kron!(z, x, y) +end + +trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$T}), T -> 'N', identity), + (T -> :(Transpose{$T, <:AbstractGPUMatrix{$T}}), T -> 'T', A -> :(parent($A))), + (T -> :(Adjoint{$T, <:AbstractGPUMatrix{$T}}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A)))) + +for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in trans_adj_wrappers + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + TypeC = :(AbstractGPUMatrix{T3}) + + @eval function LinearAlgebra.kron!(C::$TypeC, A::$TypeA, B::$TypeB) where {T1,T2,T3} + @assert size(C, 1) == size(A, 1) * size(B, 1) + @assert size(C, 2) == size(A, 2) * size(B, 2) + + ta = $transa(T1) + tb = $transb(T2) + + @kernel function kron_kernel!(C, @Const(A), @Const(B)) + ai, aj = @index(Global, NTuple) # Indices in the result matrix + + # lb1, lb2 = size(B) # Dimensions of B + lb1, lb2 = tb == 'N' ? size(B) : reverse(size(B)) + + # Map global indices (ai, aj) to submatrices of the Kronecker product + i_a = (ai - 1) ÷ lb1 + 1 # Corresponding row index in A + i_b = (ai - 1) % lb1 + 1 # Corresponding row index in B + j_a = (aj - 1) ÷ lb2 + 1 # Corresponding col index in A + j_b = (aj - 1) % lb2 + 1 # Corresponding col index in B + + @inbounds begin + a_ij = ta == 'N' ? A[i_a, j_a] : (ta == 'T' ? A[j_a, i_a] : conj(A[j_a, i_a])) + b_ij = tb == 'N' ? B[i_b, j_b] : (tb == 'T' ? B[j_b, i_b] : conj(B[j_b, i_b])) + + C[ai, aj] = a_ij * b_ij + end + end + + backend = KernelAbstractions.get_backend(C) + kernel = kron_kernel!(backend) + + kernel(C, $(unwrapa(:A)), $(unwrapb(:B)), ndrange=(size(C, 1), size(C, 2))) + + return C + end + + @eval function LinearAlgebra.kron(A::$TypeA, B::$TypeB) where {T1, T2} + T = promote_type(T1, T2) + size_C = (size(A, 1) * size(B, 1), size(A, 2) * size(B, 2)) + C = similar(A, T, size_C...) + return kron!(C, A, B) + end +end diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 8de318549..d4833d3e9 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -310,6 +310,15 @@ @test iszero(A) @test isone(A) == false end + + @testset "kron" begin + for T in eltypes + @test compare(kron, AT, rand(T, 32), rand(T, 64)) + for opa in (identity, transpose, adjoint), opb in (identity, transpose, adjoint) + @test compare(kron, AT, opa(rand(T, 32, 64)), opb(rand(T, 128, 16))) + end + end + end end @testsuite "linalg/mul!/vector-matrix" (AT, eltypes)->begin