Skip to content

Commit 51923a5

Browse files
authored
Forward structure-preserving broadcasting to diag for Diagonal (#1423)
For example, after this PR, ```julia julia> D = Diagonal(1:4) 4×4 Diagonal{Int64, UnitRange{Int64}}: 1 ⋅ ⋅ ⋅ ⋅ 2 ⋅ ⋅ ⋅ ⋅ 3 ⋅ ⋅ ⋅ ⋅ 4 julia> D .* 2 4×4 Diagonal{Int64, StepRangeLen{Int64, Int64, Int64, Int64}}: 2 ⋅ ⋅ ⋅ ⋅ 4 ⋅ ⋅ ⋅ ⋅ 6 ⋅ ⋅ ⋅ ⋅ 8 julia> using SparseArrays julia> D = Diagonal(spzeros(2)) 2×2 Diagonal{Float64, SparseVector{Float64, Int64}}: 0.0 ⋅ ⋅ 0.0 julia> D .* 2 2×2 Diagonal{Float64, SparseVector{Float64, Int64}}: 0.0 ⋅ ⋅ 0.0 julia> using FillArrays julia> D = Diagonal(Fill(3, 2)) 2×2 Diagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}: 3 ⋅ ⋅ 3 julia> D .* 2 2×2 Diagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}: 6 ⋅ ⋅ 6 ``` I've not handled the other banded matrix types in this PR, but `Diagonal` is probably used far more commonly in packages anyway.
1 parent 35a4427 commit 51923a5

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

src/structuredbroadcast.jl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,30 @@ end
230230
# All structured matrices are square, and therefore they only broadcast out if they are size (1, 1)
231231
Broadcast.newindex(D::StructuredMatrix, I::CartesianIndex{2}) = size(D) == (1,1) ? CartesianIndex(1,1) : I
232232

233+
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
234+
# We may do this because the indexing within `copyto!` is restricted to the stored indices
235+
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
236+
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
237+
args = map(x -> preprocess_broadcasted(T, x), bc.args)
238+
Broadcast.broadcasted(bc.f, args...)
239+
end
240+
# fallback case that doesn't unwrap at all
241+
_preprocess_broadcasted(::Type, x) = x
242+
243+
_preprocess_broadcasted(::Type{Diagonal}, d::Diagonal) = d.diag
244+
# fallback for types that might opt into Diagonal-like structured broadcasting, e.g. wrappers
245+
_preprocess_broadcasted(::Type{Diagonal}, d::AbstractMatrix) = diagview(d)
246+
247+
function copy(bc::Broadcasted{StructuredMatrixStyle{Diagonal}})
248+
if isstructurepreserving(bc) || fzeropreserving(bc)
249+
# forward the broadcasting operation to the diagonal
250+
bc2 = preprocess_broadcasted(Diagonal, bc)
251+
return Diagonal(copy(bc2))
252+
else
253+
@invoke copy(bc::Broadcasted)
254+
end
255+
end
256+
233257
function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
234258
isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc))
235259
axs = axes(dest)
@@ -291,13 +315,6 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
291315
return dest
292316
end
293317

294-
# Recursively replace wrapped matrices by their parents to improve broadcasting performance
295-
# We may do this because the indexing within `copyto!` is restricted to the stored indices
296-
preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A)
297-
function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T}
298-
args = map(x -> preprocess_broadcasted(T, x), bc.args)
299-
Broadcast.Broadcasted(bc.f, args, bc.axes)
300-
end
301318
_preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A)
302319
_preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A)
303320
_preprocess_broadcasted(::Type{UpperHessenberg}, A) = upperhessenbergdata(A)

test/structuredbroadcast.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,4 +410,19 @@ end
410410
@test UH2 isa UpperHessenberg
411411
end
412412

413+
@testset "forwarding broadcast to the diag for a Diagonal" begin
414+
D = Diagonal(1:4)
415+
D2 = D .* 2
416+
@test D2 isa Diagonal{Int, <:AbstractRange{Int}}
417+
418+
# test for wrappers that opt into Diagonal-like broadcasting
419+
U = UpperTriangular(D)
420+
bc = Broadcast.broadcasted(+, D, U)
421+
bcD = Broadcast.broadcasted(+, D, D)
422+
S = typeof(Broadcast.BroadcastStyle(typeof(bcD)))
423+
bc2 = convert(Broadcast.Broadcasted{S}, bc)
424+
@test copy(bc2) == copy(bc) == copy(bcD)
425+
@test copy(bc2) isa Diagonal
426+
end
427+
413428
end

0 commit comments

Comments
 (0)