Skip to content

Commit 3924de9

Browse files
authored
Prevent this rule from being misapplied to Symbol types. (#151)
* Prevent this rule from being misapplied to `Symbol` types. See SciML/ModelingToolkit.jl#972 * Constrain all the indices in the general case to be integers.
1 parent a9061b8 commit 3924de9

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/zygote.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i)
1+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int)
22
function AbstractVectorOfArray_getindex_adjoint(Δ)
33
Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))]
44
(NoTangent(),Δ′,NoTangent())
55
end
66
VA[i],AbstractVectorOfArray_getindex_adjoint
77
end
88

9-
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i, j...)
9+
function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N}
1010
function AbstractVectorOfArray_getindex_adjoint(Δ)
1111
Δ′ = zero(VA)
12-
Δ′[i,j...] = Δ
13-
(NoTangent(), Δ′, i,map(_ -> NoTangent(), j)...)
12+
Δ′[indices...] = Δ
13+
(NoTangent(), Δ′, indices[1],map(_ -> NoTangent(), indices[2:end])...)
1414
end
15-
VA[i,j...],AbstractVectorOfArray_getindex_adjoint
15+
VA[indices...],AbstractVectorOfArray_getindex_adjoint
1616
end
1717

1818
function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}} = Val{false}) where {S<:Tuple,copy_x}

0 commit comments

Comments
 (0)