|
208 | 208 | # All structured matrices are square, and therefore they only broadcast out if they are size (1, 1) |
209 | 209 | Broadcast.newindex(D::StructuredMatrix, I::CartesianIndex{2}) = size(D) == (1,1) ? CartesianIndex(1,1) : I |
210 | 210 |
|
| 211 | +# Recursively replace wrapped matrices by their parents to improve broadcasting performance |
| 212 | +# We may do this because the indexing within `copyto!` is restricted to the stored indices |
| 213 | +preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) |
| 214 | +function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} |
| 215 | + args = map(x -> preprocess_broadcasted(T, x), bc.args) |
| 216 | + Broadcast.broadcasted(bc.f, args...) |
| 217 | +end |
| 218 | +# fallback case that doesn't unwrap at all |
| 219 | +_preprocess_broadcasted(::Type, x) = x |
| 220 | + |
| 221 | +_preprocess_broadcasted(::Type{Diagonal}, d::Diagonal) = d.diag |
| 222 | +# fallback for types that might opt into Diagonal-like structured broadcasting, e.g. wrappers |
| 223 | +_preprocess_broadcasted(::Type{Diagonal}, d::AbstractMatrix) = diag(d) |
| 224 | + |
| 225 | +function copy(bc::Broadcasted{StructuredMatrixStyle{Diagonal}}) |
| 226 | + if isstructurepreserving(bc) || fzeropreserving(bc) |
| 227 | + # forward the broadcasting operation to the diagonal |
| 228 | + bc2 = preprocess_broadcasted(Diagonal, bc) |
| 229 | + return Diagonal(copy(bc2)) |
| 230 | + else |
| 231 | + @invoke copy(bc::Broadcasted) |
| 232 | + end |
| 233 | +end |
| 234 | + |
211 | 235 | function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle}) |
212 | 236 | isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) |
213 | 237 | axs = axes(dest) |
@@ -269,13 +293,6 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) |
269 | 293 | return dest |
270 | 294 | end |
271 | 295 |
|
272 | | -# Recursively replace wrapped matrices by their parents to improve broadcasting performance |
273 | | -# We may do this because the indexing within `copyto!` is restricted to the stored indices |
274 | | -preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) |
275 | | -function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} |
276 | | - args = map(x -> preprocess_broadcasted(T, x), bc.args) |
277 | | - Broadcast.Broadcasted(bc.f, args, bc.axes) |
278 | | -end |
279 | 296 | _preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A) |
280 | 297 | _preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A) |
281 | 298 |
|
|
0 commit comments