Skip to content

Commit c9ad77c

Browse files
Merge pull request #956 from SciML/ad_piracy
Remove AD piracy functions by moving to SciMLBase
2 parents 8262b58 + 042f08e commit c9ad77c

File tree

8 files changed

+53
-352
lines changed

8 files changed

+53
-352
lines changed

Project.toml

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ version = "6.136.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
98
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
10-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
119
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1210
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1311
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
@@ -25,7 +23,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2523
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2624
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2725
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
28-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2926
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3027
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
3128
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
@@ -35,9 +32,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3532
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3633
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
3734
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
38-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3935

4036
[weakdeps]
37+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4138
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4239
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4340
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
@@ -47,7 +44,6 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
4744
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4845
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4946
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
50-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5147

5248
[extensions]
5349
DiffEqBaseDistributionsExt = "Distributions"
@@ -59,7 +55,6 @@ DiffEqBaseMonteCarloMeasurementsExt = "MonteCarloMeasurements"
5955
DiffEqBaseReverseDiffExt = "ReverseDiff"
6056
DiffEqBaseTrackerExt = "Tracker"
6157
DiffEqBaseUnitfulExt = "Unitful"
62-
DiffEqBaseZygoteExt = "Zygote"
6358

6459
[compat]
6560
ArrayInterface = "7"
@@ -73,31 +68,30 @@ FastBroadcast = "0.2"
7368
ForwardDiff = "0.10"
7469
FunctionWrappers = "1.0"
7570
FunctionWrappersWrappers = "0.1"
76-
LinearAlgebra = "1.6"
77-
Logging = "1.6"
78-
Markdown = "1.6"
71+
LinearAlgebra = "1.9"
72+
Logging = "1.9"
73+
Markdown = "1.9"
7974
MuladdMacro = "0.2.1"
8075
Parameters = "0.12.0"
8176
PreallocationTools = "0.4"
8277
PrecompileTools = "1"
83-
Printf = "1.6"
78+
Printf = "1.9"
8479
RecursiveArrayTools = "2"
8580
Reexport = "1.0"
86-
Requires = "1.0"
87-
SciMLBase = "2.4.0"
81+
SciMLBase = "2.7.0"
8882
SciMLOperators = "0.2, 0.3"
8983
Setfield = "0.8, 1"
90-
SparseArrays = "1.6"
84+
SparseArrays = "1.9"
9185
Static = "0.7, 0.8"
9286
StaticArraysCore = "1.4"
9387
Statistics = "1"
9488
Tricks = "0.1.6"
9589
TruncatedStacktraces = "1"
96-
ZygoteRules = "0.2"
97-
julia = "1.6"
90+
julia = "1.9"
9891

9992
[extras]
10093
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
94+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
10195
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
10296
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
10397
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -116,7 +110,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
116110
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
117111
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
118112
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
119-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
120113

121114
[targets]
122115
test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua"]

ext/DiffEqBaseChainRulesCoreExt.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
module DiffEqBaseChainRulesCoreExt
2+
3+
using DiffEqBase
4+
import DiffEqBase: numargs
5+
6+
import ChainRulesCore
7+
import ChainRulesCore: NoTangent
8+
9+
ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), NoTangent()))
10+
ChainRulesCore.@non_differentiable checkkwargs(kwargshandle)
11+
12+
function ChainRulesCore.frule(::typeof(solve_up), prob,
13+
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
14+
u0, p, args...;
15+
kwargs...)
16+
_solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
17+
kwargs...)
18+
end
19+
20+
function ChainRulesCore.rrule(::typeof(solve_up), prob::AbstractDEProblem,
21+
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
22+
u0, p, args...;
23+
kwargs...)
24+
_solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
25+
kwargs...)
26+
end
27+
28+
end

ext/DiffEqBaseZygoteExt.jl

Lines changed: 0 additions & 60 deletions
This file was deleted.

src/DiffEqBase.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ if isdefined(Base, :Experimental) &&
33
isdefined(Base.Experimental, Symbol("@max_methods"))
44
@eval Base.Experimental.@max_methods 1
55
end
6-
if !isdefined(Base, :get_extension)
7-
using Requires
8-
end
96

107
import PrecompileTools
118

@@ -28,14 +25,10 @@ PrecompileTools.@recompile_invalidations begin
2825

2926
using Static: reduce_tup
3027

31-
import ChainRulesCore
3228
import RecursiveArrayTools
3329
import SparseArrays
3430
import TruncatedStacktraces
3531

36-
import ChainRulesCore: NoTangent, @non_differentiable
37-
import ZygoteRules
38-
3932
using Setfield
4033

4134
using ForwardDiff
@@ -140,13 +133,10 @@ include("callbacks.jl")
140133
include("common_defaults.jl")
141134
include("solve.jl")
142135
include("internal_euler.jl")
143-
include("init.jl")
144136
include("forwarddiff.jl")
145-
include("chainrules.jl")
146-
147137
include("termination_conditions.jl")
148-
149138
include("norecompile.jl")
139+
150140
# This is only used for oop stiff solvers
151141
default_factorize(A) = lu(A; check = false)
152142

@@ -181,8 +171,4 @@ export NLSolveTerminationMode,
181171

182172
export KeywordArgError, KeywordArgWarn, KeywordArgSilent
183173

184-
if !isdefined(Base, :get_extension)
185-
include("../ext/DiffEqBaseDistributionsExt.jl")
186-
end
187-
188174
end # module

0 commit comments

Comments
 (0)