@@ -739,17 +739,17 @@ end
739739
740740# # Kronecker product
741741
742+ @kernel function kron_kernel_vec! (z, @Const (x), @Const (y))
743+ i, j = @index (Global, NTuple)
744+
745+ @inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
746+ end
747+
742748function LinearAlgebra. kron! (z:: AbstractGPUVector{T1} , x:: AbstractGPUVector{T2} , y:: AbstractGPUVector{T3} ) where {T1,T2,T3}
743749 @assert length (z) == length (x) * length (y)
744750
745- @kernel function kron_kernel! (z, @Const (x), @Const (y))
746- i, j = @index (Global, NTuple)
747-
748- @inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
749- end
750-
751751 backend = KernelAbstractions. get_backend (z)
752- kernel = kron_kernel ! (backend)
752+ kernel = kron_kernel_vec ! (backend)
753753
754754 kernel (z, x, y, ndrange= (length (x), length (y)))
755755
@@ -759,57 +759,51 @@ end
759759function LinearAlgebra. kron (x:: AbstractGPUVector{T1} , y:: AbstractGPUVector{T2} ) where {T1,T2}
760760 T = promote_type (T1, T2)
761761 z = similar (x, T, length (x) * length (y))
762- return LinearAlgebra. kron! (z, x, y)
762+ return kron! (z, x, y)
763+ end
764+
765+ @kernel function kron_kernel! (C, @Const (A), @Const (B))
766+ ai, aj = @index (Global, NTuple) # Indices in the result matrix
767+
768+ # lb1, lb2 = size(B) # Dimensions of B
769+ lb1 = size (B, 1 )
770+ lb2 = size (B, 2 )
771+
772+ # Map global indices (ai, aj) to submatrices of the Kronecker product
773+ i_a = fld1 (ai, lb1) # Corresponding row index in A
774+ i_b = mod1 (ai, lb1) # Corresponding row index in B
775+ j_a = fld1 (aj, lb2) # Corresponding col index in A
776+ j_b = mod1 (aj, lb2) # Corresponding col index in B
777+
778+ @inbounds C[ai, aj] = A[i_a, j_a] * B[i_b, j_b]
763779end
764780
765- trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$ T}), T -> ' N' , identity),
766- (T -> :(Transpose{$ T, <: AbstractGPUMatrix{$T} }), T -> ' T' , A -> :(parent ($ A))),
767- (T -> :(Adjoint{$ T, <: AbstractGPUMatrix{$T} }), T -> T <: Real ? ' T' : ' C' , A -> :(parent ($ A))))
781+ trans_adj_wrappers = (
782+ T -> :(AbstractGPUVecOrMat{$ T}),
783+ T -> :(Transpose{$ T, <: AbstractGPUVecOrMat{$T} }),
784+ T -> :(Adjoint{$ T, <: AbstractGPUVecOrMat{$T} }),
785+ )
768786
769- for ( wrapa, transa, unwrapa) in trans_adj_wrappers, ( wrapb, transb, unwrapb) in trans_adj_wrappers
770- TypeA = wrapa (:(T1) )
771- TypeB = wrapb (:(T2) )
772- TypeC = :(AbstractGPUMatrix {T3})
787+ for wrapa in trans_adj_wrappers, wrapb in trans_adj_wrappers
788+ TypeA = wrapa (:T1 )
789+ TypeB = wrapb (:T2 )
790+ TypeC = :(AbstractGPUVecOrMat {T3})
773791
774- @eval function LinearAlgebra. kron! (C:: $TypeC , A:: $TypeA , B:: $TypeB ) where {T1,T2,T3}
792+ @eval function LinearAlgebra. kron! (C:: $TypeC , A:: $TypeA , B:: $TypeB ) where {T1, T2, T3}
775793 @assert size (C, 1 ) == size (A, 1 ) * size (B, 1 )
776794 @assert size (C, 2 ) == size (A, 2 ) * size (B, 2 )
777795
778- ta = $ transa (T1)
779- tb = $ transb (T2)
780-
781- @kernel function kron_kernel! (C, @Const (A), @Const (B))
782- ai, aj = @index (Global, NTuple) # Indices in the result matrix
783-
784- # lb1, lb2 = size(B) # Dimensions of B
785- lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
786-
787- # Map global indices (ai, aj) to submatrices of the Kronecker product
788- i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789- i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
790- j_a = (aj - 1 ) ÷ lb2 + 1 # Corresponding col index in A
791- j_b = (aj - 1 ) % lb2 + 1 # Corresponding col index in B
792-
793- @inbounds begin
794- a_ij = ta == ' N' ? A[i_a, j_a] : (ta == ' T' ? A[j_a, i_a] : conj (A[j_a, i_a]))
795- b_ij = tb == ' N' ? B[i_b, j_b] : (tb == ' T' ? B[j_b, i_b] : conj (B[j_b, i_b]))
796-
797- C[ai, aj] = a_ij * b_ij
798- end
799- end
800-
801796 backend = KernelAbstractions. get_backend (C)
802797 kernel = kron_kernel! (backend)
803-
804- kernel (C, $ ( unwrapa ( :A )), $ ( unwrapb ( :B )) , ndrange= (size (C, 1 ), size (C, 2 )))
805-
798+
799+ kernel (C, A, B , ndrange= (size (C, 1 ), size (C, 2 )))
800+
806801 return C
807802 end
808803
809804 @eval function LinearAlgebra. kron (A:: $TypeA , B:: $TypeB ) where {T1, T2}
810805 T = promote_type (T1, T2)
811- size_C = (size (A, 1 ) * size (B, 1 ), size (A, 2 ) * size (B, 2 ))
812- C = similar (A, T, size_C... )
806+ C = similar (A, T, size (A, 1 ) * size (B, 1 ), size (A, 2 ) * size (B, 2 ))
813807 return kron! (C, A, B)
814808 end
815809end
0 commit comments