Skip to content

Commit 7620bdd

Browse files
authored
Merge branch 'SciML:master' into func-affect
2 parents 6205992 + b77c4aa commit 7620bdd

File tree

5 files changed

+234
-2
lines changed

5 files changed

+234
-2
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
9191
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9292
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9393
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
94+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9495
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9596
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
9697
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
9798
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
9899
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99100

100101
[targets]
101-
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
102+
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

docs/src/basics/FAQ.md

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,58 @@ ERROR: TypeError: non-boolean (Num) used in boolean context
4141

4242
then it's likely you are trying to trace through a function which cannot be
4343
directly represented in Julia symbols. The techniques to handle this problem,
44-
such as `@register_symbolic`, are described in detail
44+
such as `@register_symbolic`, are described in detail
4545
[in the Symbolics.jl documentation](https://symbolics.juliasymbolics.org/dev/manual/faq/#Transforming-my-function-to-a-symbolic-equation-has-failed.-What-do-I-do?-1).
46+
47+
## Using ModelingToolkit with Optimization / Automatic Differentiation
48+
49+
If you are using ModelingToolkit inside of a loss function and are having issues with
50+
mixing MTK with automatic differentiation, getting performance, etc... don't! Instead, use
51+
MTK outside of the loss function to generate the code, and then use the generated code
52+
inside of the loss function.
53+
54+
For example, let's say you were building ODEProblems in the loss function like:
55+
56+
```julia
57+
function loss(p)
58+
prob = ODEProblem(sys, [], [p1 => p[1], p2 => p[2]])
59+
sol = solve(prob, Tsit5())
60+
sum(abs2,sol)
61+
end
62+
```
63+
64+
Since `ODEProblem` on a MTK `sys` will have to generate code, this will be slower than
65+
caching the generated code, and will required automatic differentiation to go through the
66+
code generation process itself. All of this is unnecessary. Instead, generate the problem
67+
once outside of the loss function, and remake the prob inside of the loss function:
68+
69+
```julia
70+
prob = ODEProblem(sys, [], [p1 => p[1], p2 => p[2]])
71+
function loss(p)
72+
remake(prob,p = ...)
73+
sol = solve(prob, Tsit5())
74+
sum(abs2,sol)
75+
end
76+
```
77+
78+
Now, one has to be careful with `remake` to ensure that the parameters are in the right
79+
order. One can use the previously mentioned indexing functionality to generate index
80+
maps for reordering `p` like:
81+
82+
```julia
83+
p = @parameters x y z
84+
idxs = ModelingToolkit.varmap_to_vars([p[1] => 1, p[2] => 2, p[3] => 3], p)
85+
p[idxs]
86+
```
87+
88+
Using this, the fixed index map can be used in the loss function. This would look like:
89+
90+
```julia
91+
prob = ODEProblem(sys, [], [p1 => p[1], p2 => p[2]])
92+
idxs = Int.(ModelingToolkit.varmap_to_vars([p1 => 1, p2 => 2], p))
93+
function loss(p)
94+
remake(prob,p = p[idxs])
95+
sol = solve(prob, Tsit5())
96+
sum(abs2,sol)
97+
end
98+
```

docs/src/systems/SDESystem.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ sde = SDESystem(ode, noiseeqs)
2424
```@docs
2525
structural_simplify
2626
alias_elimination
27+
Girsanov_transform
2728
```
2829

2930
## Analyses

src/systems/diffeqs/sdesystem.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,114 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
205205
name = name, checks = false)
206206
end
207207

