@@ -497,33 +497,53 @@ end
497497import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498498using Base. Broadcast: combine_styles
499499
500- struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500+ @static if fieldcount (Base. Broadcast. Broadcasted) == 4
501+ struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
502+ style:: S
503+ StructArrayStyle {N} (style) where {N} = new {N, typeof(style)} (style)
504+ end
505+ StructArrayStyle {N} (style:: StructArrayStyle ) where {N} = StructArrayStyle {N} (style. style)
506+ parent_style (s:: BroadcastStyle ) = s
507+ parent_style (s:: StructArrayStyle ) = s. style
508+ style (bc:: Broadcasted ) = bc. style
509+ const broadcasted = Broadcasted
510+ else
511+ struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
512+ StructArrayStyle {N} (style) where {N} = new {N, typeof(style)} ()
513+ end
514+ StructArrayStyle {N} (style:: StructArrayStyle{M, S} ) where {N, M, S} = StructArrayStyle {N} (S ())
515+ parent_style (s:: BroadcastStyle ) = s
516+ parent_style (:: StructArrayStyle{N, S} ) = S ()
517+ style (:: Broadcasted{Style} ) where {Style} = Style ()
518+ broadcasted (s, f, args, axes) = Broadcasted {typeof(s)} (f, args, axes)
519+ end
520+ StructArrayStyle {N, S} () where {N, S} = StructArrayStyle {N} (S ())
521+ parent_style (bc:: Broadcasted ) = parent_style (style (bc))
522+ ofstyle (s, bc:: Broadcasted ) = broadcasted (s, bc. f, bc. args, bc. axes)
501523
502524# Here we define the dimension tracking behavior of StructArrayStyle
503- function StructArrayStyle {S, M } (:: Val{N} ) where {S, M, N}
525+ function StructArrayStyle {M, S } (:: Val{N} ) where {S, M, N}
504526 T = S <: AbstractArrayStyle{M} ? typeof (S (Val {N} ())) : S
505- return StructArrayStyle {T, N } ()
527+ return StructArrayStyle {N, T } ()
506528end
507529
508530# StructArrayStyle is a wrapped style.
509531# Here we try our best to resolve style conflict.
510- function BroadcastStyle (b:: AbstractArrayStyle{M} , a:: StructArrayStyle{S, N } ) where {S, N, M}
532+ function BroadcastStyle (b:: AbstractArrayStyle{M} , a:: StructArrayStyle{N, S } ) where {S, N, M}
511533 N′ = M === Any || N === Any ? Any : max (M, N)
512- S′ = Broadcast. result_style (S (), b)
513- return S′ isa StructArrayStyle ? typeof (S′)(Val {N′} ()) : StructArrayStyle {typeof(S′), N′} ()
534+ return StructArrayStyle {N′} (Broadcast. result_style (parent_style (a), b))
514535end
515536BroadcastStyle (:: StructArrayStyle , :: DefaultArrayStyle ) = Unknown ()
516537
517538@inline combine_style_types (:: Type{A} , args... ) where {A<: AbstractArray } =
518539 combine_style_types (BroadcastStyle (A), args... )
519540@inline combine_style_types (s:: BroadcastStyle , :: Type{A} , args... ) where {A<: AbstractArray } =
520541 combine_style_types (Broadcast. result_style (s, BroadcastStyle (A)), args... )
521- combine_style_types (:: StructArrayStyle{S} ) where {S} = S () # avoid nested StructArrayStyle
522542combine_style_types (s:: BroadcastStyle ) = s
523543
524544Base. @pure cst (:: Type{SA} ) where {SA} = combine_style_types (array_types (SA). parameters... )
525545
526- BroadcastStyle (:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle {typeof(cst(SA)), ndims(SA)} ()
546+ BroadcastStyle (:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle {ndims(SA)} (cst (SA) )
527547
528548"""
529549 always_struct_broadcast(style::BroadcastStyle)
@@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref).
551571"""
552572try_struct_copy (bc:: Broadcasted ) = copy (bc)
553573
554- function Base. copy (bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
555- if always_struct_broadcast (S ( ))
574+ function Base. copy (bc:: Broadcasted{<: StructArrayStyle} )
575+ if always_struct_broadcast (parent_style (bc ))
556576 return invoke (copy, Tuple{Broadcasted}, bc)
557577 else
558578 return try_struct_copy (replace_structarray (bc))
@@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle`
567587supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
568588e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
569589"""
570- function replace_structarray (bc:: Broadcasted{Style} ) where {Style}
590+ function replace_structarray (bc:: Broadcasted )
571591 args = replace_structarray_args (bc. args)
572- Style′ = parent_style (Style () )
573- return Broadcasted {Style′} ( bc. f, args, bc. axes)
592+ style = parent_style (bc )
593+ return broadcasted (style, bc. f, args, bc. axes)
574594end
575595function replace_structarray (A:: StructArray )
576596 f = Instantiator (eltype (A))
577597 args = Tuple (components (A))
578- Style = typeof ( combine_styles (args... ) )
579- return Broadcasted {Style} ( f, args, axes (A))
598+ style = combine_styles (args... )
599+ return broadcasted (style, f, args, axes (A))
580600end
581601replace_structarray (@nospecialize (A)) = A
582602
583603replace_structarray_args (args:: Tuple ) = (replace_structarray (args[1 ]), replace_structarray_args (tail (args))... )
584604replace_structarray_args (:: Tuple{} ) = ()
585605
586- parent_style (@nospecialize (x)) = typeof (x)
587- parent_style (:: StructArrayStyle{S, N} ) where {S, N} = S
588- parent_style (:: StructArrayStyle{S, N} ) where {N, S<: AbstractArrayStyle{N} } = S
589- parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle{Any} , N} = S
590- parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle , N} = typeof (S (Val (N)))
591-
592606# `instantiate` and `_axes` might be overloaded for static axes.
593- function Broadcast. instantiate (bc:: Broadcasted{Style} ) where {Style <: StructArrayStyle }
594- Style′ = parent_style (Style ())
595- bc′ = Broadcast. instantiate (convert (Broadcasted{Style′}, bc))
596- return convert (Broadcasted{Style}, bc′)
607+ function Broadcast. instantiate (bc:: Broadcasted{<:StructArrayStyle} )
608+ bc′ = Broadcast. instantiate (ofstyle (parent_style (bc), bc))
609+ return ofstyle (style (bc), bc′)
597610end
598611
599- function Broadcast. _axes (bc:: Broadcasted{Style} , :: Nothing ) where {Style <: StructArrayStyle }
600- Style′ = parent_style (Style ())
601- return Broadcast. _axes (convert (Broadcasted{Style′}, bc), nothing )
612+ function Broadcast. _axes (bc:: Broadcasted{<:StructArrayStyle} , :: Nothing )
613+ return Broadcast. _axes (ofstyle (parent_style (bc), bc), nothing )
602614end
603615
604616# Here we use `similar` defined for `S` to build the dest Array.
605- function Base. similar (bc:: Broadcasted{StructArrayStyle{S, N}} , :: Type{ElType} ) where {S, N, ElType}
606- bc′ = convert (Broadcasted{S} , bc)
617+ function Base. similar (bc:: Broadcasted{<: StructArrayStyle} , :: Type{ElType} ) where {ElType}
618+ bc′ = ofstyle ( parent_style (bc) , bc)
607619 return isnonemptystructtype (ElType) ? buildfromschema (T -> similar (bc′, T), ElType) : similar (bc′, ElType)
608620end
609621
610622# Unwrapper to recover the behaviour defined by parent style.
611- @inline function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
612- bc′ = always_struct_broadcast (S ()) ? convert (Broadcasted{S}, bc) : replace_structarray (bc)
623+ @inline function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:StructArrayStyle} )
624+ ps = parent_style (bc)
625+ bc′ = always_struct_broadcast (ps) ? ofstyle (ps, bc) : replace_structarray (bc)
613626 return copyto! (dest, bc′)
614627end
615628
616- @inline function Broadcast. materialize! (:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where {S}
617- bc′ = always_struct_broadcast (S ()) ? bc : replace_structarray (bc)
618- return Broadcast. materialize! (S (), dest, bc′)
629+ @inline function Broadcast. materialize! (s:: StructArrayStyle , dest, bc:: Broadcasted )
630+ ps = parent_style (s)
631+ bc′ = always_struct_broadcast (ps) ? bc : replace_structarray (bc)
632+ return Broadcast. materialize! (ps, dest, bc′)
619633end
620634
621635# for aliasing analysis during broadcast
0 commit comments