@@ -23,6 +23,10 @@ arg1type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` sub
2323arg2type (x:: AbstractKroneckerArray ) = arg2type (typeof (x))
2424arg2type (:: Type{<:AbstractKroneckerArray} ) = error (" `AbstractKroneckerArray` subtypes have to implement `arg2type`." )
2525
26+ arguments (a:: AbstractKroneckerArray ) = (arg1 (a), arg2 (a))
27+ arguments (a:: AbstractKroneckerArray , n:: Int ) = arguments (a)[n]
28+ argument_types (a:: AbstractKroneckerArray ) = argument_types (typeof (a))
29+
2630function unwrap_array (a:: AbstractArray )
2731 p = parent (a)
2832 p ≡ a && return a
@@ -51,7 +55,7 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
5155end
5256
5357struct KroneckerArray{T, N, A1 <: AbstractArray{T, N} , A2 <: AbstractArray{T, N} } < :
54- AbstractKroneckerArray{T, N, A1, A2 }
58+ AbstractKroneckerArray{T, N}
5559 arg1:: A1
5660 arg2:: A2
5761end
@@ -76,6 +80,8 @@ const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = Kro
7680arg1type (:: Type{KroneckerArray{T, N, A1, A2}} ) where {T, N, A1, A2} = A1
7781arg2type (:: Type{KroneckerArray{T, N, A1, A2}} ) where {T, N, A1, A2} = A2
7882
83+ argument_types (:: Type{<:KroneckerArray{<:Any, <:Any, A1, A2}} ) where {A1, A2} = (A1, A2)
84+
7985function mutate_active_args! (f!, f, dest, src)
8086 (isactive (arg1 (dest)) || isactive (arg2 (dest))) ||
8187 error (" Can't mutate immutable KroneckerArray." )
@@ -93,7 +99,7 @@ function mutate_active_args!(f!, f, dest, src)
9399end
94100
95101using Adapt: Adapt, adapt
96- function Adapt. adapt_structure (to, a:: KroneckerArray )
102+ function Adapt. adapt_structure (to, a:: AbstractKroneckerArray )
97103 # TODO : Is this a good definition? It is similar to
98104 # the definition of `similar`.
99105 return if isactive (arg1 (a)) == isactive (arg2 (a))
@@ -105,18 +111,22 @@ function Adapt.adapt_structure(to, a::KroneckerArray)
105111 end
106112end
107113
108- function Base. copy (a:: KroneckerArray )
109- return copy (arg1 (a)) ⊗ copy (arg2 (a))
114+ Base. copy (a:: AbstractKroneckerArray ) = copy (arg1 (a)) ⊗ copy (arg2 (a))
115+ function Base. copy! (dest:: AbstractKroneckerArray , src:: AbstractKroneckerArray )
116+ return mutate_active_args! (copy!, copy, dest, src)
110117end
111118
119+ # TODO : copyto! is typically reserved for contiguous copies (i.e. also for copying from a
120+ # vector into an array), it might be better to not define that here.
112121function Base. copyto! (dest:: KroneckerArray{<:Any, N} , src:: KroneckerArray{<:Any, N} ) where {N}
113122 return mutate_active_args! (copyto!, copy, dest, src)
114123end
115124
116125function Base. convert (
117- :: Type{KroneckerArray{T, N, A1, A2}} , a:: KroneckerArray
118- ) where {T, N, A1, A2}
119- return _convert (A1, arg1 (a)) ⊗ _convert (A2, arg2 (a))
126+ :: Type{KroneckerArray{T, N, A1, A2}} , a:: AbstractKroneckerArray
127+ ):: KroneckerArray{T, N, A1, A2} where {T, N, A1, A2}
128+ typeof (a) === KroneckerArray{T, N, A1, A2} && return a
129+ return KroneckerArray (_convert (A1, arg1 (a)), _convert (A2, arg2 (a)))
120130end
121131
122132# Promote the element type if needed.
125135maybe_promot_eltype (a, elt) = eltype (a) <: elt ? a : elt .(a)
126136
127137function Base. similar (
128- a:: KroneckerArray ,
138+ a:: AbstractKroneckerArray ,
129139 elt:: Type ,
130140 axs:: Tuple {
131141 CartesianProductUnitRange{<: Integer }, Vararg{CartesianProductUnitRange{<: Integer }},
@@ -142,7 +152,7 @@ function Base.similar(
142152 maybe_promot_eltype (arg1 (a), elt) ⊗ similar (arg2 (a), elt, arg2 .(axs))
143153 end
144154end
145- function Base. similar (a:: KroneckerArray , elt:: Type )
155+ function Base. similar (a:: AbstractKroneckerArray , elt:: Type )
146156 # TODO : Is this a good definition?
147157 return if isactive (arg1 (a)) == isactive (arg2 (a))
148158 similar (arg1 (a), elt) ⊗ similar (arg2 (a), elt)
@@ -152,7 +162,7 @@ function Base.similar(a::KroneckerArray, elt::Type)
152162 maybe_promot_eltype (arg1 (a), elt) ⊗ similar (arg2 (a), elt)
153163 end
154164end
155- function Base. similar (a:: KroneckerArray )
165+ function Base. similar (a:: AbstractKroneckerArray )
156166 # TODO : Is this a good definition?
157167 return if isactive (arg1 (a)) == isactive (arg2 (a))
158168 similar (arg1 (a)) ⊗ similar (arg2 (a))
@@ -174,16 +184,18 @@ function Base.similar(
174184end
175185
176186function Base. similar (
177- arrayt :: Type{<:KroneckerArray{<:Any, <:Any, A1, A2} } ,
187+ :: Type{ArrayT } ,
178188 axs:: Tuple {
179189 CartesianProductUnitRange{<: Integer }, Vararg{CartesianProductUnitRange{<: Integer }},
180190 },
181- ) where {A1, A2}
191+ ) where {ArrayT <: AbstractKroneckerArray }
192+ A1, A2 = arg1type (ArrayT), arg2type (ArrayT)
182193 return similar (A1, map (arg1, axs)) ⊗ similar (A2, map (arg2, axs))
183194end
184195function Base. similar (
185- :: Type{<:KroneckerArray{<:Any, <:Any, A1, A2}} , sz:: Tuple{Int, Vararg{Int}}
186- ) where {A1, A2}
196+ :: Type{ArrayT} , sz:: Tuple{Int, Vararg{Int}}
197+ ) where {ArrayT <: AbstractKroneckerArray }
198+ A1, A2 = arg1type (ArrayT), arg2type (ArrayT)
187199 return similar (promote_type (A1, A2), sz)
188200end
189201
@@ -196,15 +208,15 @@ function Base.similar(
196208 return similar (arrayt, map (arg1, axs)) ⊗ similar (arrayt, map (arg2, axs))
197209end
198210
199- function Base. permutedims (a:: KroneckerArray , perm)
211+ function Base. permutedims (a:: AbstractKroneckerArray , perm)
200212 return permutedims (arg1 (a), perm) ⊗ permutedims (arg2 (a), perm)
201213end
202214using DerivableInterfaces: DerivableInterfaces, permuteddims
203- function DerivableInterfaces. permuteddims (a:: KroneckerArray , perm)
215+ function DerivableInterfaces. permuteddims (a:: AbstractKroneckerArray , perm)
204216 return permuteddims (arg1 (a), perm) ⊗ permuteddims (arg2 (a), perm)
205217end
206218
207- function Base. permutedims! (dest:: KroneckerArray , src:: KroneckerArray , perm)
219+ function Base. permutedims! (dest:: AbstractKroneckerArray , src:: AbstractKroneckerArray , perm)
208220 return mutate_active_args! (
209221 (dest, src) -> permutedims! (dest, src, perm), Base. Fix2 (permutedims, perm), dest, src
210222 )
@@ -235,9 +247,10 @@ kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2)
235247kron_nd (a1:: AbstractVector , a2:: AbstractVector ) = kron (a1, a2)
236248
237249# Eagerly collect arguments to make more general on GPU.
238- Base. collect (a:: KroneckerArray ) = kron_nd (collect (arg1 (a)), collect (arg2 (a)))
250+ Base. collect (a:: AbstractKroneckerArray ) = kron_nd (collect (arg1 (a)), collect (arg2 (a)))
251+ Base. collect (T:: Type , a:: AbstractKroneckerArray ) = kron_nd (collect (T, arg1 (a)), collect (T, arg2 (a)))
239252
240- function Base. zero (a:: KroneckerArray )
253+ function Base. zero (a:: AbstractKroneckerArray )
241254 return if isactive (arg1 (a)) == isactive (arg2 (a))
242255 # TODO : Maybe this should zero both arguments?
243256 # This is how `a * false` would behave.
@@ -250,35 +263,28 @@ function Base.zero(a::KroneckerArray)
250263end
251264
252265using DerivableInterfaces: DerivableInterfaces, zero!
253- function DerivableInterfaces. zero! (a:: KroneckerArray )
266+ function DerivableInterfaces. zero! (a:: AbstractKroneckerArray )
254267 (isactive (arg1 (a)) || isactive (arg2 (a))) ||
255268 error (" Can't mutate immutable KroneckerArray." )
256269 isactive (arg1 (a)) && zero! (arg1 (a))
257270 isactive (arg2 (a)) && zero! (arg2 (a))
258271 return a
259272end
260273
261- function Base. Array {T, N} (a:: KroneckerArray {S, N} ) where {T, S, N}
262- return convert (Array{T, N}, collect (a))
274+ function Base. Array {T, N} (a:: AbstractKroneckerArray {S, N} ) where {T, S, N}
275+ return convert (Array{T, N}, collect (T, a))
263276end
264277
265- function Base. size (a:: KroneckerArray )
266- return ntuple (dim -> size (arg1 (a), dim) * size (arg2 (a), dim), ndims (a))
267- end
278+ Base. size (a:: AbstractKroneckerArray ) = size (arg1 (a)) .* size (arg2 (a))
268279
269- function Base. axes (a:: KroneckerArray )
280+ function Base. axes (a:: AbstractKroneckerArray )
270281 return ntuple (ndims (a)) do dim
271282 return CartesianProductUnitRange (
272283 axes (arg1 (a), dim) × axes (arg2 (a), dim), Base. OneTo (size (a, dim))
273284 )
274285 end
275286end
276287
277- arguments (a:: KroneckerArray ) = (arg1 (a), arg2 (a))
278- arguments (a:: KroneckerArray , n:: Int ) = arguments (a)[n]
279- argument_types (a:: KroneckerArray ) = argument_types (typeof (a))
280- argument_types (:: Type{<:KroneckerArray{<:Any, <:Any, A1, A2}} ) where {A1, A2} = (A1, A2)
281-
282288function Base. print_array (io:: IO , a:: KroneckerArray )
283289 Base. print_array (io, arg1 (a))
284290 println (io, " \n ⊗" )
@@ -312,45 +318,48 @@ end
312318
313319# Indexing logic.
314320function Base. to_indices (
315- a:: KroneckerArray , inds, I:: Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
321+ a:: AbstractKroneckerArray , inds, I:: Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
316322 )
317323 I1 = to_indices (arg1 (a), arg1 .(inds), arg1 .(I))
318324 I2 = to_indices (arg2 (a), arg2 .(inds), arg2 .(I))
319325 return I1 .× I2
320326end
321327
322328function Base. getindex (
323- a:: KroneckerArray {<:Any, N} , I:: Vararg{Union{CartesianPair, CartesianProduct}, N}
329+ a:: AbstractKroneckerArray {<:Any, N} , I:: Vararg{Union{CartesianPair, CartesianProduct}, N}
324330 ) where {N}
325331 I′ = to_indices (a, I)
326332 return arg1 (a)[arg1 .(I′)... ] ⊗ arg2 (a)[arg2 .(I′)... ]
327333end
328334# Fix ambigiuity error.
329- Base. getindex (a:: KroneckerArray {<:Any, 0} ) = arg1 (a)[] * arg2 (a)[]
335+ Base. getindex (a:: AbstractKroneckerArray {<:Any, 0} ) = arg1 (a)[] * arg2 (a)[]
330336
331337arg1 (:: Colon ) = (:)
332338arg2 (:: Colon ) = (:)
333339arg1 (:: Base.Slice ) = (:)
334340arg2 (:: Base.Slice ) = (:)
335341function Base. view (
336- a:: KroneckerArray {<:Any, N} ,
342+ a:: AbstractKroneckerArray {<:Any, N} ,
337343 I:: Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N} ,
338344 ) where {N}
339345 return view (arg1 (a), arg1 .(I)... ) ⊗ view (arg2 (a), arg2 .(I)... )
340346end
341- function Base. view (a:: KroneckerArray {<:Any, N} , I:: Vararg{CartesianPair, N} ) where {N}
347+ function Base. view (a:: AbstractKroneckerArray {<:Any, N} , I:: Vararg{CartesianPair, N} ) where {N}
342348 return view (arg1 (a), arg1 .(I)... ) ⊗ view (arg2 (a), arg2 .(I)... )
343349end
344350# Fix ambigiuity error.
345- Base. view (a:: KroneckerArray {<:Any, 0} ) = view (arg1 (a)) ⊗ view (arg2 (a))
351+ Base. view (a:: AbstractKroneckerArray {<:Any, 0} ) = view (arg1 (a)) ⊗ view (arg2 (a))
346352
347- function Base.:(== )(a:: KroneckerArray , b:: KroneckerArray )
353+ function Base.:(== )(a:: AbstractKroneckerArray , b:: AbstractKroneckerArray )
348354 return arg1 (a) == arg1 (b) && arg2 (a) == arg2 (b)
349355end
350- function Base. isapprox (a:: KroneckerArray , b:: KroneckerArray ; kwargs... )
356+
357+ # TODO : this definition doesn't fully retain the original meaning:
358+ # ‖a - b‖ < atol could be true even if the following check isn't
359+ function Base. isapprox (a:: AbstractKroneckerArray , b:: AbstractKroneckerArray ; kwargs... )
351360 return isapprox (arg1 (a), arg1 (b); kwargs... ) && isapprox (arg2 (a), arg2 (b); kwargs... )
352361end
353- function Base. iszero (a:: KroneckerArray )
362+ function Base. iszero (a:: AbstractKroneckerArray )
354363 return iszero (arg1 (a)) || iszero (arg2 (a))
355364end
356365function Base. isreal (a:: KroneckerArray )
@@ -362,17 +371,17 @@ function DiagonalArrays.diagonal(a::KroneckerArray)
362371 return diagonal (arg1 (a)) ⊗ diagonal (arg2 (a))
363372end
364373
365- Base. real (a:: KroneckerArray {<:Real} ) = a
366- function Base. real (a:: KroneckerArray )
374+ Base. real (a:: AbstractKroneckerArray {<:Real} ) = a
375+ function Base. real (a:: AbstractKroneckerArray )
367376 if iszero (imag (arg1 (a))) || iszero (imag (arg2 (a)))
368377 return real (arg1 (a)) ⊗ real (arg2 (a))
369378 elseif iszero (real (arg1 (a))) || iszero (real (arg2 (a)))
370379 return - (imag (arg1 (a)) ⊗ imag (arg2 (a)))
371380 end
372381 return real (arg1 (a)) ⊗ real (arg2 (a)) - imag (arg1 (a)) ⊗ imag (arg2 (a))
373382end
374- Base. imag (a:: KroneckerArray {<:Real} ) = zero (a)
375- function Base. imag (a:: KroneckerArray )
383+ Base. imag (a:: AbstractKroneckerArray {<:Real} ) = zero (a)
384+ function Base. imag (a:: AbstractKroneckerArray )
376385 if iszero (imag (arg1 (a))) || iszero (real (arg2 (a)))
377386 return real (arg1 (a)) ⊗ imag (arg2 (a))
378387 elseif iszero (real (arg1 (a))) || iszero (imag (arg2 (a)))
@@ -383,14 +392,14 @@ end
383392
384393for f in [:transpose , :adjoint , :inv ]
385394 @eval begin
386- function Base. $f (a:: KroneckerArray )
395+ function Base. $f (a:: AbstractKroneckerArray )
387396 return $ f (arg1 (a)) ⊗ $ f (arg2 (a))
388397 end
389398 end
390399end
391400
392401function Base. reshape (
393- a:: KroneckerArray , ax:: Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
402+ a:: AbstractKroneckerArray , ax:: Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
394403 )
395404 return reshape (arg1 (a), map (arg1, ax)) ⊗ reshape (arg2 (a), map (arg2, ax))
396405end
410419function KroneckerStyle {N, A1, A2} (v:: Val{M} ) where {N, A1, A2, M}
411420 return KroneckerStyle {M, typeof(A1)(v), typeof(A2)(v)} ()
412421end
413- function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any, N, A1, A2}} ) where {N, A1, A2 }
414- return KroneckerStyle {N } (BroadcastStyle (A1) , BroadcastStyle (A2 ))
422+ function Base. BroadcastStyle (:: Type{T} ) where {T <: AbstractKroneckerArray }
423+ return KroneckerStyle {ndims(T) } (BroadcastStyle (arg1type (T)) , BroadcastStyle (arg2type (T) ))
415424end
416425function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
417426 style_a = BroadcastStyle (arg1 (style1), arg1 (style2))
@@ -430,10 +439,10 @@ function Base.similar(
430439 return a ⊗ b
431440end
432441
433- function Base. map (f, a1:: KroneckerArray , a_rest:: KroneckerArray ... )
442+ function Base. map (f, a1:: AbstractKroneckerArray , a_rest:: AbstractKroneckerArray ... )
434443 return Broadcast. broadcast_preserving_zero_d (f, a1, a_rest... )
435444end
436- function Base. map! (f, dest:: KroneckerArray , a1:: KroneckerArray , a_rest:: KroneckerArray ... )
445+ function Base. map! (f, dest:: AbstractKroneckerArray , a1:: AbstractKroneckerArray , a_rest:: AbstractKroneckerArray ... )
437446 dest .= f .(a1, a_rest... )
438447 return dest
439448end
465474function Base. copy (a:: Summed{<:KroneckerStyle} )
466475 return copy (KroneckerBroadcast (a))
467476end
468- function Base. copyto! (dest:: KroneckerArray , a:: Summed{<:KroneckerStyle} )
477+ function Base. copyto! (dest:: AbstractKroneckerArray , a:: Summed{<:KroneckerStyle} )
469478 return copyto! (dest, KroneckerBroadcast (a))
470479end
471480
0 commit comments