From 905b72fc649e0f20104daef6dcad4ebac53689e5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 21 Jun 2025 14:36:24 -0400 Subject: [PATCH 1/6] Simplify and generalize map and broadcasting --- src/KroneckerArrays.jl | 1 + src/fillarrays/kroneckerarray.jl | 304 +++++++++++------------ src/kroneckerarray.jl | 410 +++++++++++++++---------------- test/Project.toml | 2 + test/test_basics.jl | 8 +- 5 files changed, 352 insertions(+), 373 deletions(-) diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 4552a2f..28bdf3e 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -2,6 +2,7 @@ module KroneckerArrays export ⊗, × +include("linearcombination.jl") include("cartesianproduct.jl") include("kroneckerarray.jl") include("linearalgebra.jl") diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index f8d17b1..3454b7b 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -81,155 +81,155 @@ function DerivableInterfaces.zero!(a::EyeEye) return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) end -function Base.:*(a::Number, b::EyeKronecker) - return b.a ⊗ (a * b.b) -end -function Base.:*(a::Number, b::KroneckerEye) - return (a * b.a) ⊗ b.b -end -function Base.:*(a::Number, b::EyeEye) - return error("Can't multiply `Eye ⊗ Eye` by a number.") -end -function Base.:*(a::EyeKronecker, b::Number) - return a.a ⊗ (a.b * b) -end -function Base.:*(a::KroneckerEye, b::Number) - return (a.a * b) ⊗ a.b -end -function Base.:*(a::EyeEye, b::Number) - return error("Can't multiply `Eye ⊗ Eye` by a number.") -end - -function Base.:-(a::EyeKronecker) - return a.a ⊗ (-a.b) -end -function Base.:-(a::KroneckerEye) - return (-a.a) ⊗ a.b -end -function Base.:-(a::EyeEye) - return error("Can't multiply `Eye ⊗ Eye` by a number.") -end - -for op in (:+, :-) - @eval begin - function Base.$op(a::EyeKronecker, b::EyeKronecker) - if a.a ≠ b.a - return throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end - return a.a ⊗ $op(a.b, b.b) - end - function Base.$op(a::KroneckerEye, b::KroneckerEye) - if a.b ≠ b.b - return throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end - return $op(a.a, b.a) ⊗ a.b - end - function Base.$op(a::EyeEye, b::EyeEye) - if a.b ≠ b.b - return throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end - return $op(a.a, b.a) ⊗ a.b - end - end -end - -function Base.map!(f::typeof(identity), dest::EyeKronecker, src::EyeKronecker) - map!(f, dest.b, src.b) - return dest -end -function Base.map!(f::typeof(identity), dest::KroneckerEye, src::KroneckerEye) - map!(f, dest.a, src.a) - return dest -end -function Base.map!(::typeof(identity), dest::EyeEye, src::EyeEye) - return error("Can't write in-place.") -end -for f in [:+, :-] - @eval begin - function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) - if dest.a ≠ a.a ≠ b.a - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) - end - map!($f, dest.b, a.b, b.b) - return dest - end - function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) - if dest.b ≠ a.b ≠ b.b - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) - end - map!($f, dest.a, a.a, b.a) - return dest - end - function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye) - return error("Can't write in-place.") - end - end -end -function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker) - map!(f, dest.b, a.b) - return dest -end -function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye) - map!(f, dest.a, a.a) - return dest -end -function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye) - return error("Can't write in-place.") -end -function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) - map!(f, dest.b, a.b) - return dest -end -function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) - map!(f, dest.a, a.a) - return dest -end -function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) - return error("Can't write in-place.") -end -function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) - map!(f, dest.b, a.b) - return dest -end -function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) - map!(f, dest.a, a.a) - return dest -end -function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) - return error("Can't write in-place.") -end - -using Base.Broadcast: - AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted - -struct EyeStyle <: AbstractArrayStyle{2} end -EyeStyle(::Val{2}) = EyeStyle() -function _BroadcastStyle(::Type{<:Eye}) - return EyeStyle() -end -Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() -Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2 - -function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) - return Eye{elt}(axes(bc)) -end +## function Base.:*(a::Number, b::EyeKronecker) +## return b.a ⊗ (a * b.b) +## end +## function Base.:*(a::Number, b::KroneckerEye) +## return (a * b.a) ⊗ b.b +## end +## function Base.:*(a::Number, b::EyeEye) +## return error("Can't multiply `Eye ⊗ Eye` by a number.") +## end +## function Base.:*(a::EyeKronecker, b::Number) +## return a.a ⊗ (a.b * b) +## end +## function Base.:*(a::KroneckerEye, b::Number) +## return (a.a * b) ⊗ a.b +## end +## function Base.:*(a::EyeEye, b::Number) +## return error("Can't multiply `Eye ⊗ Eye` by a number.") +## end +## +## function Base.:-(a::EyeKronecker) +## return a.a ⊗ (-a.b) +## end +## function Base.:-(a::KroneckerEye) +## return (-a.a) ⊗ a.b +## end +## function Base.:-(a::EyeEye) +## return error("Can't multiply `Eye ⊗ Eye` by a number.") +## end +## +## for op in (:+, :-) +## @eval begin +## function Base.$op(a::EyeKronecker, b::EyeKronecker) +## if a.a ≠ b.a +## return throw( +## ArgumentError( +## "KroneckerArray addition is only supported when the first or secord arguments match.", +## ), +## ) +## end +## return a.a ⊗ $op(a.b, b.b) +## end +## function Base.$op(a::KroneckerEye, b::KroneckerEye) +## if a.b ≠ b.b +## return throw( +## ArgumentError( +## "KroneckerArray addition is only supported when the first or secord arguments match.", +## ), +## ) +## end +## return $op(a.a, b.a) ⊗ a.b +## end +## function Base.$op(a::EyeEye, b::EyeEye) +## if a.b ≠ b.b +## return throw( +## ArgumentError( +## "KroneckerArray addition is only supported when the first or secord arguments match.", +## ), +## ) +## end +## return $op(a.a, b.a) ⊗ a.b +## end +## end +## end +## +## function Base.map!(f::typeof(identity), dest::EyeKronecker, src::EyeKronecker) +## map!(f, dest.b, src.b) +## return dest +## end +## function Base.map!(f::typeof(identity), dest::KroneckerEye, src::KroneckerEye) +## map!(f, dest.a, src.a) +## return dest +## end +## function Base.map!(::typeof(identity), dest::EyeEye, src::EyeEye) +## return error("Can't write in-place.") +## end +## for f in [:+, :-] +## @eval begin +## function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) +## if dest.a ≠ a.a ≠ b.a +## throw( +## ArgumentError( +## "KroneckerArray addition is only supported when the first or second arguments match.", +## ), +## ) +## end +## map!($f, dest.b, a.b, b.b) +## return dest +## end +## function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) +## if dest.b ≠ a.b ≠ b.b +## throw( +## ArgumentError( +## "KroneckerArray addition is only supported when the first or second arguments match.", +## ), +## ) +## end +## map!($f, dest.a, a.a, b.a) +## return dest +## end +## function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye) +## return error("Can't write in-place.") +## end +## end +## end +## function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker) +## map!(f, dest.b, a.b) +## return dest +## end +## function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye) +## map!(f, dest.a, a.a) +## return dest +## end +## function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye) +## return error("Can't write in-place.") +## end +## function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) +## map!(f, dest.b, a.b) +## return dest +## end +## function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) +## map!(f, dest.a, a.a) +## return dest +## end +## function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) +## return error("Can't write in-place.") +## end +## function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) +## map!(f, dest.b, a.b) +## return dest +## end +## function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) +## map!(f, dest.a, a.a) +## return dest +## end +## function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) +## return error("Can't write in-place.") +## end +## +## using Base.Broadcast: +## AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted +## +## struct EyeStyle <: AbstractArrayStyle{2} end +## EyeStyle(::Val{2}) = EyeStyle() +## function _BroadcastStyle(::Type{<:Eye}) +## return EyeStyle() +## end +## Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() +## Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2 +## +## function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) +## return Eye{elt}(axes(bc)) +## end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index b20f9b3..7028aa6 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -117,6 +117,8 @@ kron_nd(a::AbstractVector, b::AbstractVector) = kron(a, b) # Eagerly collect arguments to make more general on GPU. Base.collect(a::KroneckerArray) = kron_nd(collect(a.a), collect(a.b)) +Base.zero(a::KroneckerArray) = zero(arg1(a)) ⊗ zero(arg2(a)) + function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} return convert(Array{T,N}, collect(a)) end @@ -202,43 +204,34 @@ function Base.isreal(a::KroneckerArray) return isreal(a.a) && isreal(a.b) end -for f in [:transpose, :adjoint, :inv] - @eval begin - function Base.$f(a::KroneckerArray) - return $f(a.a) ⊗ $f(a.b) - end - end +using DiagonalArrays: DiagonalArrays, diagonal +function DiagonalArrays.diagonal(a::KroneckerArray) + return diagonal(a.a) ⊗ diagonal(a.b) end -function Base.:*(a::Number, b::KroneckerArray) - return (a * b.a) ⊗ b.b -end -function Base.:*(a::KroneckerArray, b::Number) - return a.a ⊗ (a.b * b) -end -function Base.:/(a::KroneckerArray, b::Number) - return a.a ⊗ (a.b / b) +Base.real(a::KroneckerArray{<:Real}) = a +function Base.real(a::KroneckerArray) + if iszero(imag(a.a)) || iszero(imag(a.b)) + return real(a.a) ⊗ real(a.b) + elseif iszero(real(a.a)) || iszero(real(a.b)) + return -imag(a.a) ⊗ imag(a.b) + end + return real(a.a) ⊗ real(a.b) - imag(a.a) ⊗ imag(a.b) end -function Base.:-(a::KroneckerArray) - return (-a.a) ⊗ a.b +Base.imag(a::KroneckerArray{<:Real}) = zero(a) +function Base.imag(a::KroneckerArray) + if iszero(imag(a.a)) || iszero(real(a.b)) + return real(a.a) ⊗ imag(a.b) + elseif iszero(real(a.a)) || iszero(imag(a.b)) + return imag(a.a) ⊗ real(a.b) + end + return real(a.a) ⊗ imag(a.b) + imag(a.a) ⊗ real(a.b) end -for op in (:+, :-) +for f in [:transpose, :adjoint, :inv] @eval begin - function Base.$op(a::KroneckerArray, b::KroneckerArray) - iszero(a) && return $op(b) - iszero(b) && return a - if a.b == b.b - return $op(a.a, b.a) ⊗ a.b - elseif a.a == b.a - return a.a ⊗ $op(a.b, b.b) - else - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or secord arguments match.", - ), - ) - end + function Base.$f(a::KroneckerArray) + return $f(a.a) ⊗ $f(a.b) end end end @@ -271,222 +264,205 @@ function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N (style_b isa Broadcast.Unknown) && return Broadcast.Unknown() return KroneckerStyle{N}(style_a, style_b) end -function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B} - ax_a = arg1.(axes(bc)) - ax_b = arg2.(axes(bc)) +function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type, ax) where {N,A,B} + ax_a = arg1.(ax) + ax_b = arg2.(ax) bc_a = Broadcasted(A, nothing, (), ax_a) bc_b = Broadcasted(B, nothing, (), ax_b) a = similar(bc_a, elt) b = similar(bc_b, elt) return a ⊗ b end -# Fallback definition of broadcasting falls back to `map` but assumes -# inputs have been canonicalized to a map-compatible expression already, -# for example by absorbing scalar arguments into the function. -function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle}) - allequal(axes, bc.args) || throw(ArgumentError("Broadcasted axes must be equal.")) - map!(bc.f, dest, bc.args...) + +function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) + return Broadcast.broadcast_preserving_zero_d(f, a1, a_rest...) +end +function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...) + dest .= f.(a1, a_rest...) return dest end -# Broadcast rewrite rules. Canonicalize inputs to absorb scalar inputs into the -# function. -function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::Number, b::KroneckerArray) - return broadcasted(style, Base.Fix1(*, a), b) +function Base.copyto!(dest::KroneckerArray, a::Sum{<:KroneckerStyle}) + f = LinearCombination(a) + args = arguments(a) + arg1s = arg1.(args) + arg2s = arg2.(args) + dest1 = arg1(dest) + dest2 = arg2(dest) + if allequal(arg2s) + copyto!(dest2, first(arg2s)) + dest1 .= f.(arg1s...) + elseif allequal(arg1s) + copyto!(dest1, first(arg1s)) + dest2 .= f.(arg2s...) + else + error("This operation doesn't preserve the Kronecker structure.") + end + return dest end -function Base.broadcasted(style::KroneckerStyle, ::typeof(*), a::KroneckerArray, b::Number) - return broadcasted(style, Base.Fix2(*, b), a) + +function Broadcast.broadcasted(::KroneckerStyle, f, as...) + return error("Arbitrary broadcasting not supported for KroneckerArray.") end -function Base.broadcasted(style::KroneckerStyle, ::typeof(/), a::KroneckerArray, b::Number) - return broadcasted(style, Base.Fix2(/, b), a) + +# Linear operations. +function Broadcast.broadcasted(::KroneckerStyle, ::typeof(+), a, b) + return Sum(a) + Sum(b) end -using MapBroadcast: MapBroadcast, MapFunction -function Base.broadcasted( - style::KroneckerStyle, - f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, - a::KroneckerArray, -) - return broadcasted(style, Base.Fix1(*, f.args[1]), a) +function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a, b) + return Sum(a) - Sum(b) end -function Base.broadcasted( - style::KroneckerStyle, - f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, - a::KroneckerArray, -) - return broadcasted(style, Base.Fix2(*, f.args[2]), a) +function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), c::Number, a) + return c * Sum(a) end -function Base.broadcasted( - style::KroneckerStyle, - f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, - a::KroneckerArray, -) - return broadcasted(style, Base.Fix2(/, f.args[2]), a) +function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number) + return Sum(a) * c end - -# Simplification rules similar to those for FillArrays.jl: -# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl -using FillArrays: Zeros -function Base.broadcasted( - style::KroneckerStyle, - ::typeof(+), - a::KroneckerArray, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types. - return a -end -function Base.broadcasted( - style::KroneckerStyle, - ::typeof(+), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray, -) - # TODO: Promote the element types. - return b -end -function Base.broadcasted( - style::KroneckerStyle, - ::typeof(+), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types and axes. - return b +function Broadcast.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number) + return Sum(a) / c end -function Base.broadcasted( - style::KroneckerStyle, - ::typeof(-), - a::KroneckerArray, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types. - return a -end -function Base.broadcasted( - style::KroneckerStyle, - ::typeof(-), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray, -) - # TODO: Promote the element types. - # TODO: Return `broadcasted(-, b)`. - return -b -end -function Base.broadcasted( - style::KroneckerStyle, - ::typeof(-), - a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, - b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -) - # TODO: Promote the element types and axes. - return b +function Broadcast.broadcasted(::KroneckerStyle, ::typeof(-), a) + return -Sum(a) end -# TODO: Define by converting to a broadcast expession (with MapBroadcast.jl) -# and then constructing the output with `similar`. -function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) - return throw( - ArgumentError( - "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", - ), - ) +# Rewrite rules to canonicalize broadcast expressions. +function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix1{typeof(*),<:Number}, a) + return broadcasted(style, *, f.x, a) end -function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...) - return throw( - ArgumentError( - "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", - ), - ) +function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(*),<:Number}, a) + return broadcasted(style, *, a, f.x) end - -function _map!!(f::F, dest::AbstractArray, srcs::AbstractArray...) where {F} - map!(f, dest, srcs...) - return dest +function Broadcast.broadcasted(style::KroneckerStyle, f::Base.Fix2{typeof(/),<:Number}, a) + return broadcasted(style, /, a, f.x) end -for f in [:identity, :conj] - @eval begin - function Base.map!(::typeof($f), dest::KroneckerArray, src::KroneckerArray) - _map!!($f, dest.a, src.a) - _map!!($f, dest.b, src.b) - return dest - end - end -end +# Use to determine the element type of KroneckerBroadcasted. +_eltype(x) = eltype(x) +_eltype(x::Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...) -for f in [:+, :-] - @eval begin - function Base.map!( - ::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray - ) - iszero(b) && return map!(identity, dest, a) - iszero(a) && return map!($f, dest, b) - if a.b == b.b - map!($f, dest.a, a.a, b.a) - map!(identity, dest.b, a.b) - return dest - elseif a.a == b.a - map!(identity, dest.a, a.a) - map!($f, dest.b, a.b, b.b) - return dest - else - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) - end - end - end +using Base.Broadcast: broadcasted +struct KroneckerBroadcasted{A<:Broadcasted,B<:Broadcasted} + a::A + b::B end - -function Base.map!( - f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, src::KroneckerArray -) - map!(f, dest.a, src.a) - map!(identity, dest.b, src.b) +arg1(a::KroneckerBroadcasted) = a.a +arg2(a::KroneckerBroadcasted) = a.b +⊗(a::Broadcasted, b::Broadcasted) = KroneckerBroadcasted(a, b) +Broadcast.materialize(a::KroneckerBroadcasted) = copy(a) +Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a) +Broadcast.broadcastable(a::KroneckerBroadcasted) = a +Base.copy(a::KroneckerBroadcasted) = copy(arg1(a)) ⊗ copy(arg2(a)) +function Base.copyto!(dest::KroneckerArray, a::KroneckerBroadcasted) + copyto!(arg1(dest), copy(arg1(a))) + copyto!(arg2(dest), copy(arg2(a))) return dest end - -for op in [:*, :/] - @eval begin - function Base.map!( - f::Base.Fix2{typeof($op),<:Number}, dest::KroneckerArray, src::KroneckerArray - ) - map!(identity, dest.a, src.a) - map!(f, dest.b, src.b) - return dest - end - end +function Base.eltype(a::KroneckerBroadcasted) + a1 = arg1(a) + a2 = arg2(a) + return Base.promote_op(*, _eltype(a1), _eltype(a2)) end -for f in [:+, :-] - @eval begin - function Base.map!(::typeof($f), dest::KroneckerArray, src::KroneckerArray) - map!($f, dest.a, src.a) - map!(identity, dest.b, src.b) - return dest - end - end +function Base.axes(a::KroneckerBroadcasted) + ax1 = axes(arg1(a)) + ax2 = axes(arg2(a)) + return cartesianrange.(ax1 .× ax2) end -using DiagonalArrays: DiagonalArrays, diagonal -function DiagonalArrays.diagonal(a::KroneckerArray) - return diagonal(a.a) ⊗ diagonal(a.b) +function Base.BroadcastStyle( + ::Type{<:KroneckerBroadcasted{A,B}} +) where {StyleA,StyleB,A<:Broadcasted{StyleA},B<:Broadcasted{StyleB}} + @assert ndims(A) == ndims(B) + N = ndims(A) + return KroneckerStyle{N}(StyleA(), StyleB()) end -function Base.real(a::KroneckerArray) - if iszero(imag(a.a)) || iszero(imag(a.b)) - return real(a.a) ⊗ real(a.b) - elseif iszero(real(a.a)) || iszero(real(a.b)) - return -imag(a.a) ⊗ imag(a.b) - end - return real(a.a) ⊗ real(a.b) - imag(a.a) ⊗ imag(a.b) -end -function Base.imag(a::KroneckerArray) - if iszero(imag(a.a)) || iszero(real(a.b)) - return real(a.a) ⊗ imag(a.b) - elseif iszero(real(a.a)) || iszero(imag(a.b)) - return imag(a.a) ⊗ real(a.b) +# Operations that preserve the Kronecker structure. +for f in [:identity, :conj] + @eval begin + function Broadcast.broadcasted(::KroneckerStyle, ::typeof($f), a) + return broadcasted($f, arg1(a)) ⊗ broadcasted($f, arg2(a)) + end end - return real(a.a) ⊗ imag(a.b) + imag(a.a) ⊗ real(a.b) end + +## using MapBroadcast: MapBroadcast, MapFunction +## function Base.broadcasted( +## style::KroneckerStyle, +## f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, +## a::KroneckerArray, +## ) +## return broadcasted(style, Base.Fix1(*, f.args[1]), a) +## end +## function Base.broadcasted( +## style::KroneckerStyle, +## f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, +## a::KroneckerArray, +## ) +## return broadcasted(style, Base.Fix2(*, f.args[2]), a) +## end +## function Base.broadcasted( +## style::KroneckerStyle, +## f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, +## a::KroneckerArray, +## ) +## return broadcasted(style, Base.Fix2(/, f.args[2]), a) +## end +## +## # Simplification rules similar to those for FillArrays.jl: +## # https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl +## using FillArrays: Zeros +## function Base.broadcasted( +## style::KroneckerStyle, +## ::typeof(+), +## a::KroneckerArray, +## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## ) +## # TODO: Promote the element types. +## return a +## end +## function Base.broadcasted( +## style::KroneckerStyle, +## ::typeof(+), +## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## b::KroneckerArray, +## ) +## # TODO: Promote the element types. +## return b +## end +## function Base.broadcasted( +## style::KroneckerStyle, +## ::typeof(+), +## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## ) +## # TODO: Promote the element types and axes. +## return b +## end +## function Base.broadcasted( +## style::KroneckerStyle, +## ::typeof(-), +## a::KroneckerArray, +## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## ) +## # TODO: Promote the element types. +## return a +## end +## function Base.broadcasted( +## style::KroneckerStyle, +## ::typeof(-), +## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## b::KroneckerArray, +## ) +## # TODO: Promote the element types. +## # TODO: Return `broadcasted(-, b)`. +## return -b +## end +## function Base.broadcasted( +## style::KroneckerStyle, +## ::typeof(-), +## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +## ) +## # TODO: Promote the element types and axes. +## return b +## end diff --git a/test/Project.toml b/test/Project.toml index ce07896..23b9bdd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" @@ -21,6 +22,7 @@ Aqua = "0.8" BlockArrays = "1.6" BlockSparseArrays = "0.7.19" DerivableInterfaces = "0.5" +DiagonalArrays = "0.3.7" FillArrays = "1" JLArrays = "0.2" KroneckerArrays = "0.1" diff --git a/test/test_basics.jl b/test/test_basics.jl index 47460d3..5684936 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,7 @@ using Adapt: adapt using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted using DerivableInterfaces: zero! +using DiagonalArrays: diagonal using JLArrays: JLArray using KroneckerArrays: KroneckerArrays, @@ -11,7 +12,6 @@ using KroneckerArrays: ×, cartesianproduct, cartesianrange, - diagonal, kron_nd, unproduct using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr @@ -100,7 +100,7 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) # Mapping a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) @test_throws "not supported" map(sin, a) - @test_broken map(Base.Fix1(*, 2), a) + @test collect(map(Base.Fix1(*, 2), a)) ≈ 2 * collect(a) a′ = similar(a) @test_throws "not supported" map!(sin, a′, a) a′ = similar(a) @@ -129,12 +129,12 @@ elts = (Float32, Float64, ComplexF32, ComplexF64) if elt <: Real @test real(a) == a else - @test_throws ArgumentError real(a) + @test_throws ErrorException real(a) end if elt <: Real @test iszero(imag(a)) else - @test_throws ArgumentError imag(a) + @test_throws ErrorException imag(a) end # Adapt From 8d04743b37631e984328a26416b67634bf261d5e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 21 Jun 2025 14:55:41 -0400 Subject: [PATCH 2/6] Fixes for FillArrays --- src/fillarrays/kroneckerarray.jl | 108 +++++++++++++++++++++++++++---- src/kroneckerarray.jl | 63 +----------------- 2 files changed, 96 insertions(+), 75 deletions(-) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 3454b7b..8bb521e 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -81,6 +81,100 @@ function DerivableInterfaces.zero!(a::EyeEye) return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) end +using Base.Broadcast: + AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted + +struct EyeStyle <: AbstractArrayStyle{2} end +EyeStyle(::Val{2}) = EyeStyle() +function _BroadcastStyle(::Type{<:Eye}) + return EyeStyle() +end +Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() +Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2 + +function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) + return Eye{elt}(axes(bc)) +end + +function Base.copyto!(dest::EyeKronecker, a::Sum{<:KroneckerStyle{<:Any,EyeStyle()}}) + dest2 = arg2(dest) + f = LinearCombination(a) + args = arguments(a) + arg2s = arg2.(args) + dest2 .= f.(arg2s...) + return dest +end +function Base.copyto!(dest::KroneckerEye, a::Sum{<:KroneckerStyle{<:Any,<:Any,EyeStyle()}}) + dest1 = arg1(dest) + f = LinearCombination(a) + args = arguments(a) + arg1s = arg1.(args) + dest1 .= f.(arg1s...) + return dest +end +function Base.copyto!(dest::EyeEye, a::Sum{<:KroneckerStyle{<:Any,EyeStyle(),EyeStyle()}}) + return error("Can't write in-place to `Eye ⊗ Eye`.") +end + +# Simplification rules similar to those for FillArrays.jl: +# https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl +using FillArrays: Zeros +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types. + return a +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray, +) + # TODO: Promote the element types. + return b +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(+), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types and axes. + return b +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types. + return a +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray, +) + # TODO: Promote the element types. + # TODO: Return `broadcasted(-, b)`. + return -b +end +function Base.broadcasted( + style::KroneckerStyle, + ::typeof(-), + a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, + b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, +) + # TODO: Promote the element types and axes. + return b +end + ## function Base.:*(a::Number, b::EyeKronecker) ## return b.a ⊗ (a * b.b) ## end @@ -219,17 +313,3 @@ end ## return error("Can't write in-place.") ## end ## -## using Base.Broadcast: -## AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted -## -## struct EyeStyle <: AbstractArrayStyle{2} end -## EyeStyle(::Val{2}) = EyeStyle() -## function _BroadcastStyle(::Type{<:Eye}) -## return EyeStyle() -## end -## Base.BroadcastStyle(style1::EyeStyle, style2::EyeStyle) = EyeStyle() -## Base.BroadcastStyle(style1::EyeStyle, style2::DefaultArrayStyle) = style2 -## -## function Base.similar(bc::Broadcasted{EyeStyle}, elt::Type) -## return Eye{elt}(axes(bc)) -## end diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 7028aa6..65df8ed 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -283,12 +283,12 @@ function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::Kronecke end function Base.copyto!(dest::KroneckerArray, a::Sum{<:KroneckerStyle}) + dest1 = arg1(dest) + dest2 = arg2(dest) f = LinearCombination(a) args = arguments(a) arg1s = arg1.(args) arg2s = arg2.(args) - dest1 = arg1(dest) - dest2 = arg2(dest) if allequal(arg2s) copyto!(dest2, first(arg2s)) dest1 .= f.(arg1s...) @@ -407,62 +407,3 @@ end ## ) ## return broadcasted(style, Base.Fix2(/, f.args[2]), a) ## end -## -## # Simplification rules similar to those for FillArrays.jl: -## # https://github.com/JuliaArrays/FillArrays.jl/blob/v1.13.0/src/fillbroadcast.jl -## using FillArrays: Zeros -## function Base.broadcasted( -## style::KroneckerStyle, -## ::typeof(+), -## a::KroneckerArray, -## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## ) -## # TODO: Promote the element types. -## return a -## end -## function Base.broadcasted( -## style::KroneckerStyle, -## ::typeof(+), -## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## b::KroneckerArray, -## ) -## # TODO: Promote the element types. -## return b -## end -## function Base.broadcasted( -## style::KroneckerStyle, -## ::typeof(+), -## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## ) -## # TODO: Promote the element types and axes. -## return b -## end -## function Base.broadcasted( -## style::KroneckerStyle, -## ::typeof(-), -## a::KroneckerArray, -## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## ) -## # TODO: Promote the element types. -## return a -## end -## function Base.broadcasted( -## style::KroneckerStyle, -## ::typeof(-), -## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## b::KroneckerArray, -## ) -## # TODO: Promote the element types. -## # TODO: Return `broadcasted(-, b)`. -## return -b -## end -## function Base.broadcasted( -## style::KroneckerStyle, -## ::typeof(-), -## a::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## b::KroneckerArray{<:Any,<:Any,<:Zeros,<:Zeros}, -## ) -## # TODO: Promote the element types and axes. -## return b -## end From e363e040a7dffb0839f240a5ac61f89edb54ebd0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 21 Jun 2025 15:16:13 -0400 Subject: [PATCH 3/6] Fix ambiguity error --- src/kroneckerarray.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 65df8ed..949a683 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -318,6 +318,10 @@ end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a, c::Number) return Sum(a) * c end +# Fix ambiguity error. +function Broadcast.broadcasted(::KroneckerStyle, ::typeof(*), a::Number, b::Number) + return a * b +end function Broadcast.broadcasted(::KroneckerStyle, ::typeof(/), a, c::Number) return Sum(a) / c end From 6dfe480778f6296c71efea2a48527c4752b3d81b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 21 Jun 2025 15:31:26 -0400 Subject: [PATCH 4/6] Fix tests --- Project.toml | 2 +- src/kroneckerarray.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 4e96312..2921538 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.17" +version = "0.1.18" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 949a683..d0258a2 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -383,8 +383,8 @@ end # Operations that preserve the Kronecker structure. for f in [:identity, :conj] @eval begin - function Broadcast.broadcasted(::KroneckerStyle, ::typeof($f), a) - return broadcasted($f, arg1(a)) ⊗ broadcasted($f, arg2(a)) + function Broadcast.broadcasted(::KroneckerStyle{<:Any,A,B}, ::typeof($f), a) where {A,B} + return broadcasted(A, $f, arg1(a)) ⊗ broadcasted(B, $f, arg2(a)) end end end From 3f162f3588dd1720e83fb39be6da9bc3636a6f7e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 21 Jun 2025 15:36:12 -0400 Subject: [PATCH 5/6] Cleanup --- src/fillarrays/kroneckerarray.jl | 139 ------------------------------- src/kroneckerarray.jl | 39 ++++----- 2 files changed, 17 insertions(+), 161 deletions(-) diff --git a/src/fillarrays/kroneckerarray.jl b/src/fillarrays/kroneckerarray.jl index 8bb521e..f132dbf 100644 --- a/src/fillarrays/kroneckerarray.jl +++ b/src/fillarrays/kroneckerarray.jl @@ -174,142 +174,3 @@ function Base.broadcasted( # TODO: Promote the element types and axes. return b end - -## function Base.:*(a::Number, b::EyeKronecker) -## return b.a ⊗ (a * b.b) -## end -## function Base.:*(a::Number, b::KroneckerEye) -## return (a * b.a) ⊗ b.b -## end -## function Base.:*(a::Number, b::EyeEye) -## return error("Can't multiply `Eye ⊗ Eye` by a number.") -## end -## function Base.:*(a::EyeKronecker, b::Number) -## return a.a ⊗ (a.b * b) -## end -## function Base.:*(a::KroneckerEye, b::Number) -## return (a.a * b) ⊗ a.b -## end -## function Base.:*(a::EyeEye, b::Number) -## return error("Can't multiply `Eye ⊗ Eye` by a number.") -## end -## -## function Base.:-(a::EyeKronecker) -## return a.a ⊗ (-a.b) -## end -## function Base.:-(a::KroneckerEye) -## return (-a.a) ⊗ a.b -## end -## function Base.:-(a::EyeEye) -## return error("Can't multiply `Eye ⊗ Eye` by a number.") -## end -## -## for op in (:+, :-) -## @eval begin -## function Base.$op(a::EyeKronecker, b::EyeKronecker) -## if a.a ≠ b.a -## return throw( -## ArgumentError( -## "KroneckerArray addition is only supported when the first or secord arguments match.", -## ), -## ) -## end -## return a.a ⊗ $op(a.b, b.b) -## end -## function Base.$op(a::KroneckerEye, b::KroneckerEye) -## if a.b ≠ b.b -## return throw( -## ArgumentError( -## "KroneckerArray addition is only supported when the first or secord arguments match.", -## ), -## ) -## end -## return $op(a.a, b.a) ⊗ a.b -## end -## function Base.$op(a::EyeEye, b::EyeEye) -## if a.b ≠ b.b -## return throw( -## ArgumentError( -## "KroneckerArray addition is only supported when the first or secord arguments match.", -## ), -## ) -## end -## return $op(a.a, b.a) ⊗ a.b -## end -## end -## end -## -## function Base.map!(f::typeof(identity), dest::EyeKronecker, src::EyeKronecker) -## map!(f, dest.b, src.b) -## return dest -## end -## function Base.map!(f::typeof(identity), dest::KroneckerEye, src::KroneckerEye) -## map!(f, dest.a, src.a) -## return dest -## end -## function Base.map!(::typeof(identity), dest::EyeEye, src::EyeEye) -## return error("Can't write in-place.") -## end -## for f in [:+, :-] -## @eval begin -## function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) -## if dest.a ≠ a.a ≠ b.a -## throw( -## ArgumentError( -## "KroneckerArray addition is only supported when the first or second arguments match.", -## ), -## ) -## end -## map!($f, dest.b, a.b, b.b) -## return dest -## end -## function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) -## if dest.b ≠ a.b ≠ b.b -## throw( -## ArgumentError( -## "KroneckerArray addition is only supported when the first or second arguments match.", -## ), -## ) -## end -## map!($f, dest.a, a.a, b.a) -## return dest -## end -## function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye) -## return error("Can't write in-place.") -## end -## end -## end -## function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker) -## map!(f, dest.b, a.b) -## return dest -## end -## function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye) -## map!(f, dest.a, a.a) -## return dest -## end -## function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye) -## return error("Can't write in-place.") -## end -## function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) -## map!(f, dest.b, a.b) -## return dest -## end -## function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) -## map!(f, dest.a, a.a) -## return dest -## end -## function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) -## return error("Can't write in-place.") -## end -## function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) -## map!(f, dest.b, a.b) -## return dest -## end -## function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerEye, a::KroneckerEye) -## map!(f, dest.a, a.a) -## return dest -## end -## function Base.map!(f::Base.Fix2{typeof(*),<:Number}, dest::EyeEye, a::EyeEye) -## return error("Can't write in-place.") -## end -## diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index d0258a2..82b4357 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -389,25 +389,20 @@ for f in [:identity, :conj] end end -## using MapBroadcast: MapBroadcast, MapFunction -## function Base.broadcasted( -## style::KroneckerStyle, -## f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, -## a::KroneckerArray, -## ) -## return broadcasted(style, Base.Fix1(*, f.args[1]), a) -## end -## function Base.broadcasted( -## style::KroneckerStyle, -## f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, -## a::KroneckerArray, -## ) -## return broadcasted(style, Base.Fix2(*, f.args[2]), a) -## end -## function Base.broadcasted( -## style::KroneckerStyle, -## f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, -## a::KroneckerArray, -## ) -## return broadcasted(style, Base.Fix2(/, f.args[2]), a) -## end +# Compatibility with MapBroadcast.jl. +using MapBroadcast: MapBroadcast, MapFunction +function Base.broadcasted( + style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}}, a +) + return broadcasted(style, *, f.args[1], a) +end +function Base.broadcasted( + style::KroneckerStyle, f::MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}}, a +) + return broadcasted(style, *, a, f.args[2]) +end +function Base.broadcasted( + style::KroneckerStyle, f::MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}}, a +) + return broadcasted(style, /, a, f.args[2]) +end From 36f18756c51eb2ee568edd0f22274d4f98d41714 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 21 Jun 2025 16:04:12 -0400 Subject: [PATCH 6/6] Add missing file --- src/linearcombination.jl | 92 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 src/linearcombination.jl diff --git a/src/linearcombination.jl b/src/linearcombination.jl new file mode 100644 index 0000000..352d135 --- /dev/null +++ b/src/linearcombination.jl @@ -0,0 +1,92 @@ +using Base.Broadcast: Broadcasted +struct LinearCombination{C} <: Function + coefficients::C +end +coefficients(a::LinearCombination) = a.coefficients +function (f::LinearCombination)(args...) + return mapreduce(*,+,coefficients(f),args) +end + +struct Sum{Style,C<:Tuple,A<:Tuple} + style::Style + coefficients::C + arguments::A +end +coefficients(a::Sum) = a.coefficients +arguments(a::Sum) = a.arguments +style(a::Sum) = a.style +LinearCombination(a::Sum) = LinearCombination(coefficients(a)) +using Base.Broadcast: combine_axes +Base.axes(a::Sum) = combine_axes(a.arguments...) +function Base.eltype(a::Sum) + cts = typeof.(coefficients(a)) + elts = eltype.(arguments(a)) + ts = map((ct, elt) -> Base.promote_op(*, ct, elt), cts, elts) + return Base.promote_op(+, ts...) +end +using Base.Broadcast: combine_styles +function Sum(coefficients::Tuple, arguments::Tuple) + return Sum(combine_styles(arguments...), coefficients, arguments) +end +Sum(a) = Sum((one(eltype(a)),), (a,)) +function Base.:+(a::Sum, b::Sum) + return Sum((coefficients(a)..., coefficients(b)...), (arguments(a)..., arguments(b)...)) +end +Base.:-(a::Sum, b::Sum) = a + (-b) +Base.:+(a::Sum, b::AbstractArray) = a + Sum(b) +Base.:-(a::Sum, b::AbstractArray) = a - Sum(b) +Base.:+(a::AbstractArray, b::Sum) = Sum(a) + b +Base.:-(a::AbstractArray, b::Sum) = Sum(a) - b +Base.:*(c::Number, a::Sum) = Sum(c .* coefficients(a), arguments(a)) +Base.:*(a::Sum, c::Number) = c * a +Base.:/(a::Sum, c::Number) = Sum(coefficients(a) ./ c, arguments(a)) +Base.:-(a::Sum) = -1 * a + +function Base.copy(a::Sum) + return copyto!(similar(a), a) +end +Base.similar(a::Sum) = similar(a, eltype(a)) +Base.similar(a::Sum, elt::Type) = similar(a, elt, axes(a)) +function Base.copyto!(dest::AbstractArray, a::Sum) + f = LinearCombination(a) + dest .= f.(arguments(a)...) + return dest +end +function Broadcast.Broadcasted(a::Sum) + f = LinearCombination(a) + return Broadcasted(style(a), f, arguments(a), axes(a)) +end +function Base.similar(a::Sum, elt::Type, ax::Tuple) + return similar(Broadcasted(a), elt, ax) +end + +using Base.Broadcast: Broadcast, AbstractArrayStyle, DefaultArrayStyle +Broadcast.materialize(a::Sum) = copy(a) +Broadcast.materialize!(dest, a::Sum) = copyto!(dest, a) +struct SumStyle <: AbstractArrayStyle{Any} end +Broadcast.broadcastable(a::Sum) = a +Broadcast.BroadcastStyle(::Type{<:Sum}) = SumStyle() +Broadcast.BroadcastStyle(style::SumStyle, ::AbstractArrayStyle) = style +# Fix ambiguity error with Base. +Broadcast.BroadcastStyle(style::SumStyle, ::DefaultArrayStyle) = style +function Broadcast.broadcasted(::SumStyle, f, as...) + return error("Arbitrary broadcasting not supported for SumStyle.") +end +function Broadcast.broadcasted(::SumStyle, ::typeof(+), a, b::Sum) + return Sum(a) + b +end +function Broadcast.broadcasted(::SumStyle, ::typeof(+), a::Sum, b) + return a + Sum(b) +end +function Broadcast.broadcasted(::SumStyle, ::typeof(+), a::Sum, b::Sum) + return a + b +end +function Broadcast.broadcasted(::SumStyle, ::typeof(*), c::Number, a) + return c * Sum(a) +end +function Broadcast.broadcasted(::SumStyle, ::typeof(*), c::Number, a::Sum) + return c * a +end +function Broadcast.broadcasted(::SumStyle, ::typeof(/), a::Sum, c::Number) + return Sum(a) / c +end