@@ -494,7 +494,8 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
494494end 
495495
496496#  broadcast
497- import  Base. Broadcast:  BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
497+ import  Base. Broadcast:  BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498+ using  Base. Broadcast:  combine_styles
498499
499500struct  StructArrayStyle{S, N} <:  AbstractArrayStyle{N}  end 
500501
@@ -524,6 +525,82 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
524525
525526BroadcastStyle (:: Type{SA} ) where  {SA<: StructArray } =  StructArrayStyle {typeof(cst(SA)), ndims(SA)} ()
526527
528+ """ 
529+     always_struct_broadcast(style::BroadcastStyle) 
530+ 
531+ Check if `style` supports struct-broadcast natively, which means: 
532+ 1) `Base.copy` is not overloaded. 
533+ 2) `Base.similar` is defined. 
534+ 3) `Base.copyto!` supports `StructArray`s as broadcasted arguments. 
535+ 
536+ If any of the above conditions are not met, then this function should 
537+ not be overloaded. 
538+ In that case, try to overload [`try_struct_copy`](@ref) to support out-of-place 
539+ struct-broadcast. 
540+ """ 
541+ always_struct_broadcast (:: Any ) =  false 
542+ always_struct_broadcast (:: DefaultArrayStyle ) =  true 
543+ always_struct_broadcast (:: ArrayConflict ) =  true 
544+ 
545+ """ 
546+     try_struct_copy(bc::Broadcasted) 
547+ 
548+ Entry for non-native outplace struct-broadcast. 
549+ 
550+ See also [`always_struct_broadcast`](@ref). 
551+ """ 
552+ try_struct_copy (bc:: Broadcasted ) =  copy (bc)
553+ 
554+ function  Base. copy (bc:: Broadcasted{StructArrayStyle{S, N}} ) where  {S, N}
555+     if  always_struct_broadcast (S ())
556+         return  invoke (copy, Tuple{Broadcasted}, bc)
557+     else 
558+         return  try_struct_copy (replace_structarray (bc))
559+     end 
560+ end 
561+ 
562+ """ 
563+     replace_structarray(bc::Broadcasted) 
564+ 
565+ An internal function transforms the `Broadcasted` with `StructArray` into 
566+ an equivalent one without it. This is not a must if the root `BroadcastStyle` 
567+ supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,  
568+ e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`. 
569+ """ 
570+ function  replace_structarray (bc:: Broadcasted{Style} ) where  {Style}
571+     args =  replace_structarray_args (bc. args)
572+     Style′ =  parent_style (Style ())
573+     return  Broadcasted {Style′} (bc. f, args, bc. axes)
574+ end 
575+ function  replace_structarray (A:: StructArray )
576+     f =  Instantiator (eltype (A))
577+     args =  Tuple (components (A))
578+     Style =  typeof (combine_styles (args... ))
579+     return  Broadcasted {Style} (f, args, axes (A))
580+ end 
581+ replace_structarray (@nospecialize (A)) =  A
582+ 
583+ replace_structarray_args (args:: Tuple ) =  (replace_structarray (args[1 ]), replace_structarray_args (tail (args))... )
584+ replace_structarray_args (:: Tuple{} ) =  ()
585+ 
586+ parent_style (@nospecialize (x)) =  typeof (x)
587+ parent_style (:: StructArrayStyle{S, N} ) where  {S, N} =  S
588+ parent_style (:: StructArrayStyle{S, N} ) where  {N, S<: AbstractArrayStyle{N} } =  S
589+ parent_style (:: StructArrayStyle{S, N} ) where  {S<: AbstractArrayStyle{Any} , N} =  S
590+ parent_style (:: StructArrayStyle{S, N} ) where  {S<: AbstractArrayStyle , N} =  typeof (S (Val (N)))
591+ 
592+ #  `instantiate` and `_axes` might be overloaded for static axes.
593+ function  Broadcast. instantiate (bc:: Broadcasted{Style} ) where  {Style <:  StructArrayStyle }
594+     Style′ =  parent_style (Style ())
595+     bc′ =  Broadcast. instantiate (convert (Broadcasted{Style′}, bc))
596+     return  convert (Broadcasted{Style}, bc′)
597+ end 
598+ 
599+ function  Broadcast. _axes (bc:: Broadcasted{Style} , :: Nothing ) where  {Style <:  StructArrayStyle }
600+     Style′ =  parent_style (Style ())
601+     return  Broadcast. _axes (convert (Broadcasted{Style′}, bc), nothing )
602+ end 
603+ 
527604#  Here we use `similar` defined for `S` to build the dest Array.
528605function  Base. similar (bc:: Broadcasted{StructArrayStyle{S, N}} , :: Type{ElType} ) where  {S, N, ElType}
529606    bc′ =  convert (Broadcasted{S}, bc)
@@ -532,12 +609,22 @@ end
532609
533610#  Unwrapper to recover the behaviour defined by parent style.
534611@inline  function  Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where  {S, N}
535-     return  copyto! (dest, convert (Broadcasted{S}, bc))
612+     bc′ =  always_struct_broadcast (S ()) ?  convert (Broadcasted{S}, bc) :  replace_structarray (bc)
613+     return  copyto! (dest, bc′)
536614end 
537615
538616@inline  function  Broadcast. materialize! (:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where  {S}
539-     return  Broadcast. materialize! (S (), dest, bc)
617+     bc′ =  always_struct_broadcast (S ()) ?  bc :  replace_structarray (bc)
618+     return  Broadcast. materialize! (S (), dest, bc′)
540619end 
541620
542621#  for aliasing analysis during broadcast
622+ function  Broadcast. broadcast_unalias (dest:: StructArray , src:: AbstractArray )
623+     if  dest ===  src ||  any (Base. Fix2 (=== , src), components (dest))
624+         return  src
625+     else 
626+         return  Base. unalias (dest, src)
627+     end 
628+ end 
629+ 
543630Base. dataids (u:: StructArray ) =  mapreduce (Base. dataids, (a, b) ->  (a... , b... ), values (components (u)), init= ())
0 commit comments