@@ -8,8 +8,12 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
88StructuredMatrixStyle {T} (:: Val{2} ) where {T} = StructuredMatrixStyle {T} ()
99StructuredMatrixStyle {T} (:: Val{N} ) where {T,N} = Broadcast. DefaultArrayStyle {N} ()
1010
11- const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
12- for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
11+ const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},
12+ LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T},
13+ UpperHessenberg{T}}
14+ for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,
15+ LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular,
16+ UpperHessenberg)
1317 @eval Broadcast. BroadcastStyle (:: Type{<:$ST} ) = $ (StructuredMatrixStyle {ST} ())
1418end
1519
@@ -27,28 +31,46 @@ Broadcast.BroadcastStyle(::StructuredMatrixStyle{Diagonal}, ::StructuredMatrixSt
2731 StructuredMatrixStyle {LowerTriangular} ()
2832Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
2933 StructuredMatrixStyle {UpperTriangular} ()
34+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Diagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
35+ StructuredMatrixStyle {UpperHessenberg} ()
3036
3137Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{Diagonal} ) =
3238 StructuredMatrixStyle {Bidiagonal} ()
3339Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{<:Union{Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
3440 StructuredMatrixStyle {Tridiagonal} ()
41+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Bidiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
42+ StructuredMatrixStyle {UpperHessenberg} ()
43+
3544Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
3645 StructuredMatrixStyle {Tridiagonal} ()
46+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{SymTridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
47+ StructuredMatrixStyle {UpperHessenberg} ()
3748Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal}} ) =
3849 StructuredMatrixStyle {Tridiagonal} ()
50+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{Tridiagonal} , :: StructuredMatrixStyle{UpperHessenberg} ) =
51+ StructuredMatrixStyle {UpperHessenberg} ()
3952
4053Broadcast. BroadcastStyle (:: StructuredMatrixStyle{LowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
4154 StructuredMatrixStyle {LowerTriangular} ()
4255Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
4356 StructuredMatrixStyle {UpperTriangular} ()
57+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
58+ StructuredMatrixStyle {UpperHessenberg} ()
4459Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitLowerTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,LowerTriangular,UnitLowerTriangular}} ) =
4560 StructuredMatrixStyle {LowerTriangular} ()
4661Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{<:Union{Diagonal,UpperTriangular,UnitUpperTriangular}} ) =
4762 StructuredMatrixStyle {UpperTriangular} ()
63+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UnitUpperTriangular} , :: StructuredMatrixStyle{UpperHessenberg} ) =
64+ StructuredMatrixStyle {UpperHessenberg} ()
65+
66+ function Broadcast. BroadcastStyle (:: StructuredMatrixStyle{UpperHessenberg} ,
67+ :: StructuredMatrixStyle{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,UnitUpperTriangular,UpperTriangular}} )
68+ StructuredMatrixStyle {UpperHessenberg} ()
69+ end
4870
49- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} ) =
71+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} , :: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} ) =
5072 StructuredMatrixStyle {Matrix} ()
51- Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular}} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
73+ Broadcast. BroadcastStyle (:: StructuredMatrixStyle{<:Union{UpperTriangular,UnitUpperTriangular,UpperHessenberg }} , :: StructuredMatrixStyle{<:Union{LowerTriangular,UnitLowerTriangular}} ) =
5274 StructuredMatrixStyle {Matrix} ()
5375
5476# Make sure that `StructuredMatrixStyle{Matrix}` doesn't ever end up falling
@@ -97,7 +119,7 @@ function structured_broadcast_alloc(bc, ::Type{Tridiagonal},
97119 Tridiagonal (Array {ElType} (undef, n1),Array {ElType} (undef, n),Array {ElType} (undef, n1))
98120end
99121function structured_broadcast_alloc (bc, :: Type{T} , :: Type{ElType} ,
100- sz:: NTuple{2,Integer} ) where {ElType,T<: UpperOrLowerTriangular }
122+ sz:: NTuple{2,Integer} ) where {ElType,T<: Union{ UpperOrLowerTriangular, UpperHessenberg} }
101123 T (Array {ElType} (undef, sz))
102124end
103125structured_broadcast_alloc (bc, :: Type{Matrix} , :: Type{ElType} , sz:: NTuple{2,Integer} ) where {ElType} =
@@ -293,6 +315,18 @@ function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
293315 return dest
294316end
295317
318+ function copyto! (dest:: UpperHessenberg , bc:: Broadcasted{<:StructuredMatrixStyle} )
319+ isvalidstructbc (dest, bc) || return copyto! (dest, convert (Broadcasted{Nothing}, bc))
320+ axs = axes (dest)
321+ axes (bc) == axs || Broadcast. throwdm (axes (bc), axs)
322+ for j in axs[2 ]
323+ for i in 1 : min (size (dest. data,1 ), j+ 1 )
324+ @inbounds dest. data[i,j] = bc[CartesianIndex (i, j)]
325+ end
326+ end
327+ return dest
328+ end
329+
296330# We can also implement `map` and its promotion in terms of broadcast with a stricter dimension check
297331function map (f, A:: StructuredMatrix , Bs:: StructuredMatrix... )
298332 sz = size (A)
0 commit comments