Skip to content

Commit 9ee2af8

Browse files
fix: define Zygote.@adjoint to fall back to ChainRulesCore.rrule
1 parent 4853750 commit 9ee2af8

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ SciMLBasePartialFunctionsExt = "PartialFunctions"
5050
SciMLBasePyCallExt = "PyCall"
5151
SciMLBasePythonCallExt = "PythonCall"
5252
SciMLBaseRCallExt = "RCall"
53-
SciMLBaseZygoteExt = "Zygote"
53+
SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"]
5454

5555
[compat]
5656
ADTypes = "0.2.5,1.0.0"

ext/SciMLBaseZygoteExt.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SciMLBaseZygoteExt
33
using Zygote
44
using Zygote: @adjoint, pullback
55
import Zygote: literal_getproperty
6+
import ChainRulesCore
67
using SciMLBase
78
using SciMLBase: ODESolution, remake,
89
getobserved, build_solution, EnsembleSolution,
@@ -40,6 +41,12 @@ import SciMLStructures
4041
VA[i, j], ODESolution_getindex_pullback
4142
end
4243

44+
struct ZygoteConfig <: ChainRulesCore.RuleConfig{ChainRulesCore.HasReverseMode} end
45+
46+
@adjoint function Base.getindex(VA::ODESolution, sym, j::Integer)
47+
Zygote.ChainRulesCore.rrule(ZygoteConfig(), getindex, VA, sym, j)
48+
end
49+
4350
@adjoint function EnsembleSolution(sim, time, converged, stats)
4451
out = EnsembleSolution(sim, time, converged, stats)
4552
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}

0 commit comments

Comments
 (0)