diff --git a/src/structuredbroadcast.jl b/src/structuredbroadcast.jl index da09a16d..a18e709e 100644 --- a/src/structuredbroadcast.jl +++ b/src/structuredbroadcast.jl @@ -230,6 +230,30 @@ end # All structured matrices are square, and therefore they only broadcast out if they are size (1, 1) Broadcast.newindex(D::StructuredMatrix, I::CartesianIndex{2}) = size(D) == (1,1) ? CartesianIndex(1,1) : I +# Recursively replace wrapped matrices by their parents to improve broadcasting performance +# We may do this because the indexing within `copyto!` is restricted to the stored indices +preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) +function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} + args = map(x -> preprocess_broadcasted(T, x), bc.args) + Broadcast.broadcasted(bc.f, args...) +end +# fallback case that doesn't unwrap at all +_preprocess_broadcasted(::Type, x) = x + +_preprocess_broadcasted(::Type{Diagonal}, d::Diagonal) = d.diag +# fallback for types that might opt into Diagonal-like structured broadcasting, e.g. wrappers +_preprocess_broadcasted(::Type{Diagonal}, d::AbstractMatrix) = diagview(d) + +function copy(bc::Broadcasted{StructuredMatrixStyle{Diagonal}}) + if isstructurepreserving(bc) || fzeropreserving(bc) + # forward the broadcasting operation to the diagonal + bc2 = preprocess_broadcasted(Diagonal, bc) + return Diagonal(copy(bc2)) + else + @invoke copy(bc::Broadcasted) + end +end + function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle}) isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) axs = axes(dest) @@ -291,13 +315,6 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) return dest end -# Recursively replace wrapped matrices by their parents to improve broadcasting performance -# We may do this because the indexing within `copyto!` is restricted to the stored indices -preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) -function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} - args = map(x -> preprocess_broadcasted(T, x), bc.args) - Broadcast.Broadcasted(bc.f, args, bc.axes) -end _preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A) _preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A) _preprocess_broadcasted(::Type{UpperHessenberg}, A) = upperhessenbergdata(A) diff --git a/test/structuredbroadcast.jl b/test/structuredbroadcast.jl index edd91e40..2333ec23 100644 --- a/test/structuredbroadcast.jl +++ b/test/structuredbroadcast.jl @@ -410,4 +410,19 @@ end @test UH2 isa UpperHessenberg end +@testset "forwarding broadcast to the diag for a Diagonal" begin + D = Diagonal(1:4) + D2 = D .* 2 + @test D2 isa Diagonal{Int, <:AbstractRange{Int}} + + # test for wrappers that opt into Diagonal-like broadcasting + U = UpperTriangular(D) + bc = Broadcast.broadcasted(+, D, U) + bcD = Broadcast.broadcasted(+, D, D) + S = typeof(Broadcast.BroadcastStyle(typeof(bcD))) + bc2 = convert(Broadcast.Broadcasted{S}, bc) + @test copy(bc2) == copy(bc) == copy(bcD) + @test copy(bc2) isa Diagonal +end + end