Skip to content

Commit ed2ea13

Browse files
Merge pull request #325 from AayushSabharwal/as/inference
fix: mapreduce type stability, VoA broadcast adjoints
2 parents 137a0b5 + 1f9b577 commit ed2ea13

File tree

6 files changed

+215
-14
lines changed

6 files changed

+215
-14
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ end
100100
end
101101
end
102102

103+
@adjoint function Base.copy(u::VectorOfArray)
104+
copy(u),
105+
y -> (copy(y),)
106+
end
107+
103108
@adjoint function DiffEqArray(u, t)
104109
DiffEqArray(u, t),
105110
y -> begin
@@ -117,19 +122,122 @@ end
117122
A.x, literal_ArrayPartition_x_adjoint
118123
end
119124

120-
@adjoint function Array(VA::AbstractVectorOfArray)
125+
@adjoint function Base.Array(VA::AbstractVectorOfArray)
121126
Array(VA),
122127
y -> (Array(y),)
123128
end
124129

130+
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
131+
view(A, I...),
132+
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
133+
end
125134

126135
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
127136

128-
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
137+
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{AbstractArray,AbstractVectorOfArray})
129138
arr = reshape(x, p.sz)
130139
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
131140
end
132141

142+
@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray})
143+
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
144+
end
145+
@adjoint function Broadcast.broadcasted(::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray)
146+
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
147+
end
148+
149+
_minus(Δ) = .-Δ
150+
_minus(::Nothing) = nothing
151+
152+
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})
153+
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
154+
end
155+
@adjoint function Broadcast.broadcasted(::typeof(*), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})
156+
(
157+
x.*y,
158+
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
159+
)
160+
end
161+
@adjoint function Broadcast.broadcasted(::typeof(/), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})
162+
res = x ./ y
163+
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))
164+
end
165+
@adjoint function Broadcast.broadcasted(::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray)
166+
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
167+
end
168+
@adjoint function Broadcast.broadcasted(::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray)
169+
(
170+
x.*y,
171+
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
172+
)
173+
end
174+
@adjoint function Broadcast.broadcasted(::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray)
175+
res = x ./ y
176+
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))
177+
end
178+
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray)
179+
.-x, Δ -> (nothing, _minus(Δ))
180+
end
181+
182+
@adjoint function Broadcast.broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::AbstractVectorOfArray, exp::Val{p}) where p
183+
y = Base.literal_pow.(^, x, exp)
184+
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
185+
end
186+
187+
@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x, Δ -> (nothing, Δ)
188+
189+
@adjoint function Broadcast.broadcasted(::typeof(tanh), x::AbstractVectorOfArray)
190+
y = tanh.(x)
191+
y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
192+
end
193+
194+
@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) =
195+
conj.(x), z̄ -> (nothing, conj.(z̄))
196+
197+
@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) =
198+
real.(x), z̄ -> (nothing, real.(z̄))
199+
200+
@adjoint Broadcast.broadcasted(::typeof(imag), x::AbstractVectorOfArray) =
201+
imag.(x), z̄ -> (nothing, im .* real.(z̄))
202+
203+
@adjoint Broadcast.broadcasted(::typeof(abs2), x::AbstractVectorOfArray) =
204+
abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x)
205+
206+
@adjoint function Broadcast.broadcasted(::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool)
207+
y = b === false ? a : a .+ b
208+
y, Δ -> (nothing, Δ, nothing)
209+
end
210+
@adjoint function Broadcast.broadcasted(::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number})
211+
y = b === false ? a : b .+ a
212+
y, Δ -> (nothing, nothing, Δ)
213+
end
214+
215+
@adjoint function Broadcast.broadcasted(::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool)
216+
y = b === false ? a : a .- b
217+
y, Δ -> (nothing, Δ, nothing)
218+
end
219+
@adjoint function Broadcast.broadcasted(::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number})
220+
b .- a, Δ -> (nothing, nothing, .-Δ)
221+
end
222+
223+
@adjoint function Broadcast.broadcasted(::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool)
224+
if b === false
225+
zero(a), Δ -> (nothing, zero(Δ), nothing)
226+
else
227+
a, Δ -> (nothing, Δ, nothing)
228+
end
229+
end
230+
@adjoint function Broadcast.broadcasted(::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number})
231+
if b === false
232+
zero(a), Δ -> (nothing, nothing, zero(Δ))
233+
else
234+
a, Δ -> (nothing, nothing, Δ)
235+
end
236+
end
237+
238+
@adjoint Broadcast.broadcasted(::Type{T}, x::AbstractVectorOfArray) where {T<:Number} =
239+
T.(x), ȳ -> (nothing, Zygote._project(x, ȳ),)
240+
133241
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
134242
N = ndims(x̄)
135243
if length(x) == length(x̄)

