@@ -445,19 +445,25 @@ end
445445# broadcast
446446import Base. Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
447447
448- struct StructArrayStyle{Style } <: AbstractArrayStyle{Any } end
448+ struct StructArrayStyle{S, N } <: AbstractArrayStyle{N } end
449449
450- @inline combine_style_types (:: Type{A} , args... ) where A<: AbstractArray =
450+ # Here we define the dimension tracking behavior of StructArrayStyle
451+ function StructArrayStyle {S, M} (:: Val{N} ) where {S, M, N}
452+ T = S <: AbstractArrayStyle{M} ? typeof (S (Val (N))) : S
453+ return StructArrayStyle {T, N} ()
454+ end
455+
456+ @inline combine_style_types (:: Type{A} , args... ) where {A<: AbstractArray } =
451457 combine_style_types (BroadcastStyle (A), args... )
452- @inline combine_style_types (s:: BroadcastStyle , :: Type{A} , args... ) where A<: AbstractArray =
458+ @inline combine_style_types (s:: BroadcastStyle , :: Type{A} , args... ) where { A<: AbstractArray } =
453459 combine_style_types (Broadcast. result_style (s, BroadcastStyle (A)), args... )
454460combine_style_types (s:: BroadcastStyle ) = s
455461
456- Base. @pure cst (:: Type{SA} ) where SA = combine_style_types (array_types (SA). parameters... )
462+ Base. @pure cst (:: Type{SA} ) where {SA} = combine_style_types (array_types (SA). parameters... )
457463
458- BroadcastStyle (:: Type{SA} ) where SA<: StructArray = StructArrayStyle {typeof(cst(SA))} ()
464+ BroadcastStyle (:: Type{SA} ) where { SA<: StructArray } = StructArrayStyle {typeof(cst(SA)), ndims(SA )} ()
459465
460- Base. similar (bc:: Broadcasted{StructArrayStyle{S}} , :: Type{ElType} ) where {S<: DefaultArrayStyle ,N, ElType} =
466+ Base. similar (bc:: Broadcasted{StructArrayStyle{S, N }} , :: Type{ElType} ) where {S<: DefaultArrayStyle , N, ElType} =
461467 isstructtype (ElType) ? similar (StructArray{ElType}, axes (bc)) : similar (Array{ElType}, axes (bc))
462468
463469# for aliasing analysis during broadcast
0 commit comments