Skip to content

Commit 4fe86d3

Browse files
Merge branch 'master' into format
2 parents 44b3dfb + 61d4754 commit 4fe86d3

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <[email protected]> and contributors"]
4-
version = "2.94.1"
4+
version = "2.96.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -85,7 +85,7 @@ RecipesBase = "1.3.4"
8585
RecursiveArrayTools = "3.27.2"
8686
Reexport = "1"
8787
RuntimeGeneratedFunctions = "0.5.12"
88-
SciMLOperators = "0.4.0, 1"
88+
SciMLOperators = "0.4, 1.3"
8989
SciMLStructures = "1.1"
9090
StableRNGs = "1.0"
9191
StaticArrays = "1.7"

ext/SciMLBaseChainRulesCoreExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
6565
ODEProblem(args...; kwargs...), ODEProblemAdjoint
6666
end
6767

68+
function ChainRulesCore.rrule(::Type{
69+
<:ODEProblem{iip, T}}, args...; kwargs...) where {iip, T}
70+
function ODEProblemAdjoint(ȳ)
71+
(NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
72+
end
73+
74+
ODEProblem(args...; kwargs...), ODEProblemAdjoint
75+
end
76+
6877
function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
6978
function SDEProblemAdjoint(ȳ)
7079
(NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)

src/scimlfunctions.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2571,19 +2571,23 @@ end
25712571
######### Backwards Compatibility Overloads
25722572

25732573
(f::ODEFunction)(args...) = f.f(args...)
2574-
function (f::ODEFunction)(du, u, p, t)
2575-
if f.f isa AbstractSciMLOperator
2576-
f.f(du, u, u, p, t)
2577-
else
2578-
f.f(du, u, p, t)
2579-
end
2580-
end
2581-
function (f::ODEFunction)(u, p, t)
2582-
if f.f isa AbstractSciMLOperator
2583-
f.f(u, u, p, t)
2584-
else
2585-
f.f(u, p, t)
2586-
end
2574+
2575+
@static if isdefined(SciMLOperators, :isv1)
2576+
function (f::ODEFunction)(du, u, p, t)
2577+
if f.f isa AbstractSciMLOperator
2578+
f.f(du, u, u, p, t)
2579+
else
2580+
f.f(du, u, p, t)
2581+
end
2582+
end
2583+
2584+
function (f::ODEFunction)(u, p, t)
2585+
if f.f isa AbstractSciMLOperator
2586+
f.f(u, u, p, t)
2587+
else
2588+
f.f(u, p, t)
2589+
end
2590+
end
25872591
end
25882592

25892593
(f::NonlinearFunction)(args...) = f.f(args...)
@@ -2779,7 +2783,7 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
27792783
if specialization(f) === NoSpecialize
27802784
ODEFunction{isinplace(f), specialization(f), Any, Any, Any,
27812785
Any, Any, Any, Any, typeof(f.jac_prototype),
2782-
typeof(f.sparsity), Any, Any, Any,
2786+
typeof(f.sparsity), Any, Any, Any, Any,
27832787
Any, typeof(f.colorvec),
27842788
typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
27852789
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,

0 commit comments

Comments
 (0)