Skip to content

Commit bb8bc18

Browse files
Merge pull request #1069 from SciML/dual_func
Handle dual detection on SciMLFunctions
2 parents 2cb1d9e + 5b0785a commit bb8bc18

File tree

6 files changed

+26
-13
lines changed

6 files changed

+26
-13
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,10 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
125125
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
126126
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
127127
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
128+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
128129
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
129130
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
130131
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
131132

132133
[targets]
133-
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]
134+
test = ["Distributed", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "SymbolicIndexingInterface", "ForwardDiff", "SparseArrays", "InteractiveUtils", "Plots", "Pkg", "Random", "StaticArrays", "SafeTestsets", "Statistics", "Test", "Distributions", "Aqua", "BenchmarkTools"]

src/forwarddiff.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,15 @@ const FORWARDDIFF_AUTODETECTION_FAILURE_MESSAGE = """
155155
end
156156
```
157157
158+
To opt a type out of the dual checking, define an overload
159+
that returns Any. For example:
160+
161+
```julia
162+
function DiffEqBase.anyeltypedual(::YourType, ::Type{Val{counter}}) where {counter}
163+
Any
164+
end
165+
```
166+
158167
If you have defined this on a common type which should
159168
be more generally supported, please open a pull request
160169
adding this dispatch. If you need help defining this dispatch,
@@ -338,6 +347,8 @@ function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {coun
338347
anyeltypedual(values(x))
339348
end
340349

350+
DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} = Any
351+
341352
@inline promote_u0(::Nothing, p, t0) = nothing
342353

343354
@inline function promote_u0(u0, p, t0)

test/downstream/default_linsolve_structure.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,23 @@ jp_diag = Diagonal(zeros(2))
66
fun = ODEFunction(f; jac = jac, jac_prototype = jp_diag)
77
prob = ODEProblem(fun, ones(2), (1.0, 10.0))
88
sol = solve(prob, Rosenbrock23())
9-
@test sol[end] [10.0, 10.0]
9+
@test sol.u[end] [10.0, 10.0]
1010
@test length(sol) < 60
1111

1212
sol = solve(prob, Rosenbrock23(autodiff = false))
13-
@test sol[end] [10.0, 10.0]
13+
@test sol.u[end] [10.0, 10.0]
1414
@test length(sol) < 60
1515

1616
jp = Tridiagonal(jp_diag)
1717
fun = ODEFunction(f; jac = jac, jac_prototype = jp)
1818
prob = ODEProblem(fun, ones(2), (1.0, 10.0))
1919

2020
sol = solve(prob, Rosenbrock23())
21-
@test sol[end] [10.0, 10.0]
21+
@test sol.u[end] [10.0, 10.0]
2222
@test length(sol) < 60
2323

2424
sol = solve(prob, Rosenbrock23(autodiff = false))
25-
@test sol[end] [10.0, 10.0]
25+
@test sol.u[end] [10.0, 10.0]
2626
@test length(sol) < 60
2727

2828
#=
@@ -43,7 +43,7 @@ sol = solve(prob,Rosenbrock23())
4343
fun = ODEFunction(f; jac = jac, jac_prototype = jp)
4444
prob = ODEProblem(fun, ones(2), (1.0, 10.0))
4545
sol = solve(prob, Rosenbrock23(autodiff = false))
46-
@test sol[end] [10.0, 10.0]
46+
@test sol.u[end] [10.0, 10.0]
4747
@test length(sol) < 60
4848
end
4949

@@ -52,6 +52,6 @@ end
5252
fun = ODEFunction(f; jac = jac, jac_prototype = jp)
5353
prob = ODEProblem(fun, ones(2), (1.0, 10.0))
5454
sol = solve(prob, Rosenbrock23(autodiff = false))
55-
@test sol[end] [10.0, 10.0]
55+
@test sol.u[end] [10.0, 10.0]
5656
@test length(sol) < 60
5757
end

test/downstream/ensemble_analysis.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ vecarr = timeseries_steps_quantile(sim, 0.5)
2525
m_series, v_series = timeseries_steps_meanvar(sim)
2626
summ = EnsembleSummary(sim)
2727

28-
m4, v4 = m_series[3], v_series[3]
28+
m4, v4 = m_series.u[3], v_series.u[3]
2929
covar_mat = timeseries_steps_meancov(sim)[3, 3]
3030
@test m m4
3131
@test v v4
@@ -54,7 +54,7 @@ m_series = timeseries_point_median(sim, 0:(1 // 2^(3)):1)
5454
m_series = timeseries_point_quantile(sim, 0.5, 0:(1 // 2^(3)):1)
5555
m_series, v_series = timeseries_point_meanvar(sim, 0:(1 // 2^(3)):1)
5656
summ = EnsembleSummary(sim, 0:(1 // 2^(3)):1)
57-
m5, v5 = m_series[5], v_series[5]
57+
m5, v5 = m_series.u[5], v_series.u[5]
5858
@test m m5
5959
@test v v5
6060
m6, m7, v6 = timeseries_point_meancov(sim, 0:(1 // 2^(3)):1, 0:(1 // 2^(3)):1)[5, 5]
@@ -83,7 +83,7 @@ vecarr = timeseries_steps_quantile(sim, 0.5)
8383
m_series, v_series = timeseries_steps_meanvar(sim)
8484
summ = EnsembleSummary(sim)
8585

86-
m4, v4 = m_series[3], v_series[3]
86+
m4, v4 = m_series.u[3], v_series.u[3]
8787
covar_mat = timeseries_steps_meancov(sim)[3, 3]
8888
@test m m4
8989
@test v v4
@@ -112,7 +112,7 @@ m_series = timeseries_point_median(sim, 0:(1 // 2^(3)):1)
112112
m_series = timeseries_point_quantile(sim, 0.5, 0:(1 // 2^(3)):1)
113113
m_series, v_series = timeseries_point_meanvar(sim, 0:(1 // 2^(3)):1)
114114
summ = EnsembleSummary(sim, 0:(1 // 2^(3)):1)
115-
m5, v5 = m_series[5], v_series[5]
115+
m5, v5 = m_series.u[5], v_series.u[5]
116116
@test m m5
117117
@test v v5
118118
m6, m7, v6 = timeseries_point_meancov(sim, 0:(1 // 2^(3)):1, 0:(1 // 2^(3)):1)[5, 5]

test/downstream/tables.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
using OrdinaryDiffEq, DataFrames, Test
1+
using OrdinaryDiffEq, DataFrames, Test, SymbolicIndexingInterface
22
f_2dlinear = (du, u, p, t) -> du .= 1.01u;
33
prob = ODEProblem(f_2dlinear, rand(2, 2), (0.0, 1.0));
44
sol1 = solve(prob, Euler(); dt = 1 // 2^(4));
55
df = DataFrame(sol1)
66
@test names(df) == ["timestamp", "value1", "value2", "value3", "value4"]
77

8-
prob = ODEProblem(ODEFunction(f_2dlinear, syms = [:a, :b, :c, :d]), rand(2, 2), (0.0, 1.0));
8+
prob = ODEProblem(ODEFunction(f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)), rand(2, 2), (0.0, 1.0));
99
sol2 = solve(prob, Euler(); dt = 1 // 2^(4));
1010
df = DataFrame(sol2)
1111
@test names(df) == ["timestamp", "a", "b", "c", "d"]

test/forwarddiff_dual_detection.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,4 @@ prob = ODEProblem{false}(f, u0, tspan)
347347
foo = SciMLBase.build_solution(
348348
prob, DiffEqBase.InternalEuler.FwdEulerAlg(), [u0, u0], [0.0, 1.0])
349349
DiffEqBase.anyeltypedual((; x = foo))
350+
DiffEqBase.anyeltypedual((; x = foo, y = prob.f))

0 commit comments

Comments
 (0)