Skip to content

Commit 3b829f1

Browse files
Merge pull request #1439 from frankschae/Girsanov_trafo
Girsanov transformation for variance reduction of expected values
2 parents fd22563 + a7cee1c commit 3b829f1

File tree

5 files changed

+182
-3
lines changed

5 files changed

+182
-3
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
9191
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
9292
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9393
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
94+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9495
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
9596
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
9697
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
9798
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
9899
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
99100

100101
[targets]
101-
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]
102+
test = ["AmplNLWriter", "BenchmarkTools", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials"]

docs/src/systems/SDESystem.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ sde = SDESystem(ode, noiseeqs)
2424
```@docs
2525
structural_simplify
2626
alias_elimination
27+
Girsanov_transform
2728
```
2829

2930
## Analyses

src/parameters.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ function isparameter(x)
88
if x isa Symbolic && Symbolics.getparent(x, false) !== false
99
p = Symbolics.getparent(x)
1010
isparameter(p) ||
11-
(hasmetadata(p, Symbolics.VariableSource) &&
12-
getmetadata(p, Symbolics.VariableSource)[1] == :parameters)
11+
(hasmetadata(p, Symbolics.VariableSource) &&
12+
getmetadata(p, Symbolics.VariableSource)[1] == :parameters)
1313
elseif istree(x) && operation(x) isa Symbolic
1414
getmetadata(x, MTKParameterCtx, false) ||
1515
isparameter(operation(x))

src/systems/diffeqs/sdesystem.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,114 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
205205
name = name, checks = false)
206206
end
207207

