Skip to content

Commit 2d9cb78

Browse files
no ZygoteRules
1 parent 8aa1e1e commit 2d9cb78

File tree

4 files changed

+37
-37
lines changed

4 files changed

+37
-37
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ import RecursiveArrayTools
44

55
if isdefined(Base, :get_extension)
66
using Zygote
7-
using Zygote: ZygoteRules, FullArrays
7+
using Zygote: FullArrays, literal_getproperty, @adjoint
88
else
99
using ..Zygote
10-
using ..Zygote: ZygoteRules, FullArrays
10+
using ..Zygote: FullArrays, literal_getproperty, @adjoint
1111
end
1212

13-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
13+
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
1414
function AbstractVectorOfArray_getindex_adjoint(Δ)
1515
Δ′ = [(i == j ? Δ : Fill(zero(eltype(x)), size(x)))
1616
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -19,8 +19,8 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int)
1919
VA[i], AbstractVectorOfArray_getindex_adjoint
2020
end
2121

22-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray,
23-
i::Union{BitArray, AbstractArray{Bool}})
22+
@adjoint function getindex(VA::AbstractVectorOfArray,
23+
i::Union{BitArray, AbstractArray{Bool}})
2424
function AbstractVectorOfArray_getindex_adjoint(Δ)
2525
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x)))
2626
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -29,7 +29,7 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray,
2929
VA[i], AbstractVectorOfArray_getindex_adjoint
3030
end
3131

32-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
32+
@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArray{Int})
3333
function AbstractVectorOfArray_getindex_adjoint(Δ)
3434
iter = 0
3535
Δ′ = [(j i ? Δ[iter += 1] : Fill(zero(eltype(x)), size(x)))
@@ -39,8 +39,8 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::AbstractArr
3939
VA[i], AbstractVectorOfArray_getindex_adjoint
4040
end
4141

42-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray,
43-
i::Union{Int, AbstractArray{Int}})
42+
@adjoint function getindex(VA::AbstractVectorOfArray,
43+
i::Union{Int, AbstractArray{Int}})
4444
function AbstractVectorOfArray_getindex_adjoint(Δ)
4545
Δ′ = [(i[j] ? Δ[j] : Fill(zero(eltype(x)), size(x)))
4646
for (x, j) in zip(VA.u, 1:length(VA))]
@@ -49,16 +49,16 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray,
4949
VA[i], AbstractVectorOfArray_getindex_adjoint
5050
end
5151

52-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon)
52+
@adjoint function getindex(VA::AbstractVectorOfArray, i::Colon)
5353
function AbstractVectorOfArray_getindex_adjoint(Δ)
5454
(VectorOfArray(Δ), nothing)
5555
end
5656
VA[i], AbstractVectorOfArray_getindex_adjoint
5757
end
5858

59-
ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
60-
j::Union{Int, AbstractArray{Int}, CartesianIndex,
61-
Colon, BitArray, AbstractArray{Bool}}...)
59+
@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
60+
j::Union{Int, AbstractArray{Int}, CartesianIndex,
61+
Colon, BitArray, AbstractArray{Bool}}...)
6262
function AbstractVectorOfArray_getindex_adjoint(Δ)
6363
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
6464
Δ′[i, j...] = Δ
@@ -67,12 +67,12 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
6767
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
6868
end
6969

70-
ZygoteRules.@adjoint function ArrayPartition(x::S,
71-
::Type{Val{copy_x}} = Val{false}) where {
72-
S <:
73-
Tuple,
74-
copy_x
75-
}
70+
@adjoint function ArrayPartition(x::S,
71+
::Type{Val{copy_x}} = Val{false}) where {
72+
S <:
73+
Tuple,
74+
copy_x
75+
}
7676
function ArrayPartition_adjoint(_y)
7777
y = Array(_y)
7878
starts = vcat(0, cumsum(reduce(vcat, length.(x))))
@@ -83,23 +83,23 @@ ZygoteRules.@adjoint function ArrayPartition(x::S,
8383
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint
8484
end
8585

86-
ZygoteRules.@adjoint function VectorOfArray(u)
86+
@adjoint function VectorOfArray(u)
8787
VectorOfArray(u),
8888
y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
8989
for i in 1:size(y)[end]]),)
9090
end
9191

92-
ZygoteRules.@adjoint function DiffEqArray(u, t)
92+
@adjoint function DiffEqArray(u, t)
9393
DiffEqArray(u, t),
9494
y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]],
9595
t), nothing)
9696
end
9797

98-
ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(A::ArrayPartition, ::Val{:x})
98+
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
9999
function literal_ArrayPartition_x_adjoint(d)
100100
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
101101
end
102102
A.x, literal_ArrayPartition_x_adjoint
103103
end
104104

105-
end
105+
end

src/array_partition.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Base.ones(A::ArrayPartition, dims::NTuple{N, Int}) where {N} = ones(A)
115115

116116
# mutable iff all components of ArrayPartition are mutable
117117
@generated function ArrayInterface.ismutable(::Type{<:ArrayPartition{T, S}}) where {T, S
118-
}
118+
}
119119
res = all(ArrayInterface.ismutable, S.parameters)
120120
return :($res)
121121
end

src/zygote.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfAr
6565
#(::ChainRulesCore.ProjectTo{VectorOfArray})(dx::AbstractArray) = dx
6666

6767
# These rules duplicate the `rrule` methods above, because Zygote looks for an `@adjoint`
68-
# definition first, and finds its own before finding those.
68+
# definition first, and finds its own before finding those.

test/partitions_test.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -191,20 +191,20 @@ up = 2 .* ap .+ 1
191191
@test typeof(ap) == typeof(up)
192192

193193
@testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1,
194+
2,
195+
false),
196+
([
197+
1,
198+
],
199+
2,
200+
false),
201+
([
202+
1,
203+
],
204+
[
194205
2,
195-
false),
196-
([
197-
1,
198-
],
199-
2,
200-
false),
201-
([
202-
1,
203-
],
204-
[
205-
2,
206-
],
207-
true))
206+
],
207+
true))
208208
@test ArrayInterface.ismutable(ArrayPartition(a, b)) == r
209209
end
210210

0 commit comments

Comments
 (0)