Skip to content

Commit a45a3c3

Browse files
committed
small fixes, and formatting update
1 parent 2b08f41 commit a45a3c3

File tree

2 files changed

+80
-80
lines changed

2 files changed

+80
-80
lines changed

src/systems/diffeqs/sdesystem.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ function generate_diffusion_function(sys::SDESystem, dvs = states(sys),
153153
return build_function(get_noiseeqs(sys),
154154
map(x -> time_varying_as_func(value(x), sys), dvs),
155155
map(x -> time_varying_as_func(value(x), sys), ps),
156-
get_iv(sys); kwargs...)
156+
get_iv(sys); kwargs...)
157157
end
158158

159159
"""
@@ -205,7 +205,6 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
205205
name = name, checks = false)
206206
end
207207

208-
209208
"""
210209
$(TYPEDSIGNATURES)
211210
@@ -259,7 +258,7 @@ simmod = solve(ensemble_probmod,EM(),dt=dt,trajectories=numtraj)
259258
```
260259
261260
"""
262-
function Girsanov_transform(sys::SDESystem, u; θ0=1.0)
261+
function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
263262
name = nameof(sys)
264263

265264
# register new varible θ corresponding to 1D correction process θ(t)
@@ -269,48 +268,49 @@ function Girsanov_transform(sys::SDESystem, u; θ0=1.0)
269268

270269
# determine the adjustable parameters `d` given `u`
271270
# gradient of u with respect to states
272-
grad = Symbolics.gradient(u,states(sys))
271+
grad = Symbolics.gradient(u, states(sys))
273272

274273
noiseeqs = get_noiseeqs(sys)
275274
if typeof(noiseeqs) <: Vector
276-
d = simplify.(-(noiseeqs.*grad)/u)
277-
drift_correction = noiseeqs.*d
275+
d = simplify.(-(noiseeqs .* grad) / u)
276+
drift_correction = noiseeqs .* d
278277
else
279-
d = simplify.(-noiseeqs*grad/u)
280-
drift_correction = noiseeqs*d
278+
d = simplify.(-noiseeqs * grad / u)
279+
drift_correction = noiseeqs * d
281280
end
282281

283282
# transformation adds additional state θ: newX = (X,θ)
284283
# drift function for state is modified
285284
# θ has zero drift
286-
deqs = vcat([equations(sys)[i].lhs ~ equations(sys)[i].rhs - drift_correction[i] for i in eachindex(states(sys))]...)
285+
deqs = vcat([equations(sys)[i].lhs ~ equations(sys)[i].rhs - drift_correction[i]
286+
for i in eachindex(states(sys))]...)
287287
deqsθ = D(θ) ~ 0
288-
push!(deqs,deqsθ)
288+
push!(deqs, deqsθ)
289289

290290
# diffusion matrix is of size d x m (d states, m noise), with diagonal noise represented as a d-dimensional vector
291291
# for diagonal noise processes with m>1, the noise process will become non-diagonal; extra state component but no new noise process.
292292
# new diffusion matrix is of size d+1 x M
293293
# diffusion for state is unchanged
294294

295-
noiseqsθ = θ*d
295+
noiseqsθ = θ * d
296296

297297
if typeof(noiseeqs) <: Vector
298298
m = size(noiseeqs)
299299
if m == 1
300-
push!(noiseeqs,noiseqsθ)
300+
push!(noiseeqs, noiseqsθ)
301301
else
302302
noiseeqs = [Array(Diagonal(noiseeqs)); noiseqsθ']
303303
end
304304
else
305305
noiseeqs = [Array(noiseeqs); noiseqsθ']
306306
end
307307

308-
state = [states(sys);θ]
308+
state = [states(sys); θ]
309309

310310
# return modified SDE System
311311
SDESystem(deqs, noiseeqs, get_iv(sys), state, parameters(sys);
312-
defaults = Dict=> θ0), observed = [weight ~ θ/θ0],
313-
name=name, checks=false)
312+
defaults = Dict=> θ0), observed = [weight ~ θ / θ0],
313+
name = name, checks = false)
314314
end
315315

316316
"""

test/sdesystem.jl

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -517,69 +517,69 @@ end
517517

518518

519519
@testset "Measure Transformation for variance reduction" begin
520-
@parameters α β
521-
@variables t x(t) y(t) z(t)
522-
D = Differential(t)
523-
524-
# Evaluate Exp [(X_T)^2]
525-
# SDE: X_t = x + \int_0^t α X_z dz + \int_0^t b X_z dW_z
526-
eqs = [D(x) ~ α*x]
527-
noiseeqs =*x]
528-
529-
@named de = SDESystem(eqs,noiseeqs,t,[x],[α,β])
530-
531-
g(x) = x[2]^2
532-
dt = 1 //2 ^(7)
533-
x0 = 0.1
534-
535-
## Standard approach
536-
# EM with 1`000 trajectories for stepsize 2^-7
537-
u0map = [
538-
x => x0
539-
]
540-
541-
parammap = [
542-
α => 1.5,
543-
β => 1.0
544-
]
545-
546-
prob = SDEProblem(de,u0map,(0.0,1.0),parammap)
547-
548-
function prob_func(prob, i, repeat)
549-
remake(prob,seed=seeds[i])
550-
end
551-
numtraj = Int(1e3)
552-
seed = 100
553-
Random.seed!(seed)
554-
seeds = rand(UInt, numtraj)
555-
556-
ensemble_prob = EnsembleProblem(prob;
557-
output_func = (sol,i) -> (g(sol[end]),false),
558-
prob_func = prob_func
559-
)
560-
561-
sim = solve(ensemble_prob,EM(),dt=dt,trajectories=numtraj)
562-
μ = mean(sim)
563-
σ = std(sim)/sqrt(numtraj)
564-
565-
## Variance reduction method
566-
u = x
567-
demod = ModelingToolkit.Girsanov_transform(de, u; θ0=0.1)
568-
569-
probmod = SDEProblem(demod,u0map,(0.0,1.0),parammap)
570-
571-
ensemble_probmod = EnsembleProblem(probmod;
572-
output_func = (sol,i) -> (g(sol[x,end])*sol[weight,end],false),
573-
prob_func = prob_func
574-
)
575-
576-
simmod = solve(ensemble_probmod,EM(),dt=dt,trajectories=numtraj)
577-
μmod = mean(simmod)
578-
σmod = std(simmod)/sqrt(numtraj)
579-
580-
display("μ = $(round(μ, digits=2)) ± $(round(σ, digits=2))")
581-
display("μmod = $(round(μmod, digits=2)) ± $(round(σmod, digits=2))")
582-
583-
@test μ μmod atol = 2σ
584-
@test σ > σmod
520+
@parameters α β
521+
@variables t x(t) y(t) z(t)
522+
D = Differential(t)
523+
524+
# Evaluate Exp [(X_T)^2]
525+
# SDE: X_t = x + \int_0^t α X_z dz + \int_0^t b X_z dW_z
526+
eqs = [D(x) ~ α * x]
527+
noiseeqs = * x]
528+
529+
@named de = SDESystem(eqs, noiseeqs, t, [x], [α, β])
530+
531+
g(x) = x[1]^2
532+
dt = 1 // 2^(7)
533+
x0 = 0.1
534+
535+
## Standard approach
536+
# EM with 1`000 trajectories for stepsize 2^-7
537+
u0map = [
538+
x => x0,
539+
]
540+
541+
parammap = [
542+
α => 1.5,
543+
β => 1.0,
544+
]
545+
546+
prob = SDEProblem(de, u0map, (0.0, 1.0), parammap)
547+
548+
function prob_func(prob, i, repeat)
549+
remake(prob, seed = seeds[i])
550+
end
551+
numtraj = Int(1e3)
552+
seed = 100
553+
Random.seed!(seed)
554+
seeds = rand(UInt, numtraj)
555+
556+
ensemble_prob = EnsembleProblem(prob;
557+
output_func = (sol, i) -> (g(sol[end]), false),
558+
prob_func = prob_func)
559+
560+
sim = solve(ensemble_prob, EM(), dt = dt, trajectories = numtraj)
561+
μ = mean(sim)
562+
σ = std(sim) / sqrt(numtraj)
563+
564+
## Variance reduction method
565+
u = x
566+
demod = ModelingToolkit.Girsanov_transform(de, u; θ0 = 0.1)
567+
568+
probmod = SDEProblem(demod, u0map, (0.0, 1.0), parammap)
569+
570+
ensemble_probmod = EnsembleProblem(probmod;
571+
output_func = (sol, i) -> (g(sol[x, end]) *
572+
sol[demod.θ, end] /
573+
sol[demod.θ, 1], false),
574+
prob_func = prob_func)
575+
576+
simmod = solve(ensemble_probmod, EM(), dt = dt, trajectories = numtraj)
577+
μmod = mean(simmod)
578+
σmod = std(simmod) / sqrt(numtraj)
579+
580+
display("μ = $(round(μ, digits=2)) ± $(round(σ, digits=2))")
581+
display("μmod = $(round(μmod, digits=2)) ± $(round(σmod, digits=2))")
582+
583+
@test μμmod atol=2σ
584+
@test σ > σmod
585585
end

0 commit comments

Comments
 (0)