@@ -61,8 +61,8 @@ Broadcast.BroadcastStyle(T::StructuredMatrixStyle{Matrix}, ::StructuredMatrixSty
6161Broadcast. BroadcastStyle (:: StructuredMatrixStyle , :: StructuredMatrixStyle ) = DefaultArrayStyle {2} ()
6262
6363# And a definition akin to similar using the structured type:
64- structured_broadcast_alloc (bc, :: Type{Diagonal} , :: Type{ElType} , n ) where {ElType} =
65- Diagonal (Array {ElType} (undef, n ))
64+ structured_broadcast_alloc (bc, :: Type{Diagonal} , :: Type{ElType} , sz :: NTuple{2,Integer} ) where {ElType} =
65+ Diagonal (Array {ElType} (undef, sz[ 1 ] ))
6666# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
6767# system will return Tridiagonal when there's more than one Bidiagonal, but when
6868# there's only one, we need to make figure out upper or lower
@@ -75,28 +75,33 @@ find_uplo(a::Bidiagonal) = a.uplo
7575find_uplo (a) = nothing
7676find_uplo (bc:: Broadcasted ) = mapfoldl (find_uplo, merge_uplos, Broadcast. cat_nested (bc), init= nothing )
7777
78- function structured_broadcast_alloc (bc, :: Type{Bidiagonal} , :: Type{ElType} , n) where {ElType}
78+ function structured_broadcast_alloc (bc, :: Type{Bidiagonal} ,
79+ :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType}
80+ n = sz[1 ]
7981 uplo = n > 0 ? find_uplo (bc) : ' U'
8082 n1 = max (n - 1 , 0 )
8183 if count_structedmatrix (Bidiagonal, bc) > 1 && uplo == ' T'
8284 return Tridiagonal (Array {ElType} (undef, n1), Array {ElType} (undef, n), Array {ElType} (undef, n1))
8385 end
8486 return Bidiagonal (Array {ElType} (undef, n),Array {ElType} (undef, n1), uplo)
8587end
86- structured_broadcast_alloc (bc, :: Type{SymTridiagonal} , :: Type{ElType} , n) where {ElType} =
88+ function structured_broadcast_alloc (bc, :: Type{SymTridiagonal} ,
89+ :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType}
90+ n = sz[1 ]
8791 SymTridiagonal (Array {ElType} (undef, n),Array {ElType} (undef, max (0 ,n- 1 )))
88- structured_broadcast_alloc (bc, :: Type{Tridiagonal} , :: Type{ElType} , n) where {ElType} =
89- Tridiagonal (Array {ElType} (undef, max (0 ,n- 1 )),Array {ElType} (undef, n),Array {ElType} (undef, max (0 ,n- 1 )))
90- structured_broadcast_alloc (bc, :: Type{LowerTriangular} , :: Type{ElType} , n) where {ElType} =
91- LowerTriangular (Array {ElType} (undef, n, n))
92- structured_broadcast_alloc (bc, :: Type{UpperTriangular} , :: Type{ElType} , n) where {ElType} =
93- UpperTriangular (Array {ElType} (undef, n, n))
94- structured_broadcast_alloc (bc, :: Type{UnitLowerTriangular} , :: Type{ElType} , n) where {ElType} =
95- UnitLowerTriangular (Array {ElType} (undef, n, n))
96- structured_broadcast_alloc (bc, :: Type{UnitUpperTriangular} , :: Type{ElType} , n) where {ElType} =
97- UnitUpperTriangular (Array {ElType} (undef, n, n))
98- structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , n) where {ElType} =
99- Array {ElType} (undef, n, n)
92+ end
93+ function structured_broadcast_alloc (bc, :: Type{Tridiagonal} ,
94+ :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType}
95+ n = sz[1 ]
96+ n1 = max (0 ,n- 1 )
97+ Tridiagonal (Array {ElType} (undef, n1),Array {ElType} (undef, n),Array {ElType} (undef, n1))
98+ end
99+ function structured_broadcast_alloc (bc, :: Type{T} , :: Type{ElType} ,
100+ sz:: NTuple{2,Integer} ) where {ElType,T<: UpperOrLowerTriangular }
101+ T (Array {ElType} (undef, sz))
102+ end
103+ structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType} =
104+ Array {ElType} (undef, sz)
100105
101106# A _very_ limited list of structure-preserving functions known at compile-time. This list is
102107# derived from the formerly-implemented `broadcast` methods in 0.6. Note that this must
@@ -172,7 +177,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType})
172177 inds = axes (bc)
173178 fzerobc = fzeropreserving (bc)
174179 if isstructurepreserving (bc) || (fzerobc && ! (T <: Union{UnitLowerTriangular,UnitUpperTriangular} ))
175- return structured_broadcast_alloc (bc, T, ElType, length ( inds[ 1 ] ))
180+ return structured_broadcast_alloc (bc, T, ElType, map (length, inds))
176181 elseif fzerobc && T <: UnitLowerTriangular
177182 return similar (convert (Broadcasted{StructuredMatrixStyle{LowerTriangular}}, bc), ElType)
178183 elseif fzerobc && T <: UnitUpperTriangular
0 commit comments