Skip to content

Commit b22f74f

Browse files
committed
Move style computation to runtime
1 parent 5fdec8f commit b22f74f

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

src/vector_of_array.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,31 +149,30 @@ end
149149

150150
## broadcasting
151151

152-
struct VectorOfArrayStyle{Style <: Broadcast.BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end
153-
VectorOfArrayStyle(::S) where {S} = VectorOfArrayStyle{S}()
154-
VectorOfArrayStyle(::S, ::Val{N}) where {S,N} = VectorOfArrayStyle(S(Val(N)))
155-
VectorOfArrayStyle(::Val{N}) where N = VectorOfArrayStyle{Broadcast.DefaultArrayStyle{N}}()
152+
struct VectorOfArrayStyle <: Broadcast.AbstractArrayStyle{Any} end
153+
VectorOfArrayStyle(::Any) = VectorOfArrayStyle()
154+
VectorOfArrayStyle(::Any, ::Any) = VectorOfArrayStyle()
156155

157156
# promotion rules
158-
@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
159-
VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
160-
end
161-
Broadcast.BroadcastStyle(::VectorOfArrayStyle{Style}, ::Broadcast.DefaultArrayStyle{0}) where Style<:Broadcast.BroadcastStyle = VectorOfArrayStyle{Style}()
162-
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
157+
#@inline function Broadcast.BroadcastStyle(::VectorOfArrayStyle{AStyle}, ::VectorOfArrayStyle{BStyle}) where {AStyle, BStyle}
158+
# VectorOfArrayStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
159+
#end
160+
Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.BroadcastStyle) = VectorOfArrayStyle()
161+
#Broadcast.BroadcastStyle(::VectorOfArrayStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
163162

164163
function Broadcast.BroadcastStyle(::Type{<:AbstractVectorOfArray{T,S}}) where {T, S}
165-
VectorOfArrayStyle(Broadcast.result_style(Broadcast.BroadcastStyle(T)))
164+
VectorOfArrayStyle()
166165
end
167166

168-
@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}) where Style
167+
@inline function Base.copy(bc::Broadcast.Broadcasted{VectorOfArrayStyle})
169168
N = narrays(bc)
170169
x = unpack_voa(bc, 1)
171170
VectorOfArray(map(1:N) do i
172171
copy(unpack_voa(bc, i))
173172
end)
174173
end
175174

176-
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}) where Style
175+
@inline function Base.copyto!(dest::AbstractVectorOfArray, bc::Broadcast.Broadcasted{VectorOfArrayStyle})
177176
N = narrays(bc)
178177
@inbounds for i in 1:N
179178
copyto!(dest[i], unpack_voa(bc, i))
@@ -206,7 +205,7 @@ _narrays(args::Tuple{}) = 0
206205

207206
# drop axes because it is easier to recompute
208207
@inline unpack_voa(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
209-
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args_voa(i, bc.args))
208+
@inline unpack_voa(bc::Broadcast.Broadcasted{VectorOfArrayStyle}, i) = Broadcast.Broadcasted(bc.f, unpack_args_voa(i, bc.args))
210209
unpack_voa(x,::Any) = x
211210
unpack_voa(x::AbstractVectorOfArray, i) = x.u[i]
212211
unpack_voa(x::AbstractArray{T,N}, i) where {T,N} = @view x[ntuple(x->Colon(),N-1)...,i]

test/basic_indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,4 @@ w = v .* v
9696
@test w[3] == v[3] .* v[3]
9797
x = copy(v)
9898
x .= v .* v
99-
@test all(x .== w)
99+
@test x.u == w.u

0 commit comments

Comments
 (0)