Skip to content

Commit 9add5ec

Browse files
committed
Cleanup
1 parent a6f082c commit 9add5ec

File tree

3 files changed

+91
-180
lines changed

3 files changed

+91
-180
lines changed

src/kroneckerarray.jl

Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1-
# Allows customizing for `FillArrays.Eye`.
2-
function _convert(A::Type{<:AbstractArray}, a::AbstractArray)
3-
return convert(A, a)
1+
function unwrap_array(a::AbstractArray)
2+
p = parent(a)
3+
p a && return a
4+
return unwrap_array(p)
45
end
6+
isactive(a::AbstractArray) = ismutable(unwrap_array(a))
7+
58
# Custom `_convert` works around the issue that
69
# `convert(::Type{<:Diagonal}, ::AbstractMatrix)` isn't defined
710
# in Julia v1.10 (https://github.com/JuliaLang/julia/pull/48895,
811
# https://github.com/JuliaLang/julia/pull/52487).
912
# TODO: Delete once we drop support for Julia v1.10.
13+
function _convert(A::Type{<:AbstractArray}, a::AbstractArray)
14+
return convert(A, a)
15+
end
1016
using LinearAlgebra: LinearAlgebra, Diagonal, diag, isdiag
1117
_construct(A::Type{<:Diagonal}, a::AbstractMatrix) = A(diag(a))
1218
function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
@@ -33,38 +39,35 @@ const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerAr
3339
arg1(a::KroneckerArray) = a.a
3440
arg2(a::KroneckerArray) = a.b
3541

42+
function mutate_active_args!(f!, f, dest, src)
43+
(isactive(arg1(dest)) || isactive(arg2(dest))) ||
44+
error("Can't mutate immutable KroneckerArray.")
45+
if isactive(arg1(dest))
46+
f!(arg1(dest), arg1(src))
47+
else
48+
arg1(dest) == f(arg1(src)) || error("Immutable arguments aren't equal.")
49+
end
50+
if isactive(arg2(dest))
51+
f!(arg2(dest), arg2(src))
52+
else
53+
arg2(dest) == f(arg2(src)) || error("Immutable arguments aren't equal.")
54+
end
55+
return dest
56+
end
57+
3658
using Adapt: Adapt, adapt
3759
Adapt.adapt_structure(to, a::KroneckerArray) = adapt(to, arg1(a)) adapt(to, arg2(a))
3860

3961
function Base.copy(a::KroneckerArray)
4062
return copy(arg1(a)) copy(arg2(a))
4163
end
4264

43-
# Allows extra customization, like for `FillArrays.Eye`.
44-
function _copyto!!(dest::AbstractArray{<:Any,N}, src::AbstractArray{<:Any,N}) where {N}
45-
copyto!(dest, src)
46-
return dest
47-
end
48-
using Base.Broadcast: Broadcasted
49-
function _copyto!!(dest::AbstractArray, src::Broadcasted)
50-
copyto!(dest, src)
51-
return dest
52-
end
53-
5465
function Base.copyto!(dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}) where {N}
55-
return copyto!_kronecker(dest, src)
56-
end
57-
function copyto!_kronecker(
58-
dest::KroneckerArray{<:Any,N}, src::KroneckerArray{<:Any,N}
59-
) where {N}
60-
# TODO: Check if neither argument is mutated and if so error.
61-
_copyto!!(arg1(dest), arg1(src))
62-
_copyto!!(arg2(dest), arg2(src))
63-
return dest
66+
return mutate_active_args!(copyto!, copy, dest, src)
6467
end
6568

6669
function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where {T,N,A,B}
67-
return KroneckerArray(_convert(A, arg1(a)), _convert(B, arg2(a)))
70+
return _convert(A, arg1(a)) _convert(B, arg2(a))
6871
end
6972

