@@ -736,3 +736,80 @@ function Base.isone(x::AbstractGPUMatrix{T}) where {T}
736
736
737
737
Array (y)[]
738
738
end
739
+
740
+ # # Kronecker product
741
+
742
+ function LinearAlgebra. kron! (z:: AbstractGPUVector{T1} , x:: AbstractGPUVector{T2} , y:: AbstractGPUVector{T3} ) where {T1,T2,T3}
743
+ @assert length (z) == length (x) * length (y)
744
+
745
+ @kernel function kron_kernel! (z, @Const (x), @Const (y))
746
+ i, j = @index (Global, NTuple)
747
+
748
+ @inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
749
+ end
750
+
751
+ backend = KernelAbstractions. get_backend (z)
752
+ kernel = kron_kernel! (backend)
753
+
754
+ kernel (z, x, y, ndrange= (length (x), length (y)))
755
+
756
+ return z
757
+ end
758
+
759
+ function LinearAlgebra. kron (x:: AbstractGPUVector{T1} , y:: AbstractGPUVector{T2} ) where {T1,T2}
760
+ T = promote_type (T1, T2)
761
+ z = similar (x, T, length (x) * length (y))
762
+ return LinearAlgebra. kron! (z, x, y)
763
+ end
764
+
765
+ trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$ T}), T -> ' N' , identity),
766
+ (T -> :(Transpose{$ T, <: AbstractGPUMatrix{$T} }), T -> ' T' , A -> :(parent ($ A))),
767
+ (T -> :(Adjoint{$ T, <: AbstractGPUMatrix{$T} }), T -> T <: Real ? ' T' : ' C' , A -> :(parent ($ A))))
768
+
769
+ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in trans_adj_wrappers
770
+ TypeA = wrapa (:(T1))
771
+ TypeB = wrapb (:(T2))
772
+ TypeC = :(AbstractGPUMatrix{T3})
773
+
774
+ @eval function LinearAlgebra. kron! (C:: $TypeC , A:: $TypeA , B:: $TypeB ) where {T1,T2,T3}
775
+ @assert size (C, 1 ) == size (A, 1 ) * size (B, 1 )
776
+ @assert size (C, 2 ) == size (A, 2 ) * size (B, 2 )
777
+
778
+ ta = $ transa (T1)
779
+ tb = $ transb (T2)
780
+
781
+ @kernel function kron_kernel! (C, @Const (A), @Const (B))
782
+ ai, aj = @index (Global, NTuple) # Indices in the result matrix
783
+
784
+ # lb1, lb2 = size(B) # Dimensions of B
785
+ lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
786
+
787
+ # Map global indices (ai, aj) to submatrices of the Kronecker product
788
+ i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789
+ i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
790
+ j_a = (aj - 1 ) ÷ lb2 + 1 # Corresponding col index in A
791
+ j_b = (aj - 1 ) % lb2 + 1 # Corresponding col index in B
792
+
793
+ @inbounds begin
794
+ a_ij = ta == ' N' ? A[i_a, j_a] : (ta == ' T' ? A[j_a, i_a] : conj (A[j_a, i_a]))
795
+ b_ij = tb == ' N' ? B[i_b, j_b] : (tb == ' T' ? B[j_b, i_b] : conj (B[j_b, i_b]))
796
+
797
+ C[ai, aj] = a_ij * b_ij
798
+ end
799
+ end
800
+
801
+ backend = KernelAbstractions. get_backend (C)
802
+ kernel = kron_kernel! (backend)
803
+
804
+ kernel (C, $ (unwrapa (:A )), $ (unwrapb (:B )), ndrange= (size (C, 1 ), size (C, 2 )))
805
+
806
+ return C
807
+ end
808
+
809
+ @eval function LinearAlgebra. kron (A:: $TypeA , B:: $TypeB ) where {T1, T2}
810
+ T = promote_type (T1, T2)
811
+ size_C = (size (A, 1 ) * size (B, 1 ), size (A, 2 ) * size (B, 2 ))
812
+ C = similar (A, T, size_C... )
813
+ return kron! (C, A, B)
814
+ end
815
+ end
0 commit comments