@@ -25,9 +25,10 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
2525 return isdiag (a) ? _construct (A, a) : throw (InexactError (:convert , A, a))
2626end
2727
28- struct KroneckerArray{T,N,A<: AbstractArray{T,N} ,B<: AbstractArray{T,N} } <: AbstractArray{T,N}
29- a:: A
30- b:: B
28+ struct KroneckerArray{T,N,A1<: AbstractArray{T,N} ,A2<: AbstractArray{T,N} } < :
29+ AbstractArray{T,N}
30+ arg1:: A1
31+ arg2:: A2
3132end
3233function KroneckerArray (a:: AbstractArray , b:: AbstractArray )
3334 if ndims (a) != ndims (b)
@@ -38,11 +39,15 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray)
3839 elt = promote_type (eltype (a), eltype (b))
3940 return _convert (AbstractArray{elt}, a) ⊗ _convert (AbstractArray{elt}, b)
4041end
41- const KroneckerMatrix{T,A<: AbstractMatrix{T} ,B<: AbstractMatrix{T} } = KroneckerArray{T,2 ,A,B}
42- const KroneckerVector{T,A<: AbstractVector{T} ,B<: AbstractVector{T} } = KroneckerArray{T,1 ,A,B}
42+ const KroneckerMatrix{T,A1<: AbstractMatrix{T} ,A2<: AbstractMatrix{T} } = KroneckerArray{
43+ T,2 ,A1,A2
44+ }
45+ const KroneckerVector{T,A1<: AbstractVector{T} ,A2<: AbstractVector{T} } = KroneckerArray{
46+ T,1 ,A1,A2
47+ }
4348
44- arg1 (a:: KroneckerArray ) = a . a
45- arg2 (a:: KroneckerArray ) = a . b
49+ @inline arg1 (a:: KroneckerArray ) = getfield (a, :arg1 )
50+ @inline arg2 (a:: KroneckerArray ) = getfield (a, :arg2 )
4651
4752function mutate_active_args! (f!, f, dest, src)
4853 (isactive (arg1 (dest)) || isactive (arg2 (dest))) ||
@@ -81,8 +86,10 @@ function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N
8186 return mutate_active_args! (copyto!, copy, dest, src)
8287end
8388
84- function Base. convert (:: Type{KroneckerArray{T,N,A,B}} , a:: KroneckerArray ) where {T,N,A,B}
85- return _convert (A, arg1 (a)) ⊗ _convert (B, arg2 (a))
89+ function Base. convert (
90+ :: Type{KroneckerArray{T,N,A1,A2}} , a:: KroneckerArray
91+ ) where {T,N,A1,A2}
92+ return _convert (A1, arg1 (a)) ⊗ _convert (A2, arg2 (a))
8693end
8794
8895# Promote the element type if needed.
@@ -140,17 +147,17 @@ function Base.similar(
140147end
141148
142149function Base. similar (
143- arrayt:: Type{<:KroneckerArray{<:Any,<:Any,A,B }} ,
150+ arrayt:: Type{<:KroneckerArray{<:Any,<:Any,A1,A2 }} ,
144151 axs:: Tuple {
145152 CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
146153 },
147- ) where {A,B }
148- return similar (A , map (arg1, axs)) ⊗ similar (B , map (arg2, axs))
154+ ) where {A1,A2 }
155+ return similar (A1 , map (arg1, axs)) ⊗ similar (A2 , map (arg2, axs))
149156end
150157function Base. similar (
151- :: Type{<:KroneckerArray{<:Any,<:Any,A,B }} , sz:: Tuple{Int,Vararg{Int}}
152- ) where {A,B }
153- return similar (promote_type (A, B ), sz)
158+ :: Type{<:KroneckerArray{<:Any,<:Any,A1,A2 }} , sz:: Tuple{Int,Vararg{Int}}
159+ ) where {A1,A2 }
160+ return similar (promote_type (A1, A2 ), sz)
154161end
155162
156163function Base. similar (
243250arguments (a:: KroneckerArray ) = (arg1 (a), arg2 (a))
244251arguments (a:: KroneckerArray , n:: Int ) = arguments (a)[n]
245252argument_types (a:: KroneckerArray ) = argument_types (typeof (a))
246- argument_types (:: Type{<:KroneckerArray{<:Any,<:Any,A,B }} ) where {A,B } = (A, B )
253+ argument_types (:: Type{<:KroneckerArray{<:Any,<:Any,A1,A2 }} ) where {A1,A2 } = (A1, A2 )
247254
248255function Base. print_array (io:: IO , a:: KroneckerArray )
249256 Base. print_array (io, arg1 (a))
@@ -362,22 +369,22 @@ function Base.reshape(
362369end
363370
364371using Base. Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
365- struct KroneckerStyle{N,A,B } <: AbstractArrayStyle{N} end
366- arg1 (:: Type{<:KroneckerStyle{<:Any,A }} ) where {A } = A
372+ struct KroneckerStyle{N,A1,A2 } <: AbstractArrayStyle{N} end
373+ arg1 (:: Type{<:KroneckerStyle{<:Any,A1 }} ) where {A1 } = A1
367374arg1 (style:: KroneckerStyle ) = arg1 (typeof (style))
368- arg2 (:: Type{<:KroneckerStyle{<:Any,B }} ) where {B } = B
375+ arg2 (:: Type{<:KroneckerStyle{<:Any,<:Any,A2 }} ) where {A2 } = A2
369376arg2 (style:: KroneckerStyle ) = arg2 (typeof (style))
370377function KroneckerStyle {N} (a:: BroadcastStyle , b:: BroadcastStyle ) where {N}
371378 return KroneckerStyle {N,a,b} ()
372379end
373380function KroneckerStyle (a:: AbstractArrayStyle{N} , b:: AbstractArrayStyle{N} ) where {N}
374381 return KroneckerStyle {N} (a, b)
375382end
376- function KroneckerStyle {N,A,B } (v:: Val{M} ) where {N,A,B ,M}
377- return KroneckerStyle {M,typeof(A )(v),typeof(B )(v)} ()
383+ function KroneckerStyle {N,A1,A2 } (v:: Val{M} ) where {N,A1,A2 ,M}
384+ return KroneckerStyle {M,typeof(A1 )(v),typeof(A2 )(v)} ()
378385end
379- function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A,B }} ) where {N,A,B }
380- return KroneckerStyle {N} (BroadcastStyle (A ), BroadcastStyle (B ))
386+ function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A1,A2 }} ) where {N,A1,A2 }
387+ return KroneckerStyle {N} (BroadcastStyle (A1 ), BroadcastStyle (A2 ))
381388end
382389function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
383390 style_a = BroadcastStyle (arg1 (style1), arg1 (style2))
@@ -386,9 +393,11 @@ function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N
386393 (style_b isa Broadcast. Unknown) && return Broadcast. Unknown ()
387394 return KroneckerStyle {N} (style_a, style_b)
388395end
389- function Base. similar (bc:: Broadcasted{<:KroneckerStyle{N,A,B}} , elt:: Type , ax) where {N,A,B}
390- bc_a = Broadcasted (A, bc. f, arg1 .(bc. args), arg1 .(ax))
391- bc_b = Broadcasted (B, bc. f, arg2 .(bc. args), arg2 .(ax))
396+ function Base. similar (
397+ bc:: Broadcasted{<:KroneckerStyle{N,A1,A2}} , elt:: Type , ax
398+ ) where {N,A1,A2}
399+ bc_a = Broadcasted (A1, bc. f, arg1 .(bc. args), arg1 .(ax))
400+ bc_b = Broadcasted (A2, bc. f, arg2 .(bc. args), arg2 .(ax))
392401 a = similar (bc_a, elt)
393402 b = similar (bc_b, elt)
394403 return a ⊗ b
@@ -497,12 +506,12 @@ using Base.Broadcast: broadcasted
497506# Represents broadcast operations that can be applied Kronecker-wise,
498507# i.e. independently to each argument of the Kronecker product.
499508# Note that not all broadcast operations can be mapped to this.
500- struct KroneckerBroadcasted{A,B }
501- a :: A
502- b :: B
509+ struct KroneckerBroadcasted{A1,A2 }
510+ arg1 :: A1
511+ arg2 :: A2
503512end
504- arg1 (a:: KroneckerBroadcasted ) = a . a
505- arg2 (a:: KroneckerBroadcasted ) = a . b
513+ @inline arg1 (a:: KroneckerBroadcasted ) = getfield (a, :arg1 )
514+ @inline arg2 (a:: KroneckerBroadcasted ) = getfield (a, :arg2 )
506515⊗ (a:: Broadcasted , b:: Broadcasted ) = KroneckerBroadcasted (a, b)
507516⊗ (a:: Broadcasted , b) = KroneckerBroadcasted (a, b)
508517⊗ (a, b:: Broadcasted ) = KroneckerBroadcasted (a, b)
@@ -525,18 +534,20 @@ function Base.axes(a::KroneckerBroadcasted)
525534end
526535
527536function Base. BroadcastStyle (
528- :: Type{<:KroneckerBroadcasted{A,B }}
529- ) where {StyleA,StyleB,A <: Broadcasted{StyleA} ,B <: Broadcasted{StyleB } }
530- @assert ndims (A ) == ndims (B )
531- N = ndims (A )
532- return KroneckerStyle {N} (StyleA (), StyleB ())
537+ :: Type{<:KroneckerBroadcasted{A1,A2 }}
538+ ) where {StyleA1,StyleA2,A1 <: Broadcasted{StyleA1} ,A2 <: Broadcasted{StyleA2 } }
539+ @assert ndims (A1 ) == ndims (A2 )
540+ N = ndims (A1 )
541+ return KroneckerStyle {N} (StyleA1 (), StyleA2 ())
533542end
534543
535544# Operations that preserve the Kronecker structure.
536545for f in [:identity , :conj ]
537546 @eval begin
538- function Broadcast. broadcasted (:: KroneckerStyle{<:Any,A,B} , :: typeof ($ f), a) where {A,B}
539- return broadcasted (A, $ f, arg1 (a)) ⊗ broadcasted (B, $ f, arg2 (a))
547+ function Broadcast. broadcasted (
548+ :: KroneckerStyle{<:Any,A1,A2} , :: typeof ($ f), a
549+ ) where {A1,A2}
550+ return broadcasted (A1, $ f, arg1 (a)) ⊗ broadcasted (A2, $ f, arg2 (a))
540551 end
541552 end
542553end
0 commit comments