1- # Allows customizing for `FillArrays.Eye`.
2- function _convert (A:: Type{<:AbstractArray} , a:: AbstractArray )
3- return convert (A, a)
1+ function unwrap_array (a:: AbstractArray )
2+ p = parent (a)
3+ p ≡ a && return a
4+ return unwrap_array (p)
45end
6+ isactive (a:: AbstractArray ) = ismutable (unwrap_array (a))
7+
58# Custom `_convert` works around the issue that
69# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isn't defined
710# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
811# https://github.com/JuliaLang/julia/pull/52487).
912# TODO : Delete once we drop support for Julia v1.10.
13+ function _convert (A:: Type{<:AbstractArray} , a:: AbstractArray )
14+ return convert (A, a)
15+ end
1016using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
1117_construct (A:: Type{<:Diagonal} , a:: AbstractMatrix ) = A (diag (a))
1218function _convert (A:: Type{<:Diagonal} , a:: AbstractMatrix )
@@ -33,38 +39,35 @@ const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerAr
3339arg1 (a:: KroneckerArray ) = a. a
3440arg2 (a:: KroneckerArray ) = a. b
3541
42+ function mutate_active_args! (f!, f, dest, src)
43+ (isactive (arg1 (dest)) || isactive (arg2 (dest))) ||
44+ error (" Can't mutate immutable KroneckerArray." )
45+ if isactive (arg1 (dest))
46+ f! (arg1 (dest), arg1 (src))
47+ else
48+ arg1 (dest) == f (arg1 (src)) || error (" Immutable arguments aren't equal." )
49+ end
50+ if isactive (arg2 (dest))
51+ f! (arg2 (dest), arg2 (src))
52+ else
53+ arg2 (dest) == f (arg2 (src)) || error (" Immutable arguments aren't equal." )
54+ end
55+ return dest
56+ end
57+
3658using Adapt: Adapt, adapt
3759Adapt. adapt_structure (to, a:: KroneckerArray ) = adapt (to, arg1 (a)) ⊗ adapt (to, arg2 (a))
3860
3961function Base. copy (a:: KroneckerArray )
4062 return copy (arg1 (a)) ⊗ copy (arg2 (a))
4163end
4264
43- # Allows extra customization, like for `FillArrays.Eye`.
44- function _copyto!! (dest:: AbstractArray{<:Any,N} , src:: AbstractArray{<:Any,N} ) where {N}
45- copyto! (dest, src)
46- return dest
47- end
48- using Base. Broadcast: Broadcasted
49- function _copyto!! (dest:: AbstractArray , src:: Broadcasted )
50- copyto! (dest, src)
51- return dest
52- end
53-
5465function Base. copyto! (dest:: KroneckerArray{<:Any,N} , src:: KroneckerArray{<:Any,N} ) where {N}
55- return copyto!_kronecker (dest, src)
56- end
57- function copyto!_kronecker (
58- dest:: KroneckerArray{<:Any,N} , src:: KroneckerArray{<:Any,N}
59- ) where {N}
60- # TODO : Check if neither argument is mutated and if so error.
61- _copyto!! (arg1 (dest), arg1 (src))
62- _copyto!! (arg2 (dest), arg2 (src))
63- return dest
66+ return mutate_active_args! (copyto!, copy, dest, src)
6467end
6568
6669function Base. convert (:: Type{KroneckerArray{T,N,A,B}} , a:: KroneckerArray ) where {T,N,A,B}
67- return KroneckerArray ( _convert (A, arg1 (a)), _convert (B, arg2 (a) ))
70+ return _convert (A, arg1 (a)) ⊗ _convert (B, arg2 (a))
6871end
6972
7073# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
@@ -124,21 +127,18 @@ function Base.similar(
124127 return similar (promote_type (A, B), sz)
125128end
126129
127- function _permutedims!! (dest:: AbstractArray , src:: AbstractArray , perm)
128- permutedims! (dest, src, perm)
129- return dest
130+ function Base. permutedims (a:: KroneckerArray , perm)
131+ return permutedims (arg1 (a), perm) ⊗ permutedims (arg2 (a), perm)
130132end
131-
132133using DerivableInterfaces: DerivableInterfaces, permuteddims
133134function DerivableInterfaces. permuteddims (a:: KroneckerArray , perm)
134135 return permuteddims (arg1 (a), perm) ⊗ permuteddims (arg2 (a), perm)
135136end
136137
137138function Base. permutedims! (dest:: KroneckerArray , src:: KroneckerArray , perm)
138- # TODO : Error if neither argument is mutable.
139- _permutedims!! (arg1 (dest), arg1 (src), perm)
140- _permutedims!! (arg2 (dest), arg2 (src), perm)
141- return dest
139+ return mutate_active_args! (
140+ (dest, src) -> permutedims! (dest, src, perm), Base. Fix2 (permutedims, perm), dest, src
141+ )
142142end
143143
144144function flatten (t:: Tuple{Tuple,Tuple,Vararg{Tuple}} )
@@ -172,11 +172,10 @@ Base.zero(a::KroneckerArray) = zero(arg1(a)) ⊗ zero(arg2(a))
172172
173173using DerivableInterfaces: DerivableInterfaces, zero!
174174function DerivableInterfaces. zero! (a:: KroneckerArray )
175- ismut1 = ismutable (arg1 (a))
176- ismut2 = ismutable (arg2 (a))
177- (ismut1 || ismut2) || throw (ArgumentError (" Can't zero out immutable KroneckerArray." ))
178- ismut1 && zero! (arg1 (a))
179- ismut2 && zero! (arg2 (a))
175+ (isactive (arg1 (a)) || isactive (arg2 (a))) ||
176+ error (" Can't mutate immutable KroneckerArray." )
177+ isactive (arg1 (a)) && zero! (arg1 (a))
178+ isactive (arg2 (a)) && zero! (arg2 (a))
180179 return a
181180end
182181
@@ -293,7 +292,7 @@ function Base.real(a::KroneckerArray)
293292 if iszero (imag (arg1 (a))) || iszero (imag (arg2 (a)))
294293 return real (arg1 (a)) ⊗ real (arg2 (a))
295294 elseif iszero (real (arg1 (a))) || iszero (real (arg2 (a)))
296- return - imag (arg1 (a)) ⊗ imag (arg2 (a))
295+ return - ( imag (arg1 (a)) ⊗ imag (arg2 (a) ))
297296 end
298297 return real (arg1 (a)) ⊗ real (arg2 (a)) - imag (arg1 (a)) ⊗ imag (arg2 (a))
299298end
@@ -321,9 +320,6 @@ function Base.reshape(
321320 return reshape (arg1 (a), map (arg1, ax)) ⊗ reshape (arg2 (a), map (arg2, ax))
322321end
323322
324- # Allows for customizations for FillArrays.
325- _BroadcastStyle (x) = BroadcastStyle (x)
326-
327323using Base. Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
328324struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
329325arg1 (:: Type{<:KroneckerStyle{<:Any,A}} ) where {A} = A
@@ -340,7 +336,7 @@ function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
340336 return KroneckerStyle {M,typeof(A)(v),typeof(B)(v)} ()
341337end
342338function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A,B}} ) where {N,A,B}
343- return KroneckerStyle {N} (_BroadcastStyle (A), _BroadcastStyle (B))
339+ return KroneckerStyle {N} (BroadcastStyle (A), BroadcastStyle (B))
344340end
345341function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
346342 style_a = BroadcastStyle (arg1 (style1), arg1 (style2))
@@ -366,23 +362,34 @@ function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::Kronecke
366362end
367363
368364using MapBroadcast: MapBroadcast, LinearCombination, Summed
369- function Base. copyto! (dest:: KroneckerArray , a:: Summed{<:KroneckerStyle} )
370- dest1 = arg1 (dest)
371- dest2 = arg2 (dest)
365+ function KroneckerBroadcast (a:: Summed{<:KroneckerStyle} )
372366 f = LinearCombination (a)
373367 args = MapBroadcast. arguments (a)
374368 arg1s = arg1 .(args)
375369 arg2s = arg2 .(args)
376- if allequal (arg2s)
377- copyto! (dest2, first (arg2s))
378- dest1 .= f .(arg1s... )
379- elseif allequal (arg1s)
380- copyto! (dest1, first (arg1s))
381- dest2 .= f .(arg2s... )
382- else
370+ arg1_isunique = allequal (arg1s)
371+ arg2_isunique = allequal (arg2s)
372+ (arg1_isunique || arg2_isunique) ||
383373 error (" This operation doesn't preserve the Kronecker structure." )
374+ broadcast_arg = if arg1_isunique && arg2_isunique
375+ isactive (first (arg1s)) ? 1 : 2
376+ elseif arg1_isunique
377+ 2
378+ elseif arg2_isunique
379+ 1
384380 end
385- return dest
381+ return if broadcast_arg == 1
382+ broadcasted (f, arg1s... ) ⊗ first (arg2s)
383+ elseif broadcast_arg == 2
384+ first (arg1s) ⊗ broadcasted (f, arg2s... )
385+ end
386+ end
387+
388+ function Base. copy (a:: Summed{<:KroneckerStyle} )
389+ return copy (KroneckerBroadcast (a))
390+ end
391+ function Base. copyto! (dest:: KroneckerArray , a:: Summed{<:KroneckerStyle} )
392+ return copyto! (dest, KroneckerBroadcast (a))
386393end
387394
388395function Broadcast. broadcasted (:: KroneckerStyle , f, as... )
@@ -424,11 +431,31 @@ function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(/),<:N
424431 return broadcasted (style, / , a, f. x)
425432end
426433
434+ # Compatibility with MapBroadcast.jl.
435+ using MapBroadcast: MapBroadcast, MapFunction
436+ function Base. broadcasted (
437+ style:: KroneckerStyle , f:: MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}} , a
438+ )
439+ return broadcasted (style, * , f. args[1 ], a)
440+ end
441+ function Base. broadcasted (
442+ style:: KroneckerStyle , f:: MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}} , a
443+ )
444+ return broadcasted (style, * , a, f. args[2 ])
445+ end
446+ function Base. broadcasted (
447+ style:: KroneckerStyle , f:: MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}} , a
448+ )
449+ return broadcasted (style, / , a, f. args[2 ])
450+ end
427451# Use to determine the element type of KroneckerBroadcasted.
428452_eltype (x) = eltype (x)
429453_eltype (x:: Broadcasted ) = Base. promote_op (x. f, _eltype .(x. args)... )
430454
431455using Base. Broadcast: broadcasted
456+ # Represents broadcast operations that can be applied Kronecker-wise,
457+ # i.e. independently to each argument of the Kronecker product.
458+ # Note that not all broadcast operations can be mapped to this.
432459struct KroneckerBroadcasted{A,B}
433460 a:: A
434461 b:: B
@@ -442,10 +469,8 @@ Broadcast.materialize(a::KroneckerBroadcasted) = copy(a)
442469Broadcast. materialize! (dest, a:: KroneckerBroadcasted ) = copyto! (dest, a)
443470Broadcast. broadcastable (a:: KroneckerBroadcasted ) = a
444471Base. copy (a:: KroneckerBroadcasted ) = copy (arg1 (a)) ⊗ copy (arg2 (a))
445- function Base. copyto! (dest:: KroneckerArray , a:: KroneckerBroadcasted )
446- _copyto!! (arg1 (dest), arg1 (a))
447- _copyto!! (arg2 (dest), arg2 (a))
448- return dest
472+ function Base. copyto! (dest:: KroneckerArray , src:: KroneckerBroadcasted )
473+ return mutate_active_args! (copyto!, copy, dest, src)
449474end
450475function Base. eltype (a:: KroneckerBroadcasted )
451476 a1 = arg1 (a)
@@ -474,21 +499,3 @@ for f in [:identity, :conj]
474499 end
475500 end
476501end
477-
478- # Compatibility with MapBroadcast.jl.
479- using MapBroadcast: MapBroadcast, MapFunction
480- function Base. broadcasted (
481- style:: KroneckerStyle , f:: MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}} , a
482- )
483- return broadcasted (style, * , f. args[1 ], a)
484- end
485- function Base. broadcasted (
486- style:: KroneckerStyle , f:: MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}} , a
487- )
488- return broadcasted (style, * , a, f. args[2 ])
489- end
490- function Base. broadcasted (
491- style:: KroneckerStyle , f:: MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}} , a
492- )
493- return broadcasted (style, / , a, f. args[2 ])
494- end
0 commit comments