You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
function generic_matmatmul!(C::AbstractGPUMatrix{R}, A::AbstractGPUMatrix{T}, B::AbstractGPUMatrix{S}, add::MulAddMul) where {T<:Number,S<:Number,R<:Number}
336
-
N = size(A,1)
337
-
Q = size(A,2)
338
-
M = size(B,2)
339
-
if Q != size(B,1)
340
-
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
341
-
end
342
-
if size(C,1) != N || size(C,2) != M
343
-
throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((N,M))"))
coalesced_matmul_kernel!(get_backend(C), (MAX_TILE_DIM, MAX_TILE_DIM))(C, A, B, N, Q, M;ndrange=map(x -> ceil(Int,x/MAX_TILE_DIM)*MAX_TILE_DIM, size(C)))
412
-
C
413
-
end
414
333
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
415
334
if size(A,2) != size(B,1)
416
335
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
@@ -825,7 +744,7 @@ function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2},
825
744
826
745
@kernel function kron_kernel!(z, @Const(x), @Const(y))
827
746
i, j = @index(Global, NTuple)
828
-
747
+
829
748
@inbounds z[(i -1) * length(y) + j] = x[i] * y[j]
830
749
end
831
750
@@ -858,13 +777,13 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
858
777
859
778
ta =$transa(T1)
860
779
tb =$transb(T2)
861
-
780
+
862
781
@kernel function kron_kernel!(C, @Const(A), @Const(B))
863
782
ai, aj = @index(Global, NTuple) # Indices in the result matrix
864
-
783
+
865
784
# lb1, lb2 = size(B) # Dimensions of B
866
785
lb1, lb2 = tb =='N'? size(B) : reverse(size(B))
867
-
786
+
868
787
# Map global indices (ai, aj) to submatrices of the Kronecker product
869
788
i_a = (ai -1) ÷ lb1 +1# Corresponding row index in A
870
789
i_b = (ai -1) % lb1 +1# Corresponding row index in B
@@ -878,12 +797,12 @@ for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in
0 commit comments