208+
"""
209+
$(TYPEDSIGNATURES)
210+
211+
Measure transformation method that allows for a reduction in the variance of an estimator `Exp(g(X_t))`.
212+
Input: Original SDE system and symbolic function `u(t,x)` with scalar output that
213+
defines the adjustable parameters `d` in the Girsanov transformation. Optional: initial
214+
condition for `θ0`.
215+
Output: Modified SDESystem with additional component `θ_t` and initial value `θ0`, as well as
216+
the weight `θ_t/θ0` as observed equation, such that the estimator `Exp(g(X_t)θ_t/θ0)`
217+
has a smaller variance.
218+
219+
Reference:
220+
Kloeden, P. E., Platen, E., & Schurz, H. (2012). Numerical solution of SDE through computer
221+
experiments. Springer Science & Business Media.
222+
223+
# Example
224+
225+
```julia
226+
using ModelingToolkit
227+
228+
@parameters α β
229+
@variables t x(t) y(t) z(t)
230+
D = Differential(t)
231+
232+
eqs = [D(x) ~ α*x]
233+
noiseeqs = [β*x]
234+
235+
@named de = SDESystem(eqs,noiseeqs,t,[x],[α,β])
236+
237+
# define u (user choice)
238+
u = x
239+
θ0 = 0.1
240+
g(x) = x[1]^2
241+
demod = ModelingToolkit.Girsanov_transform(de, u; θ0=0.1)
242+
243+
u0modmap = [
244+
x => x0
245+
]
246+
247+
parammap = [
248+
α => 1.5,
249+
β => 1.0
250+
]
251+
252+
probmod = SDEProblem(demod,u0modmap,(0.0,1.0),parammap)
253+
ensemble_probmod = EnsembleProblem(probmod;
254+
output_func = (sol,i) -> (g(sol[x,end])*sol[demod.weight,end],false),
255+
)
256+
257+
simmod = solve(ensemble_probmod,EM(),dt=dt,trajectories=numtraj)
258+
```
259+
260+
"""
261+
function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
262+
name = nameof(sys)
263+
264+
# register new varible θ corresponding to 1D correction process θ(t)
265+
t = get_iv(sys)
266+
D = Differential(t)
267+
@variables θ(t), weight(t)
268+
269+
# determine the adjustable parameters `d` given `u`
270+
# gradient of u with respect to states
271+
grad = Symbolics.gradient(u, states(sys))
272+
273+
noiseeqs = get_noiseeqs(sys)
274+
if typeof(noiseeqs) <: Vector
275+
d = simplify.(-(noiseeqs .* grad) / u)
276+
drift_correction = noiseeqs .* d
277+
else
278+
d = simplify.(-noiseeqs * grad / u)
279+
drift_correction = noiseeqs * d
280+
end
281+
282+
# transformation adds additional state θ: newX = (X,θ)
283+
# drift function for state is modified
284+
# θ has zero drift
285+
deqs = vcat([equations(sys)[i].lhs ~ equations(sys)[i].rhs - drift_correction[i]
286+
for i in eachindex(states(sys))]...)
287+
deqsθ = D(θ) ~ 0
288+
push!(deqs, deqsθ)
289+
290+
# diffusion matrix is of size d x m (d states, m noise), with diagonal noise represented as a d-dimensional vector
291+
# for diagonal noise processes with m>1, the noise process will become non-diagonal; extra state component but no new noise process.
292+
# new diffusion matrix is of size d+1 x M
293+
# diffusion for state is unchanged
294+
295+
noiseqsθ = θ * d
296+
297+
if typeof(noiseeqs) <: Vector
298+
m = size(noiseeqs)
299+
if m == 1
300+
push!(noiseeqs, noiseqsθ)
301+
else
302+
noiseeqs = [Array(Diagonal(noiseeqs)); noiseqsθ']
303+
end
304+
else
305+
noiseeqs = [Array(noiseeqs); noiseqsθ']
306+
end
307+
308+
state = [states(sys); θ]
309+
310+
# return modified SDE System
311+
SDESystem(deqs, noiseeqs, get_iv(sys), state, parameters(sys);
312+
defaults = Dict=> θ0), observed = [weight ~ θ / θ0],
313+
name = name, checks = false)
314+
end
315+
208316
"""
209317
```julia
210318
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;

test/sdesystem.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ModelingToolkit, StaticArrays, LinearAlgebra
22
using StochasticDiffEq, OrdinaryDiffEq, SparseArrays
33
using Random, Test
4+
using Statistics
45

56
# Define some variables
67
@parameters t σ ρ β
@@ -513,3 +514,71 @@ noiseeqs = [0.1 * x]
513514
@test observed(ode) == [weight ~ x * 10]
514515
@test solode[weight] == 10 * solode[x]
515516
end
517+
518+
@testset "Measure Transformation for variance reduction" begin
519+
@parameters α β
520+
@variables t x(t) y(t) z(t)
521+
D = Differential(t)
522+
523+
# Evaluate Exp [(X_T)^2]
524+
# SDE: X_t = x + \int_0^t α X_z dz + \int_0^t b X_z dW_z
525+
eqs = [D(x) ~ α * x]
526+
noiseeqs =* x]
527+
528+
@named de = SDESystem(eqs, noiseeqs, t, [x], [α, β])
529+
530+
g(x) = x[1]^2
531+
dt = 1 // 2^(7)
532+
x0 = 0.1
533+
534+
## Standard approach
535+
# EM with 1`000 trajectories for stepsize 2^-7
536+
u0map = [
537+
x => x0,
538+
]
539+
540+
parammap = [
541+
α => 1.5,
542+
β => 1.0,
543+
]
544+
545+
prob = SDEProblem(de, u0map, (0.0, 1.0), parammap)
546+
547+
function prob_func(prob, i, repeat)
548+
remake(prob, seed = seeds[i])
549+
end
550+
numtraj = Int(1e3)
551+
seed = 100
552+
Random.seed!(seed)
553+
seeds = rand(UInt, numtraj)
554+
555+
ensemble_prob = EnsembleProblem(prob;
556+
output_func = (sol, i) -> (g(sol[end]), false),
557+
prob_func = prob_func)
558+
559+
sim = solve(ensemble_prob, EM(), dt = dt, trajectories = numtraj)
560+
μ = mean(sim)
561+
σ = std(sim) / sqrt(numtraj)
562+
563+
## Variance reduction method
564+
u = x
565+
demod = ModelingToolkit.Girsanov_transform(de, u; θ0 = 0.1)
566+
567+
probmod = SDEProblem(demod, u0map, (0.0, 1.0), parammap)
568+
569+
ensemble_probmod = EnsembleProblem(probmod;
570+
output_func = (sol, i) -> (g(sol[x, end]) *
571+
sol[demod.weight, end],
572+
false),
573+
prob_func = prob_func)
574+
575+
simmod = solve(ensemble_probmod, EM(), dt = dt, trajectories = numtraj)
576+
μmod = mean(simmod)
577+
σmod = std(simmod) / sqrt(numtraj)
578+
579+
display("μ = $(round(μ, digits=2)) ± $(round(σ, digits=2))")
580+
display("μmod = $(round(μmod, digits=2)) ± $(round(σmod, digits=2))")
581+
582+
@test μμmod atol=2σ
583+
@test σ > σmod
584+
end

0 commit comments

Comments
 (0)