@@ -30,14 +30,14 @@ struct KroneckerArray{T,N,A1<:AbstractArray{T,N},A2<:AbstractArray{T,N}} <:
3030 arg1:: A1
3131 arg2:: A2
3232end
33- function KroneckerArray (a :: AbstractArray , b :: AbstractArray )
34- if ndims (a ) != ndims (b )
33+ function KroneckerArray (a1 :: AbstractArray , a2 :: AbstractArray )
34+ if ndims (a1 ) != ndims (a2 )
3535 throw (
3636 ArgumentError (" Kronecker product requires arrays of the same number of dimensions." )
3737 )
3838 end
39- elt = promote_type (eltype (a ), eltype (b ))
40- return _convert (AbstractArray{elt}, a ) ⊗ _convert (AbstractArray{elt}, b )
39+ elt = promote_type (eltype (a1 ), eltype (a2 ))
40+ return _convert (AbstractArray{elt}, a1 ) ⊗ _convert (AbstractArray{elt}, a2 )
4141end
4242const KroneckerMatrix{T,A1<: AbstractMatrix{T} ,A2<: AbstractMatrix{T} } = KroneckerArray{
4343 T,2 ,A1,A2
@@ -204,8 +204,8 @@ function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N}
204204 sz = reverse (ntuple (i -> size (a, i) * size (b, i), N))
205205 return permutedims (reshape (c′, sz), reverse (ntuple (identity, N)))
206206end
207- kron_nd (a :: AbstractMatrix , b :: AbstractMatrix ) = kron (a, b )
208- kron_nd (a :: AbstractVector , b :: AbstractVector ) = kron (a, b )
207+ kron_nd (a1 :: AbstractMatrix , a2 :: AbstractMatrix ) = kron (a1, a2 )
208+ kron_nd (a1 :: AbstractVector , a2 :: AbstractVector ) = kron (a1, a2 )
209209
210210# Eagerly collect arguments to make more general on GPU.
211211Base. collect (a:: KroneckerArray ) = kron_nd (collect (arg1 (a)), collect (arg2 (a)))
@@ -265,10 +265,10 @@ function Base.show(io::IO, a::KroneckerArray)
265265 return nothing
266266end
267267
268- ⊗ (a :: AbstractArray , b :: AbstractArray ) = KroneckerArray (a, b )
269- ⊗ (a :: Number , b :: Number ) = a * b
270- ⊗ (a :: Number , b :: AbstractArray ) = a * b
271- ⊗ (a :: AbstractArray , b :: Number ) = a * b
268+ ⊗ (a1 :: AbstractArray , a2 :: AbstractArray ) = KroneckerArray (a1, a2 )
269+ ⊗ (a1 :: Number , a2 :: Number ) = a1 * a2
270+ ⊗ (a1 :: Number , a2 :: AbstractArray ) = a1 * a2
271+ ⊗ (a1 :: AbstractArray , a2 :: Number ) = a1 * a2
272272
273273function Base. getindex (a:: KroneckerArray , i:: Integer )
274274 return a[CartesianIndices (a)[i]]
@@ -374,11 +374,11 @@ arg1(::Type{<:KroneckerStyle{<:Any,A1}}) where {A1} = A1
374374arg1 (style:: KroneckerStyle ) = arg1 (typeof (style))
375375arg2 (:: Type{<:KroneckerStyle{<:Any,<:Any,A2}} ) where {A2} = A2
376376arg2 (style:: KroneckerStyle ) = arg2 (typeof (style))
377- function KroneckerStyle {N} (a :: BroadcastStyle , b :: BroadcastStyle ) where {N}
378- return KroneckerStyle {N,a,b } ()
377+ function KroneckerStyle {N} (a1 :: BroadcastStyle , a2 :: BroadcastStyle ) where {N}
378+ return KroneckerStyle {N,a1,a2 } ()
379379end
380- function KroneckerStyle (a :: AbstractArrayStyle{N} , b :: AbstractArrayStyle{N} ) where {N}
381- return KroneckerStyle {N} (a, b )
380+ function KroneckerStyle (a1 :: AbstractArrayStyle{N} , a2 :: AbstractArrayStyle{N} ) where {N}
381+ return KroneckerStyle {N} (a1, a2 )
382382end
383383function KroneckerStyle {N,A1,A2} (v:: Val{M} ) where {N,A1,A2,M}
384384 return KroneckerStyle {M,typeof(A1)(v),typeof(A2)(v)} ()
@@ -447,11 +447,11 @@ function Broadcast.broadcasted(::KroneckerStyle, f, as...)
447447end
448448
449449# Linear operations.
450- function Broadcast. broadcasted (:: KroneckerStyle , :: typeof (+ ), a, b )
451- return Summed (a ) + Summed (b )
450+ function Broadcast. broadcasted (:: KroneckerStyle , :: typeof (+ ), a1, a2 )
451+ return Summed (a1 ) + Summed (a2 )
452452end
453- function Broadcast. broadcasted (:: KroneckerStyle , :: typeof (- ), a, b )
454- return Summed (a ) - Summed (b )
453+ function Broadcast. broadcasted (:: KroneckerStyle , :: typeof (- ), a1, a2 )
454+ return Summed (a1 ) - Summed (a2 )
455455end
456456function Broadcast. broadcasted (:: KroneckerStyle , :: typeof (* ), c:: Number , a)
457457 return c * Summed (a)
@@ -512,9 +512,9 @@ struct KroneckerBroadcasted{A1,A2}
512512end
513513@inline arg1 (a:: KroneckerBroadcasted ) = getfield (a, :arg1 )
514514@inline arg2 (a:: KroneckerBroadcasted ) = getfield (a, :arg2 )
515- ⊗ (a :: Broadcasted , b :: Broadcasted ) = KroneckerBroadcasted (a, b )
516- ⊗ (a :: Broadcasted , b ) = KroneckerBroadcasted (a, b )
517- ⊗ (a, b :: Broadcasted ) = KroneckerBroadcasted (a, b )
515+ ⊗ (a1 :: Broadcasted , a2 :: Broadcasted ) = KroneckerBroadcasted (a1, a2 )
516+ ⊗ (a1 :: Broadcasted , a2 ) = KroneckerBroadcasted (a1, a2 )
517+ ⊗ (a1, a2 :: Broadcasted ) = KroneckerBroadcasted (a1, a2 )
518518Broadcast. materialize (a:: KroneckerBroadcasted ) = copy (a)
519519Broadcast. materialize! (dest, a:: KroneckerBroadcasted ) = copyto! (dest, a)
520520Broadcast. broadcastable (a:: KroneckerBroadcasted ) = a
0 commit comments