Skip to content

Commit 0093468

Browse files
authored
Improve kron implementation (#600)
Significantly simplify indexing by not applying transposes and adjoints manually and by using `fld1` and `mod1`. Also add some combinations involving mixed vectors and matrices for generality
1 parent 602976f commit 0093468

File tree

2 files changed

+39
-46
lines changed

2 files changed

+39
-46
lines changed

src/host/linalg.jl

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -739,17 +739,17 @@ end
739739

740740
## Kronecker product
741741

742+
@kernel function kron_kernel_vec!(z, @Const(x), @Const(y))
743+
i, j = @index(Global, NTuple)
744+
745+
@inbounds z[(i - 1) * length(y) + j] = x[i] * y[j]
746+
end
747+
742748
function LinearAlgebra.kron!(z::AbstractGPUVector{T1}, x::AbstractGPUVector{T2}, y::AbstractGPUVector{T3}) where {T1,T2,T3}
743749
@assert length(z) == length(x) * length(y)
744750

745-
@kernel function kron_kernel!(z, @Const(x), @Const(y))
746-
i, j = @index(Global, NTuple)
747-
748-
@inbounds z[(i - 1) * length(y) + j] = x[i] * y[j]
749-
end
750-
751751
backend = KernelAbstractions.get_backend(z)
752-
kernel = kron_kernel!(backend)
752+
kernel = kron_kernel_vec!(backend)
753753

754754
kernel(z, x, y, ndrange=(length(x), length(y)))
755755

@@ -759,57 +759,51 @@ end
759759
function LinearAlgebra.kron(x::AbstractGPUVector{T1}, y::AbstractGPUVector{T2}) where {T1,T2}
760760
T = promote_type(T1, T2)
761761
z = similar(x, T, length(x) * length(y))
762-
return LinearAlgebra.kron!(z, x, y)
762+
return kron!(z, x, y)
763+
end
764+
765+
@kernel function kron_kernel!(C, @Const(A), @Const(B))
766+
ai, aj = @index(Global, NTuple) # Indices in the result matrix
767+
768+
# lb1, lb2 = size(B) # Dimensions of B
769+
lb1 = size(B, 1)
770+
lb2 = size(B, 2)
771+
772+
# Map global indices (ai, aj) to submatrices of the Kronecker product
773+
i_a = fld1(ai, lb1) # Corresponding row index in A
774+
i_b = mod1(ai, lb1) # Corresponding row index in B
775+
j_a = fld1(aj, lb2) # Corresponding col index in A
776+
j_b = mod1(aj, lb2) # Corresponding col index in B
777+
778+
@inbounds C[ai, aj] = A[i_a, j_a] * B[i_b, j_b]
763779
end
764780

765-
trans_adj_wrappers = ((T -> :(AbstractGPUMatrix{$T}), T -> 'N', identity),
766-
(T -> :(Transpose{$T, <:AbstractGPUMatrix{$T}}), T -> 'T', A -> :(parent($A))),
767-
(T -> :(Adjoint{$T, <:AbstractGPUMatrix{$T}}), T -> T <: Real ? 'T' : 'C', A -> :(parent($A))))
781+
trans_adj_wrappers = (
782+
T -> :(AbstractGPUVecOrMat{$T}),
783+
T -> :(Transpose{$T, <:AbstractGPUVecOrMat{$T}}),
784+
T -> :(Adjoint{$T, <:AbstractGPUVecOrMat{$T}}),
785+
)
768786

769-
for (wrapa, transa, unwrapa) in trans_adj_wrappers, (wrapb, transb, unwrapb) in trans_adj_wrappers
770-
TypeA = wrapa(:(T1))
771-
TypeB = wrapb(:(T2))
772-
TypeC = :(AbstractGPUMatrix{T3})
787+
for wrapa in trans_adj_wrappers, wrapb in trans_adj_wrappers
788+
TypeA = wrapa(:T1)
789+
TypeB = wrapb(:T2)
790+
TypeC = :(AbstractGPUVecOrMat{T3})
773791

774-
@eval function LinearAlgebra.kron!(C::$TypeC, A::$TypeA, B::$TypeB) where {T1,T2,T3}
792+
@eval function LinearAlgebra.kron!(C::$TypeC, A::$TypeA, B::$TypeB) where {T1, T2, T3}
775793
@assert size(C, 1) == size(A, 1) * size(B, 1)
776794
@assert size(C, 2) == size(A, 2) * size(B, 2)
777795

778-
ta = $transa(T1)
779-
tb = $transb(T2)
780-
781-
@kernel function kron_kernel!(C, @Const(A), @Const(B))
782-
ai, aj = @index(Global, NTuple) # Indices in the result matrix
783-
784-
# lb1, lb2 = size(B) # Dimensions of B
785-
lb1, lb2 = tb == 'N' ? size(B) : reverse(size(B))
786-
787-
# Map global indices (ai, aj) to submatrices of the Kronecker product
788-
i_a = (ai - 1) ÷ lb1 + 1 # Corresponding row index in A
789-
i_b = (ai - 1) % lb1 + 1 # Corresponding row index in B
790-
j_a = (aj - 1) ÷ lb2 + 1 # Corresponding col index in A
791-
j_b = (aj - 1) % lb2 + 1 # Corresponding col index in B
792-
793-
@inbounds begin
794-
a_ij = ta == 'N' ? A[i_a, j_a] : (ta == 'T' ? A[j_a, i_a] : conj(A[j_a, i_a]))
795-
b_ij = tb == 'N' ? B[i_b, j_b] : (tb == 'T' ? B[j_b, i_b] : conj(B[j_b, i_b]))
796-
797-
C[ai, aj] = a_ij * b_ij
798-
end
799-
end
800-
801796
backend = KernelAbstractions.get_backend(C)
802797
kernel = kron_kernel!(backend)
803-
804-
kernel(C, $(unwrapa(:A)), $(unwrapb(:B)), ndrange=(size(C, 1), size(C, 2)))
805-
798+
799+
kernel(C, A, B, ndrange=(size(C, 1), size(C, 2)))
800+
806801
return C
807802
end
808803

809804
@eval function LinearAlgebra.kron(A::$TypeA, B::$TypeB) where {T1, T2}
810805
T = promote_type(T1, T2)
811-
size_C = (size(A, 1) * size(B, 1), size(A, 2) * size(B, 2))
812-
C = similar(A, T, size_C...)
806+
C = similar(A, T, size(A, 1) * size(B, 1), size(A, 2) * size(B, 2))
813807
return kron!(C, A, B)
814808
end
815809
end

test/testsuite/linalg.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,7 @@
313313

314314
@testset "kron" begin
315315
for T in eltypes
316-
@test compare(kron, AT, rand(T, 32), rand(T, 64))
317-
for opa in (identity, transpose, adjoint), opb in (identity, transpose, adjoint)
316+
for opa in (vec, identity, transpose, adjoint), opb in (vec, identity, transpose, adjoint)
318317
@test compare(kron, AT, opa(rand(T, 32, 64)), opb(rand(T, 128, 16)))
319318
end
320319
end

0 commit comments

Comments
 (0)