@@ -736,3 +736,80 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T}
736736
737737 Array(y)[]
738738end
739+
740+ # # Kronecker product
741+
742+ function LinearAlgebra. kron!(z:: AbstractGPUVector{T1} , x:: AbstractGPUVector{T2} , y:: AbstractGPUVector{T3} ) where {T1,T2,T3}
743+ @assert length(z) == length(x) * length(y)
744+
745+ @kernel function kron_kernel!(z, @Const(x), @Const(y))
746+ i, j = @index(Global, NTuple)
747+
748+ @inbounds z[(i - 1 ) * length(y) + j] = x[i] * y[j]
749+ end
750+
751+ backend = KernelAbstractions. get_backend(z)
752+ kernel = kron_kernel!(backend)
753+
754+ kernel(z, x, y, ndrange= (length(x), length(y)))
755+
756+ return z
757+ end
758+
759+ function LinearAlgebra. kron(x:: AbstractGPUVector{T1} , y:: AbstractGPUVector{T2} ) where {T1,T2}
760+ T = promote_type(T1, T2)
761+ z = similar(x, T, length(x) * length(y))
762+ return LinearAlgebra. kron!(z, x, y)
763+ end
764+
765+ trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$ T}), T -> ' N' , identity),
766+ (T -> :(Transpose{$ T, <: AbstractGPUMatrix{$T} }), T -> ' T' , A -> :(parent($ A))),
767+ (T -> :(Adjoint{$ T, <: AbstractGPUMatrix{$T} }), T -> T <: Real ? ' T' : ' C' , A -> :(parent($ A))))
768+
769+ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in trans_adj_wrappers
770+ TypeA = wrapa(:(T1))
771+ TypeB = wrapb(:(T2))
772+ TypeC = :(AbstractGPUMatrix{T3})
773+
774+ @eval function LinearAlgebra. kron!(C:: $TypeC , A:: $TypeA , B:: $TypeB ) where {T1,T2,T3}
775+ @assert size(C, 1 ) == size(A, 1 ) * size(B, 1 )
776+ @assert size(C, 2 ) == size(A, 2 ) * size(B, 2 )
777+
778+ ta = $ transa(T1)
779+ tb = $ transb(T2)
780+
781+ @kernel function kron_kernel!(C, @Const(A), @Const(B))
782+ ai, aj = @index(Global, NTuple) # Indices in the result matrix
783+
784+ # lb1, lb2 = size(B) # Dimensions of B
785+ lb1, lb2 = tb == ' N' ? size(B) : reverse(size(B))
786+
787+ # Map global indices (ai, aj) to submatrices of the Kronecker product
788+ i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789+ i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
790+ j_a = (aj - 1 ) ÷ lb2 + 1 # Corresponding col index in A
791+ j_b = (aj - 1 ) % lb2 + 1 # Corresponding col index in B
792+
793+ @inbounds begin
794+ a_ij = ta == ' N' ? A[i_a, j_a] : (ta == ' T' ? A[j_a, i_a] : conj(A[j_a, i_a]))
795+ b_ij = tb == ' N' ? B[i_b, j_b] : (tb == ' T' ? B[j_b, i_b] : conj(B[j_b, i_b]))
796+
797+ C[ai, aj] = a_ij * b_ij
798+ end
799+ end
800+
801+ backend = KernelAbstractions. get_backend(C)
802+ kernel = kron_kernel!(backend)
803+
804+ kernel(C, $ (unwrapa(:A)), $ (unwrapb(:B)), ndrange= (size(C, 1 ), size(C, 2 )))
805+
806+ return C
807+ end
808+
809+ @eval function LinearAlgebra. kron(A:: $TypeA , B:: $TypeB ) where {T1, T2}
810+ T = promote_type(T1, T2)
811+ size_C = (size(A, 1 ) * size(B, 1 ), size(A, 2 ) * size(B, 2 ))
812+ C = similar(A, T, size_C... )
813+ return kron!(C, A, B)
814+ end
815+ end
0 commit comments