@@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
88StructuredMatrixStyle {T} (:: Val{2} ) where {T} = StructuredMatrixStyle {T} ()
99StructuredMatrixStyle {T} (:: Val{N} ) where {T,N} = Broadcast. DefaultArrayStyle {N} ()
1010
11- const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12- for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
11+ const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},
12+ LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T},
13+ UpperHessenberg{T}}
14+ for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,
15+ LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular,
16+ UpperHessenberg)
1317 @eval Broadcast. BroadcastStyle (:: Type{<:$ST} ) = $ (StructuredMatrixStyle {ST} ())
1418end
1519
@@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt
2731 StructuredMatrixStyle {LowerTriangular} ()
2832Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
2933 StructuredMatrixStyle {UpperTriangular} ()
34+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
35+ StructuredMatrixStyle {UpperHessenberg} ()
3036
3137Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{Diagonal} ) =
3238 StructuredMatrixStyle {Bidiagonal} ()
3339Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
3440 StructuredMatrixStyle {Tridiagonal} ()
41+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
42+ StructuredMatrixStyle {UpperHessenberg} ()
43+
3544Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
3645 StructuredMatrixStyle {Tridiagonal} ()
46+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
47+ StructuredMatrixStyle {UpperHessenberg} ()
3748Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
3849 StructuredMatrixStyle {Tridiagonal} ()
50+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
51+ StructuredMatrixStyle {UpperHessenberg} ()
3952
4053Broadcast. BroadcastStyle (:: StructuredMatrixStyle{LowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
4154 StructuredMatrixStyle {LowerTriangular} ()
4255Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
4356 StructuredMatrixStyle {UpperTriangular} ()
57+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
58+ StructuredMatrixStyle {UpperHessenberg} ()
4459Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitLowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
4560 StructuredMatrixStyle {LowerTriangular} ()
4661Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
4762 StructuredMatrixStyle {UpperTriangular} ()
63+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
64+ StructuredMatrixStyle {UpperHessenberg} ()
65+
66+ function Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperHessenberg} ,
67+ :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}} )
68+ StructuredMatrixStyle {UpperHessenberg} ()
69+ end
4870
49- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
71+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} ) =
5072 StructuredMatrixStyle {Matrix} ()
51- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
73+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
5274 StructuredMatrixStyle {Matrix} ()
5375
5476# Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling
@@ -95,6 +117,8 @@ structured_broadcast_alloc(bc, ::Type{UnitLowerTriangular}, ::Type{ElType}, n) w
95117 UnitLowerTriangular (Array {ElType} (undef, n, n))
96118structured_broadcast_alloc (bc, :: Type{UnitUpperTriangular} , :: Type{ElType} , n) where {ElType} =
97119 UnitUpperTriangular (Array {ElType} (undef, n, n))
120+ structured_broadcast_alloc (bc, :: Type{UpperHessenberg} , :: Type{ElType} , n) where {ElType} =
121+ UpperHessenberg (Array {ElType} (undef, n, n))
98122structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , n) where {ElType} =
99123 Array {ElType} (undef, n, n)
100124
@@ -288,6 +312,18 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
288312 return dest
289313end
290314
315+ function copyto! (dest:: UpperHessenberg , bc:: Broadcasted{<:StructuredMatrixStyle} )
316+ isvalidstructbc (dest, bc) || return copyto! (dest, convert (Broadcasted{Nothing}, bc))
317+ axs = axes (dest)
318+ axes (bc) == axs || Broadcast. throwdm (axes (bc), axs)
319+ for j in axs[2 ]
320+ for i in 1 : min (size (dest. data,1 ), j+ 1 )
321+ @inbounds dest. data[i,j] = bc[CartesianIndex (i, j)]
322+ end
323+ end
324+ return dest
325+ end
326+
291327# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
292328function map (f, A:: StructuredMatrix , Bs:: StructuredMatrix... )
293329 sz = size (A)
0 commit comments