7073
# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
@@ -124,21 +127,18 @@ function Base.similar(
124127
return similar(promote_type(A, B), sz)
125128
end
126129

127-
function _permutedims!!(dest::AbstractArray, src::AbstractArray, perm)
128-
permutedims!(dest, src, perm)
129-
return dest
130+
function Base.permutedims(a::KroneckerArray, perm)
131+
return permutedims(arg1(a), perm) permutedims(arg2(a), perm)
130132
end
131-
132133
using DerivableInterfaces: DerivableInterfaces, permuteddims
133134
function DerivableInterfaces.permuteddims(a::KroneckerArray, perm)
134135
return permuteddims(arg1(a), perm) permuteddims(arg2(a), perm)
135136
end
136137

137138
function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm)
138-
# TODO: Error if neither argument is mutable.
139-
_permutedims!!(arg1(dest), arg1(src), perm)
140-
_permutedims!!(arg2(dest), arg2(src), perm)
141-
return dest
139+
return mutate_active_args!(
140+
(dest, src) -> permutedims!(dest, src, perm), Base.Fix2(permutedims, perm), dest, src
141+
)
142142
end
143143

144144
function flatten(t::Tuple{Tuple,Tuple,Vararg{Tuple}})
@@ -172,11 +172,10 @@ Base.zero(a::KroneckerArray) = zero(arg1(a)) ⊗ zero(arg2(a))
172172

173173
using DerivableInterfaces: DerivableInterfaces, zero!
174174
function DerivableInterfaces.zero!(a::KroneckerArray)
175-
ismut1 = ismutable(arg1(a))
176-
ismut2 = ismutable(arg2(a))
177-
(ismut1 || ismut2) || throw(ArgumentError("Can't zero out immutable KroneckerArray."))
178-
ismut1 && zero!(arg1(a))
179-
ismut2 && zero!(arg2(a))
175+
(isactive(arg1(a)) || isactive(arg2(a))) ||
176+
error("Can't mutate immutable KroneckerArray.")
177+
isactive(arg1(a)) && zero!(arg1(a))
178+
isactive(arg2(a)) && zero!(arg2(a))
180179
return a
181180
end
182181

@@ -293,7 +292,7 @@ function Base.real(a::KroneckerArray)
293292
if iszero(imag(arg1(a))) || iszero(imag(arg2(a)))
294293
return real(arg1(a)) real(arg2(a))
295294
elseif iszero(real(arg1(a))) || iszero(real(arg2(a)))
296-
return -imag(arg1(a)) imag(arg2(a))
295+
return -(imag(arg1(a)) imag(arg2(a)))
297296
end
298297
return real(arg1(a)) real(arg2(a)) - imag(arg1(a)) imag(arg2(a))
299298
end
@@ -321,9 +320,6 @@ function Base.reshape(
321320
return reshape(arg1(a), map(arg1, ax)) reshape(arg2(a), map(arg2, ax))
322321
end
323322

324-
# Allows for customizations for FillArrays.
325-
_BroadcastStyle(x) = BroadcastStyle(x)
326-
327323
using Base.Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
328324
struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
329325
arg1(::Type{<:KroneckerStyle{<:Any,A}}) where {A} = A
@@ -340,7 +336,7 @@ function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
340336
return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}()
341337
end
342338
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B}
343-
return KroneckerStyle{N}(_BroadcastStyle(A), _BroadcastStyle(B))
339+
return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B))
344340
end
345341
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
346342
style_a = BroadcastStyle(arg1(style1), arg1(style2))
@@ -366,23 +362,34 @@ function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::Kronecke
366362
end
367363

