250250function Base. iszero (a:: KroneckerArray )
251251 return iszero (a. a) || iszero (a. b)
252252end
253+ function Base. isreal (a:: KroneckerArray )
254+ return isreal (a. a) && isreal (a. b)
255+ end
253256function Base. inv (a:: KroneckerArray )
254257 return inv (a. a) ⊗ inv (a. b)
255258end
270273function Base.:* (a:: KroneckerArray , b:: Number )
271274 return a. a ⊗ (a. b * b)
272275end
276+ function Base.:/ (a:: KroneckerArray , b:: Number )
277+ return a * inv (b)
278+ end
273279
274280function Base.:- (a:: KroneckerArray )
275281 return (- a. a) ⊗ a. b
@@ -291,26 +297,79 @@ for op in (:+, :-)
291297 end
292298end
293299
300+ using Base. Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
301+ struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
302+ function KroneckerStyle {N} (a:: BroadcastStyle , b:: BroadcastStyle ) where {N}
303+ return KroneckerStyle {N,a,b} ()
304+ end
305+ function KroneckerStyle {N,A,B} (v:: Val{M} ) where {N,A,B,M}
306+ return KroneckerStyle {M,typeof(A)(v),typeof(B)(v)} ()
307+ end
308+ function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A,B}} ) where {N,A,B}
309+ return KroneckerStyle {N} (BroadcastStyle (A), BroadcastStyle (B))
310+ end
311+ function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
312+ return KroneckerStyle {N} (
313+ BroadcastStyle (style1. a, style2. a), BroadcastStyle (style1. b, style2. b)
314+ )
315+ end
316+ function Base. similar (bc:: Broadcasted{<:KroneckerStyle{N,A,B}} , elt:: Type ) where {N,A,B}
317+ ax_a = map (ax -> ax. product. a, axes (bc))
318+ ax_b = map (ax -> ax. product. b, axes (bc))
319+ bc_a = Broadcasted (A, ax_a)
320+ bc_b = Broadcasted (B, ax_b)
321+ a = similar (bc_a, elt)
322+ b = similar (bc_b, elt)
323+ return a ⊗ b
324+ end
325+ function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:KroneckerStyle} )
326+ return throw (
327+ ArgumentError (
328+ " Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
329+ ),
330+ )
331+ end
332+
333+ function Base. map (f, a1:: KroneckerArray , a_rest:: KroneckerArray... )
334+ return throw (
335+ ArgumentError (
336+ " Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
337+ ),
338+ )
339+ end
340+ function Base. map! (f, dest:: KroneckerArray , a1:: KroneckerArray , a_rest:: KroneckerArray... )
341+ return throw (
342+ ArgumentError (
343+ " Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
344+ ),
345+ )
346+ end
294347function Base. map! (:: typeof (identity), dest:: KroneckerArray , a:: KroneckerArray )
295348 dest. a .= a. a
296349 dest. b .= a. b
297350 return dest
298351end
299- function Base. map! (:: typeof (+ ), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray )
300- if a. b == b. b
301- map! (+ , dest. a, a. a, b. a)
302- dest. b .= a. b
303- elseif a. a == b. a
304- dest. a .= a. a
305- map! (+ , dest. b, a. b, b. b)
306- else
307- throw (
308- ArgumentError (
309- " KroneckerArray addition is only supported when the first or second arguments match." ,
310- ),
352+ for f in [:+ , :- ]
353+ @eval begin
354+ function Base. map! (
355+ :: typeof ($ f), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray
311356 )
357+ if a. b == b. b
358+ map! ($ f, dest. a, a. a, b. a)
359+ dest. b .= a. b
360+ elseif a. a == b. a
361+ dest. a .= a. a
362+ map! ($ f, dest. b, a. b, b. b)
363+ else
364+ throw (
365+ ArgumentError (
366+ " KroneckerArray addition is only supported when the first or second arguments match." ,
367+ ),
368+ )
369+ end
370+ return dest
371+ end
312372 end
313- return dest
314373end
315374function Base. map! (
316375 f:: Base.Fix1{typeof(*),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
@@ -326,6 +385,11 @@ function Base.map!(
326385 dest. b .= f. f .(a. b, f. x)
327386 return dest
328387end
388+ function Base. map! (
389+ f:: Base.Fix2{typeof(/),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
390+ )
391+ return map! (Base. Fix2 (* , inv (f. x)), dest, a)
392+ end
329393
330394using LinearAlgebra:
331395 LinearAlgebra,
@@ -343,9 +407,11 @@ using LinearAlgebra:
343407 svd,
344408 svdvals,
345409 tr
346- diagonal (a:: AbstractArray ) = Diagonal (a)
347- function diagonal (a:: KroneckerArray )
348- return Diagonal (a. a) ⊗ Diagonal (a. b)
410+
411+ using DiagonalArrays: DiagonalArrays, diagonal
412+ DiagonalArrays. diagonal (a:: AbstractArray ) = Diagonal (a)
413+ function DiagonalArrays. diagonal (a:: KroneckerArray )
414+ return diagonal (a. a) ⊗ diagonal (a. b)
349415end
350416
351417function Base.:* (a:: KroneckerArray , b:: KroneckerArray )
@@ -506,6 +572,19 @@ const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
506572const KroneckerEye{T,A<: AbstractMatrix{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
507573const EyeEye{T,A<: Eye{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
508574
575+ using DerivableInterfaces: DerivableInterfaces, zero!
576+ function DerivableInterfaces. zero! (a:: EyeKronecker )
577+ zero! (a. b)
578+ return a
579+ end
580+ function DerivableInterfaces. zero! (a:: KroneckerEye )
581+ zero! (a. a)
582+ return a
583+ end
584+ function DerivableInterfaces. zero! (a:: EyeEye )
585+ return throw (ArgumentError (" Can't zero out `Eye ⊗ Eye`." ))
586+ end
587+
509588function Base.:* (a:: Number , b:: EyeKronecker )
510589 return b. a ⊗ (a * b. b)
511590end
@@ -580,29 +659,44 @@ end
580659function Base. map! (:: typeof (identity), dest:: EyeEye , a:: EyeEye )
581660 return error (" Can't write in-place." )
582661end
583- function Base. map! (f:: typeof (+ ), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
584- if dest. a ≠ a. a ≠ b. a
585- throw (
586- ArgumentError (
587- " KroneckerArray addition is only supported when the first or second arguments match." ,
588- ),
589- )
662+ for f in [:+ , :- ]
663+ @eval begin
664+ function Base. map! (:: typeof ($ f), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
665+ if dest. a ≠ a. a ≠ b. a
666+ throw (
667+ ArgumentError (
668+ " KroneckerArray addition is only supported when the first or second arguments match." ,
669+ ),
670+ )
671+ end
672+ map! ($ f, dest. b, a. b, b. b)
673+ return dest
674+ end
675+ function Base. map! (:: typeof ($ f), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
676+ if dest. b ≠ a. b ≠ b. b
677+ throw (
678+ ArgumentError (
679+ " KroneckerArray addition is only supported when the first or second arguments match." ,
680+ ),
681+ )
682+ end
683+ map! ($ f, dest. a, a. a, b. a)
684+ return dest
685+ end
686+ function Base. map! (:: typeof ($ f), dest:: EyeEye , a:: EyeEye , b:: EyeEye )
687+ return error (" Can't write in-place." )
688+ end
590689 end
591- map! (f, dest. b, a. b, b. b)
690+ end
691+ function Base. map! (f:: typeof (- ), dest:: EyeKronecker , a:: EyeKronecker )
692+ map! (f, dest. b, a. b)
592693 return dest
593694end
594- function Base. map! (f:: typeof (+ ), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
595- if dest. b ≠ a. b ≠ b. b
596- throw (
597- ArgumentError (
598- " KroneckerArray addition is only supported when the first or second arguments match." ,
599- ),
600- )
601- end
695+ function Base. map! (f:: typeof (- ), dest:: KroneckerEye , a:: KroneckerEye )
602696 map! (f, dest. a, a. a, b. a)
603697 return dest
604698end
605- function Base. map! (f:: typeof (+ ), dest:: EyeEye , a:: EyeEye , b :: EyeEye )
699+ function Base. map! (f:: typeof (- ), dest:: EyeEye , a:: EyeEye )
606700 return error (" Can't write in-place." )
607701end
608702function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: EyeKronecker , a:: EyeKronecker )
@@ -812,6 +906,39 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
812906const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
813907const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
814908
909+ # Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
910+ function Base. similar (
911+ arrayt:: Type{<:SquareEyeKronecker{<:Any,<:Any,A}} ,
912+ elt:: Type ,
913+ axs:: NTuple{2,CartesianProductUnitRange{<:Integer}} ,
914+ ) where {A}
915+ ax_a = map (ax -> ax. product. a, axs)
916+ ax_b = map (ax -> ax. product. b, axs)
917+ eye_ax_a = (only (unique (ax_a)),)
918+ return Eye {elt} (eye_ax_a) ⊗ similar (A, elt, ax_b)
919+ end
920+ function Base. similar (
921+ arrayt:: Type{<:KroneckerSquareEye{<:Any,A}} ,
922+ elt:: Type ,
923+ axs:: NTuple{2,CartesianProductUnitRange{<:Integer}} ,
924+ ) where {A}
925+ ax_a = map (ax -> ax. product. a, axs)
926+ ax_b = map (ax -> ax. product. b, axs)
927+ eye_ax_b = (only (unique (ax_b)),)
928+ return similar (A, elt, ax_a) ⊗ Eye {elt} (eye_ax_b)
929+ end
930+ function Base. similar (
931+ arrayt:: Type{<:SquareEyeSquareEye} ,
932+ elt:: Type ,
933+ axs:: NTuple{2,CartesianProductUnitRange{<:Integer}} ,
934+ )
935+ ax_a = map (ax -> ax. product. a, axs)
936+ ax_b = map (ax -> ax. product. b, axs)
937+ eye_ax_a = (only (unique (ax_a)),)
938+ eye_ax_b = (only (unique (ax_b)),)
939+ return Eye {elt} (eye_ax_a) ⊗ Eye {elt} (eye_ax_b)
940+ end
941+
815942struct SquareEyeAlgorithm{KWargs<: NamedTuple } <: AbstractAlgorithm
816943 kwargs:: KWargs
817944end
0 commit comments