208+
"""
209+
$(TYPEDSIGNATURES)
210+
211+
Measure transformation method that allows for a reduction in the variance of an estimator `Exp(g(X_t))`.
212+
Input: Original SDE system and symbolic function `u(t,x)` with scalar output that
213+
defines the adjustable parameters `d` in the Girsanov transformation. Optional: initial
214+
condition for `θ0`.
215+
Output: Modified SDESystem with additional component `θ_t` and initial value `θ0`, as well as
216+
the weight `θ_t/θ0` as observed equation, such that the estimator `Exp(g(X_t)θ_t/θ0)`
217+
has a smaller variance.
218+
219+
Reference:
220+
Kloeden, P. E., Platen, E., & Schurz, H. (2012). Numerical solution of SDE through computer
221+
experiments. Springer Science & Business Media.
222+
223+
# Example
224+
225+
```julia
226+
using ModelingToolkit
227+
228+
@parameters α β
229+
@variables t x(t) y(t) z(t)
230+
D = Differential(t)
231+
232+
eqs = [D(x) ~ α*x]
233+
noiseeqs = [β*x]
234+
235+
@named de = SDESystem(eqs,noiseeqs,t,[x],[α,β])
236+
237+
# define u (user choice)
238+
u = x
239+
θ0 = 0.1
240+
g(x) = x[1]^2
241+
demod = ModelingToolkit.Girsanov_transform(de, u; θ0=0.1)
242+
243+
u0modmap = [
244+
x => x0
245+
]
246+
247+
parammap = [
248+
α => 1.5,
249+
β => 1.0
250+
]
251+
252+
probmod = SDEProblem(demod,u0modmap,(0.0,1.0),parammap)
253+
ensemble_probmod = EnsembleProblem(probmod;
254+
output_func = (sol,i) -> (g(sol[x,end])*sol[demod.weight,end],false),
255+
)
256+
257+
simmod = solve(ensemble_probmod,EM(),dt=dt,trajectories=numtraj)
258+
```
259+
260+
"""
261+
function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
262+
name = nameof(sys)
263+
264+
# register new varible θ corresponding to 1D correction process θ(t)
265+
t = get_iv(sys)
266+
D = Differential(t)
267+
@variables θ(t), weight(t)
268+
269+
# determine the adjustable parameters `d` given `u`
270+
# gradient of u with respect to states
271+
grad = Symbolics.gradient(u, states(sys))
272+
273+
noiseeqs = get_noiseeqs(sys)
274+
if typeof(noiseeqs) <: Vector
275+
d = simplify.(-(noiseeqs .* grad) / u)
276+
drift_correction = noiseeqs .* d
277+
else
278+
d = simplify.(-noiseeqs * grad / u)
279+
drift_correction = noiseeqs * d
280+
end
281+
282+
# transformation adds additional state θ: newX = (X,θ)
283+
# drift function for state is modified
284+
# θ has zero drift
285+
deqs = vcat([equations(sys)[i].lhs ~ equations(sys)[i].rhs - drift_correction[i]
286+
for i in eachindex(states(sys))]...)
287+
deqsθ = D(θ) ~ 0
288+
push!(deqs, deqsθ)
289+
290+
# diffusion matrix is of size d x m (d states, m noise), with diagonal noise represented as a d-dimensional vector
291+
# for diagonal noise processes with m>1, the noise process will become non-diagonal; extra state component but no new noise process.
292+
# new diffusion matrix is of size d+1 x M
293+
# diffusion for state is unchanged
294+
295+
noiseqsθ = θ * d
296+
297+
if typeof(noiseeqs) <: Vector
298+
m = size(noiseeqs)
299+
if m == 1
300+
push!(noiseeqs, noiseqsθ)
301+
else
302+
noiseeqs = [Array(Diagonal(noiseeqs)); noiseqsθ']
303+
end
304+
else
305+
noiseeqs = [Array(noiseeqs); noiseqsθ']
306+
end
307+
308+
state = [states(sys); θ]
309+
310+
# return modified SDE System
311+
SDESystem(deqs, noiseeqs, get_iv(sys), state, parameters(sys);
312+
defaults = Dict=> θ0), observed = [weight ~ θ / θ0],
313+
name = name, checks = false)
314+
end
315+
208316
"""
209317
```julia
210318
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;

test/sdesystem.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra
22
using StochasticDiffEq, OrdinaryDiffEq, SparseArrays
33
using Random, Test
4+
using Statistics
45

56
# Define some variables
67
@parameters t σ ρ β
@@ -513,3 +514,71 @@ noiseeqs = [0.1 * x]
513514
@test observed(ode) == [weight ~ x * 10]
514515
@test solode[weight] == 10 * solode[x]
515516
end
517+
518+
@testset "Measure Transformation for variance reduction" begin
519+
@parameters α β
520+
@variables t x(t) y(t) z(t)
521+
D = Differential(t)
522+
523+
# Evaluate Exp [(X_T)^2]
524+
# SDE: X_t = x + \int_0^t α X_z dz + \int_0^t b X_z dW_z
525+
eqs = [D(x) ~ α * x]
526+
noiseeqs =* x]
527+
528+
@named de = SDESystem(eqs, noiseeqs, t, [x], [α, β])
529+
530+
g(x) = x[1]^2
531+
dt = 1 // 2^(7)
532+
x0 = 0.1
533+
534+
## Standard approach
535+
# EM with 1`000 trajectories for stepsize 2^-7
536+
u0map = [
537+
x => x0,
538+
]
539+
540+
parammap = [
541+
α => 1.5,
542+
β => 1.0,
543+
]
544+
545+
prob = SDEProblem(de, u0map, (0.0, 1.0), parammap)
546+
547+
function prob_func(prob, i, repeat)
548+
remake(prob, seed = seeds[i])
549+
end
550+
numtraj = Int(1e3)
551+
seed = 100
552+
Random.seed!(seed)
553+
seeds = rand(UInt, numtraj)
554+
555+
ensemble_prob = EnsembleProblem(prob;
556+
output_func = (sol, i) -> (g(sol[end]), false),
557+
prob_func = prob_func)
558+
559+
sim = solve(ensemble_prob, EM(), dt = dt, trajectories = numtraj)
560+
μ = mean(sim)
561+
σ = std(sim) / sqrt(numtraj)
562+
563+
## Variance reduction method
564+
u = x
565+
demod = ModelingToolkit.Girsanov_transform(de, u; θ0 = 0.1)
566+
567+
probmod = SDEProblem(demod, u0map, (0.0, 1.0), parammap)
568+
569+
ensemble_probmod = EnsembleProblem(probmod;
570+
output_func = (sol, i) -> (g(sol[x, end]) *
571+
sol[demod.weight, end],
572+
false),
573+
prob_func = prob_func)
574+
575+
simmod = solve(ensemble_probmod, EM(), dt = dt, trajectories = numtraj)
576+
μmod = mean(simmod)
577+
σmod = std(simmod) / sqrt(numtraj)
578+
579+
display("μ = $(round(μ, digits=2)) ± $(round(σ, digits=2))")
580+
display("μmod = $(round(μmod, digits=2)) ± $(round(σmod, digits=2))")
581+
582+
@test μμmod atol=2σ
583+
@test σ > σmod
584+
end

0 commit comments

Comments
 (0)