src/array_partition.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
165165
## Iterable Collection Constructs
166166

167167
Base.map(f, A::ArrayPartition) = ArrayPartition(map(x -> map(f, x), A.x))
168-
function Base.mapreduce(f, op, A::ArrayPartition)
169-
mapreduce(f, op, (mapreduce(f, op, x) for x in A.x))
168+
function Base.mapreduce(f, op, A::ArrayPartition{T}; kwargs...) where {T}
169+
mapreduce(f, op, (i for i in A); kwargs...)
170170
end
171171
Base.filter(f, A::ArrayPartition) = ArrayPartition(map(x -> filter(f, x), A.x))
172172
Base.any(f, A::ArrayPartition) = any(f, (any(f, x) for x in A.x))

src/vector_of_array.jl

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,16 @@ function VectorOfArray(vec::AbstractVector{T}, ::NTuple{N}) where {T, N}
133133
VectorOfArray{eltype(T), N, typeof(vec)}(vec)
134134
end
135135
# Assume that the first element is representative of all other elements
136-
VectorOfArray(vec::AbstractVector) = VectorOfArray(vec, (size(vec[1])..., length(vec)))
136+
function VectorOfArray(vec::AbstractVector)
137+
T = eltype(vec[1])
138+
N = ndims(vec[1])
139+
if all(x isa Union{<:AbstractArray, <:AbstractVectorOfArray} for x in vec)
140+
A = Vector{Union{typeof.(vec)...}}
141+
else
142+
A = typeof(vec)
143+
end
144+
VectorOfArray{T, N + 1, A}(vec)
145+
end
137146
function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray{T, N}}
138147
VectorOfArray{T, N + 1, typeof(vec)}(vec)
139148
end
@@ -482,21 +491,30 @@ function Base.append!(VA::AbstractVectorOfArray{T, N},
482491
return VA
483492
end
484493

494+
function Base.stack(VA::AbstractVectorOfArray; dims = :)
495+
stack(VA.u; dims)
496+
end
497+
485498
# AbstractArray methods
486499
function Base.view(A::AbstractVectorOfArray, I::Vararg{Any,M}) where {M}
487500
@inline
488501
J = map(i->Base.unalias(A,i), to_indices(A, I))
489502
@boundscheck checkbounds(A, J...)
490503
SubArray(IndexStyle(A), A, J, Base.index_dimsum(J...))
491504
end
505+
function Base.SubArray(parent::AbstractVectorOfArray, indices::Tuple)
506+
@inline
507+
SubArray(IndexStyle(Base.viewindexing(indices), IndexStyle(parent)), parent, Base.ensure_indexable(indices), Base.index_dimsum(indices...))
508+
end
509+
Base.isassigned(VA::AbstractVectorOfArray, idxs...) = checkbounds(Bool, VA, idxs...)
492510
Base.check_parent_index_match(::RecursiveArrayTools.AbstractVectorOfArray{T,N}, ::NTuple{N,Bool}) where {T,N} = nothing
493511
Base.ndims(::AbstractVectorOfArray{T, N}) where {T, N} = N
494512
function Base.checkbounds(::Type{Bool}, VA::AbstractVectorOfArray, idx...)
495513
if checkbounds(Bool, VA.u, last(idx))
496514
if last(idx) isa Integer
497-
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)))
515+
return all(checkbounds.(Bool, (VA.u[last(idx)],), Base.front(idx)...))
498516
else
499-
return all(checkbounds.(Bool, VA.u[last(idx)], Base.front(idx)))
517+
return all(checkbounds.(Bool, VA.u[last(idx)], tuple.(Base.front(idx))...))
500518
end
501519
end
502520
return false
@@ -595,10 +613,14 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray)
595613
end
596614

597615
# statistics
598-
@inline Base.sum(f, VA::AbstractVectorOfArray) = sum(f, Array(VA))
599-
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(Array(VA); kwargs...)
600-
@inline Base.prod(f, VA::AbstractVectorOfArray) = prod(f, Array(VA))
601-
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(Array(VA); kwargs...)
616+
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(identity, VA; kwargs...)
617+
@inline function Base.sum(f, VA::AbstractVectorOfArray; kwargs...)
618+
mapreduce(f, Base.add_sum, VA; kwargs...)
619+
end
620+
@inline Base.prod(VA::AbstractVectorOfArray; kwargs...) = prod(identity, VA; kwargs...)
621+
@inline function Base.prod(f, VA::AbstractVectorOfArray; kwargs...)
622+
mapreduce(f, Base.mul_prod, VA; kwargs...)
623+
end
602624

603625
@inline Statistics.mean(VA::AbstractVectorOfArray; kwargs...) = mean(Array(VA); kwargs...)
604626
@inline function Statistics.median(VA::AbstractVectorOfArray; kwargs...)
@@ -638,8 +660,12 @@ end
638660
end
639661

