@@ -739,17 +739,17 @@ end
739
739
740
740
# # Kronecker product
741
741
742
+ @kernel function kron_kernel_vec! (z, @Const (x), @Const (y))
743
+ i, j = @index (Global, NTuple)
744
+
745
+ @inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
746
+ end
747
+
742
748
function LinearAlgebra. kron! (z:: AbstractGPUVector{T1} , x:: AbstractGPUVector{T2} , y:: AbstractGPUVector{T3} ) where {T1,T2,T3}
743
749
@assert length (z) == length (x) * length (y)
744
750
745
- @kernel function kron_kernel! (z, @Const (x), @Const (y))
746
- i, j = @index (Global, NTuple)
747
-
748
- @inbounds z[(i - 1 ) * length (y) + j] = x[i] * y[j]
749
- end
750
-
751
751
backend = KernelAbstractions. get_backend (z)
752
- kernel = kron_kernel ! (backend)
752
+ kernel = kron_kernel_vec ! (backend)
753
753
754
754
kernel (z, x, y, ndrange= (length (x), length (y)))
755
755
@@ -759,57 +759,51 @@ end
759
759
function LinearAlgebra. kron (x:: AbstractGPUVector{T1} , y:: AbstractGPUVector{T2} ) where {T1,T2}
760
760
T = promote_type (T1, T2)
761
761
z = similar (x, T, length (x) * length (y))
762
- return LinearAlgebra. kron! (z, x, y)
762
+ return kron! (z, x, y)
763
+ end
764
+
765
+ @kernel function kron_kernel! (C, @Const (A), @Const (B))
766
+ ai, aj = @index (Global, NTuple) # Indices in the result matrix
767
+
768
+ # lb1, lb2 = size(B) # Dimensions of B
769
+ lb1 = size (B, 1 )
770
+ lb2 = size (B, 2 )
771
+
772
+ # Map global indices (ai, aj) to submatrices of the Kronecker product
773
+ i_a = fld1 (ai, lb1) # Corresponding row index in A
774
+ i_b = mod1 (ai, lb1) # Corresponding row index in B
775
+ j_a = fld1 (aj, lb2) # Corresponding col index in A
776
+ j_b = mod1 (aj, lb2) # Corresponding col index in B
777
+
778
+ @inbounds C[ai, aj] = A[i_a, j_a] * B[i_b, j_b]
763
779
end
764
780
765
- trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$ T}), T -> ' N' , identity),
766
- (T -> :(Transpose{$ T, <: AbstractGPUMatrix{$T} }), T -> ' T' , A -> :(parent ($ A))),
767
- (T -> :(Adjoint{$ T, <: AbstractGPUMatrix{$T} }), T -> T <: Real ? ' T' : ' C' , A -> :(parent ($ A))))
781
+ trans_adj_wrappers = (
782
+ T -> :(AbstractGPUVecOrMat{$ T}),
783
+ T -> :(Transpose{$ T, <: AbstractGPUVecOrMat{$T} }),
784
+ T -> :(Adjoint{$ T, <: AbstractGPUVecOrMat{$T} }),
785
+ )
768
786
769
- for ( wrapa, transa, unwrapa) in trans_adj_wrappers, ( wrapb, transb, unwrapb) in trans_adj_wrappers
770
- TypeA = wrapa (:(T1) )
771
- TypeB = wrapb (:(T2) )
772
- TypeC = :(AbstractGPUMatrix {T3})
787
+ for wrapa in trans_adj_wrappers, wrapb in trans_adj_wrappers
788
+ TypeA = wrapa (:T1 )
789
+ TypeB = wrapb (:T2 )
790
+ TypeC = :(AbstractGPUVecOrMat {T3})
773
791
774
- @eval function LinearAlgebra. kron! (C:: $TypeC , A:: $TypeA , B:: $TypeB ) where {T1,T2,T3}
792
+ @eval function LinearAlgebra. kron! (C:: $TypeC , A:: $TypeA , B:: $TypeB ) where {T1, T2, T3}
775
793
@assert size (C, 1 ) == size (A, 1 ) * size (B, 1 )
776
794
@assert size (C, 2 ) == size (A, 2 ) * size (B, 2 )
777
795
778
- ta = $ transa (T1)
779
- tb = $ transb (T2)
780
-
781
- @kernel function kron_kernel! (C, @Const (A), @Const (B))
782
- ai, aj = @index (Global, NTuple) # Indices in the result matrix
783
-
784
- # lb1, lb2 = size(B) # Dimensions of B
785
- lb1, lb2 = tb == ' N' ? size (B) : reverse (size (B))
786
-
787
- # Map global indices (ai, aj) to submatrices of the Kronecker product
788
- i_a = (ai - 1 ) ÷ lb1 + 1 # Corresponding row index in A
789
- i_b = (ai - 1 ) % lb1 + 1 # Corresponding row index in B
790
- j_a = (aj - 1 ) ÷ lb2 + 1 # Corresponding col index in A
791
- j_b = (aj - 1 ) % lb2 + 1 # Corresponding col index in B
792
-
793
- @inbounds begin
794
- a_ij = ta == ' N' ? A[i_a, j_a] : (ta == ' T' ? A[j_a, i_a] : conj (A[j_a, i_a]))
795
- b_ij = tb == ' N' ? B[i_b, j_b] : (tb == ' T' ? B[j_b, i_b] : conj (B[j_b, i_b]))
796
-
797
- C[ai, aj] = a_ij * b_ij
798
- end
799
- end
800
-
801
796
backend = KernelAbstractions. get_backend (C)
802
797
kernel = kron_kernel! (backend)
803
-
804
- kernel (C, $ ( unwrapa ( :A )), $ ( unwrapb ( :B )) , ndrange= (size (C, 1 ), size (C, 2 )))
805
-
798
+
799
+ kernel (C, A, B , ndrange= (size (C, 1 ), size (C, 2 )))
800
+
806
801
return C
807
802
end
808
803
809
804
@eval function LinearAlgebra. kron (A:: $TypeA , B:: $TypeB ) where {T1, T2}
810
805
T = promote_type (T1, T2)
811
- size_C = (size (A, 1 ) * size (B, 1 ), size (A, 2 ) * size (B, 2 ))
812
- C = similar (A, T, size_C... )
806
+ C = similar (A, T, size (A, 1 ) * size (B, 1 ), size (A, 2 ) * size (B, 2 ))
813
807
return kron! (C, A, B)
814
808
end
815
809
end
0 commit comments