368364
using MapBroadcast: MapBroadcast, LinearCombination, Summed
369-
function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle})
370-
dest1 = arg1(dest)
371-
dest2 = arg2(dest)
365+
function KroneckerBroadcast(a::Summed{<:KroneckerStyle})
372366
f = LinearCombination(a)
373367
args = MapBroadcast.arguments(a)
374368
arg1s = arg1.(args)
375369
arg2s = arg2.(args)
376-
if allequal(arg2s)
377-
copyto!(dest2, first(arg2s))
378-
dest1 .= f.(arg1s...)
379-
elseif allequal(arg1s)
380-
copyto!(dest1, first(arg1s))
381-
dest2 .= f.(arg2s...)
382-
else
370+
arg1_isunique = allequal(arg1s)
371+
arg2_isunique = allequal(arg2s)
372+
(arg1_isunique || arg2_isunique) ||
383373
error("This operation doesn't preserve the Kronecker structure.")
374+
broadcast_arg = if arg1_isunique && arg2_isunique
375+
isactive(first(arg1s)) ? 1 : 2
376+
elseif arg1_isunique
377+
2
378+
elseif arg2_isunique
379+
1
384380
end
385-
return dest
381+
return if broadcast_arg == 1
382+
broadcasted(f, arg1s...) first(arg2s)
383+
elseif broadcast_arg == 2
384+
first(arg1s) broadcasted(f, arg2s...)
385+
end
386+
end
387+
388+
function Base.copy(a::Summed{<:KroneckerStyle})
389+
return copy(KroneckerBroadcast(a))
390+
end
391+
function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle})
392+
return copyto!(dest, KroneckerBroadcast(a))
386393
end
387394

388395
function Broadcast.broadcasted(::KroneckerStyle, f, as...)
@@ -424,11 +431,31 @@ function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(/),<:N
424431
return broadcasted(style, /, a, f.x)
425432
end
426433

434+
# Compatibility with MapBroadcast.jl.
435+
using MapBroadcast: MapBroadcast, MapFunction
436+
function Base.broadcasted(
437+
style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, a
438+
)
439+
return broadcasted(style, *, f.args[1], a)
440+
end
441+
function Base.broadcasted(
442+
style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, a
443+
)
444+
return broadcasted(style, *, a, f.args[2])
445+
end
446+
function Base.broadcasted(
447+
style::KroneckerStyle, f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, a
448+
)
449+
return broadcasted(style, /, a, f.args[2])
450+
end
427451
# Use to determine the element type of KroneckerBroadcasted.
428452
_eltype(x) = eltype(x)
429453
_eltype(x::Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...)
430454

431455
using Base.Broadcast: broadcasted
456+
# Represents broadcast operations that can be applied Kronecker-wise,
457+
# i.e. independently to each argument of the Kronecker product.
458+
# Note that not all broadcast operations can be mapped to this.
432459
struct KroneckerBroadcasted{A,B}
433460
a::A
434461
b::B
@@ -442,10 +469,8 @@ Broadcast.materialize(a::KroneckerBroadcasted) = copy(a)
442469
Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a)
443470
Broadcast.broadcastable(a::KroneckerBroadcasted) = a
444471
Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) copy(arg2(a))
445-
function Base.copyto!(dest::KroneckerArray, a::KroneckerBroadcasted)
446-
_copyto!!(arg1(dest), arg1(a))
447-
_copyto!!(arg2(dest), arg2(a))
448-
return dest
472+
function Base.copyto!(dest::KroneckerArray, src::KroneckerBroadcasted)
473+
return mutate_active_args!(copyto!, copy, dest, src)
449474
end
450475
function Base.eltype(a::KroneckerBroadcasted)
451476
a1 = arg1(a)
@@ -474,21 +499,3 @@ for f in [:identity, :conj]
474499
end
475500
end
476501
end
477-
478-
# Compatibility with MapBroadcast.jl.
479-
using MapBroadcast: MapBroadcast, MapFunction
480-
function Base.broadcasted(
481-
style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, a
482-
)
483-
return broadcasted(style, *, f.args[1], a)
484-
end
485-
function Base.broadcasted(
486-
style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, a
487-
)
488-
return broadcasted(style, *, a, f.args[2])
489-
end
490-
function Base.broadcasted(
491-
style::KroneckerStyle, f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, a
492-
)
493-
return broadcasted(style, /, a, f.args[2])
494-
end

src/linearcombination.jl

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)