640662
Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u)
641-
function Base.mapreduce(f, op, A::AbstractVectorOfArray)
642-
mapreduce(f, op, (mapreduce(f, op, x) for x in A.u))
663+
664+
function Base.mapreduce(f, op, A::AbstractVectorOfArray; kwargs...)
665+
mapreduce(f, op, view(A, ntuple(_ -> :, ndims(A))...); kwargs...)
666+
end
667+
function Base.mapreduce(f, op, A::AbstractVectorOfArray{T,1,<:AbstractVector{T}}; kwargs...) where {T}
668+
mapreduce(f, op, A.u; kwargs...)
643669
end
644670

645671
## broadcasting

test/adjoints.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,27 @@ end
3939

4040
function loss7(x)
4141
_x = VectorOfArray([x .* i for i in 1:5])
42-
return sum(abs2, x .- 1)
42+
return sum(abs2, _x .- 1)
43+
end
44+
45+
# use a bunch of broadcasts to test all the adjoints
46+
function loss8(x)
47+
_x = VectorOfArray([x .* i for i in 1:5])
48+
res = copy(_x)
49+
res = res .+ _x
50+
res = res .+ 1
51+
res = res .* _x
52+
res = res .* 2.0
53+
res = res .* res
54+
res = res ./ 2.0
55+
res = res ./ _x
56+
res = 3.0 .- res
57+
res = .-res
58+
res = identity.(Base.literal_pow.(^, res, Val(2)))
59+
res = tanh.(res)
60+
res = res .+ im .* res
61+
res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res)
62+
return sum(abs2, res)
4363
end
4464

4565
x = float.(6:10)
@@ -51,3 +71,4 @@ loss(x)
5171
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
5272
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
5373
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
74+
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)

test/interface_tests.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,44 @@ push!(testda, [-1, -2, -3, -4])
5757
@test_throws MethodError push!(testda, [-1 -2 -3 -4])
5858
@test_throws MethodError push!(testda, [-1 -2; -3 -4])
5959

60+
# Type inference
61+
@inferred sum(testva)
62+
@inferred sum(VectorOfArray([VectorOfArray([zeros(4,4)])]))
63+
@inferred mapreduce(string, *, testva)
64+
65+
# mapreduce
66+
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
67+
@test mapreduce(x -> string(x) * "q", *, testva) == "1q2q3q4q5q6q7q8q9q"
68+
69+
testvb = VectorOfArray([rand(1:10, 3, 3, 3) for _ in 1:4])
70+
arrvb = Array(testvb)
71+
for i in 1:ndims(arrvb)
72+
@test sum(arrvb; dims=i) == sum(testvb; dims=i)
73+
@test prod(arrvb; dims=i) == prod(testvb; dims=i)
74+
@test mapreduce(string, *, arrvb; dims=i) == mapreduce(string, *, testvb; dims=i)
75+
end
76+
77+
# Test when ndims == 1
78+
testvb = VectorOfArray(collect(1.0:0.1:2.0))
79+
arrvb = Array(testvb)
80+
@test sum(arrvb) == sum(testvb)
81+
@test prod(arrvb) == prod(testvb)
82+
@test mapreduce(string, *, arrvb) == mapreduce(string, *, testvb)
83+
84+
# view
85+
testvc = VectorOfArray([rand(1:10, 3, 3) for _ in 1:3])
86+
arrvc = Array(testvc)
87+
for idxs in [(2, 2, :), (2, :, 2), (:, 2, 2), (:, :, 2), (:, 2, :), (2, : ,:), (:, :, :)]
88+
arr_view = view(arrvc, idxs...)
89+
voa_view = view(testvc, idxs...)
90+
@test size(arr_view) == size(voa_view)
91+
@test all(arr_view .== voa_view)
92+
end
93+
94+
# test stack
95+
@test stack(testva) == [1 4 7; 2 5 8; 3 6 9]
96+
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]
97+
6098
# convert array from VectorOfArray/DiffEqArray
6199
t = 1:8
62100
recs = [rand(10, 7) for i in 1:8]

test/partitions_test.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
104104
@inferred recursive_one(x)
105105
@inferred recursive_bottom_eltype(x)
106106

107+
# mapreduce
108+
@inferred Union{Int, Float64} sum(x)
109+
@inferred sum(ArrayPartition(ArrayPartition(zeros(4,4))))
110+
@inferred sum(ArrayPartition(ArrayPartition(zeros(4))))
111+
@inferred sum(ArrayPartition(zeros(4,4)))
112+
@inferred mapreduce(string, *, x)
113+
@test mapreduce(i -> string(i) * "q", *, x) == "1q2q3.0q4.0q"
114+
107115
# broadcasting
108116
_scalar_op(y) = y + 1
109117
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:

0 commit comments

Comments
 (0)