Skip to content

Commit 5898006

Browse files
committed
Forward structure-preserving broadcasting to diag for Diagonal
1 parent ed53855 commit 5898006

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

src/structuredbroadcast.jl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,30 @@ end
208208
# All structured matrices are square, and therefore they only broadcast out if they are size (1, 1)
209209
Broadcast.newindex(D::StructuredMatrix, I::CartesianIndex{2}) = size(D) == (1,1) ? CartesianIndex(1,1) : I
210210

211+
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
212+
# We may do this because the indexing within `copyto!` is restricted to the stored indices
213+
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
214+
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
215+
args = map(x -> preprocess_broadcasted(T, x), bc.args)
216+
Broadcast.broadcasted(bc.f, args...)
217+
end
218+
# fallback case that doesn't unwrap at all
219+
_preprocess_broadcasted(::Type, x) = x
220+
221+
_preprocess_broadcasted(::Type{Diagonal}, d::Diagonal) = d.diag
222+
# fallback for types that might opt into Diagonal-like structured broadcasting, e.g. wrappers
223+
_preprocess_broadcasted(::Type{Diagonal}, d::AbstractMatrix) = diag(d)
224+
225+
function copy(bc::Broadcasted{StructuredMatrixStyle{Diagonal}})
226+
if isstructurepreserving(bc) || fzeropreserving(bc)
227+
# forward the broadcasting operation to the diagonal
228+
bc2 = preprocess_broadcasted(Diagonal, bc)
229+
return Diagonal(copy(bc2))
230+
else
231+
@invoke copy(bc::Broadcasted)
232+
end
233+
end
234+
211235
function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
212236
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
213237
axs = axes(dest)
@@ -269,13 +293,6 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
269293
return dest
270294
end
271295

272-
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
273-
# We may do this because the indexing within `copyto!` is restricted to the stored indices
274-
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
275-
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
276-
args = map(x -> preprocess_broadcasted(T, x), bc.args)
277-
Broadcast.Broadcasted(bc.f, args, bc.axes)
278-
end
279296
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
280297
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)
281298

test/structuredbroadcast.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,4 +398,10 @@ end
398398
end
399399
end
400400

401+
@testset "forwarding for Diganoal" begin
402+
D = Diagonal(1:4)
403+
D2 = D .* 2
404+
@test D2 isa Diagonal{Int, <:AbstractRange{Int}}
405+
end
406+
401407
end

0 commit comments

Comments
 (0)