Skip to content

Commit 8140b62

Browse files
author
fchen121
committed
Combined change of variable function for ODE and SDE
1 parent 9a53121 commit 8140b62

File tree

2 files changed

+92
-127
lines changed

2 files changed

+92
-127
lines changed

src/systems/diffeqs/basic_transformations.jl

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -94,55 +94,13 @@ new_sol = solve(new_prob, Tsit5())
9494
```
9595
9696
"""
97-
function changeofvariables(sys::System, iv, forward_subs, backward_subs; simplify=true, t0=missing)
98-
t = iv
99-
100-
old_vars = first.(backward_subs)
101-
new_vars = last.(forward_subs)
102-
# kept_vars = setdiff(states(sys), old_vars)
103-
# rhs = [eq.rhs for eq in equations(sys)]
104-
105-
# use: dz/dt = ∂z/∂x dx/dt + ∂z/∂t
106-
dzdt = Symbolics.derivative( first.(forward_subs), t )
107-
new_eqs = Equation[]
108-
for (new_var, ex) in zip(new_vars, dzdt)
109-
for ode_eq in equations(sys)
110-
ex = substitute(ex, ode_eq.lhs => ode_eq.rhs)
111-
end
112-
ex = substitute(ex, Dict(forward_subs))
113-
ex = substitute(ex, Dict(backward_subs))
114-
if simplify
115-
ex = Symbolics.simplify(ex, expand=true)
116-
end
117-
push!(new_eqs, Differential(t)(new_var) ~ ex)
118-
end
119-
120-
defs = get_defaults(sys)
121-
new_defs = Dict()
122-
for f_sub in forward_subs
123-
ex = substitute(first(f_sub), defs)
124-
if !ismissing(t0)
125-
ex = substitute(ex, t => t0)
126-
end
127-
new_defs[last(f_sub)] = ex
128-
end
129-
for para in parameters(sys)
130-
if haskey(defs, para)
131-
new_defs[para] = defs[para]
132-
end
133-
end
134-
@named new_sys = System(vcat(new_eqs, first.(backward_subs) .~ last.(backward_subs)), t;
135-
defaults=new_defs,
136-
observed=observed(sys)
137-
)
138-
if simplify
139-
return mtkcompile(new_sys)
97+
function changeofvariables(
98+
sys::System, iv, forward_subs, backward_subs;
99+
simplify=true, t0=missing, isSDE=false
100+
)
101+
if !iscomplete(sys)
102+
sys = mtkcompile(sys)
140103
end
141-
return new_sys
142-
end
143-
144-
function change_of_variable_SDE(sys::System, iv, forward_subs, backward_subs; simplify=true, t0=missing)
145-
sys = mtkcompile(sys)
146104
t = iv
147105

148106
old_vars = first.(backward_subs)
@@ -152,10 +110,15 @@ function change_of_variable_SDE(sys::System, iv, forward_subs, backward_subs; si
152110
# use: dY = (∂f/∂t + μ∂f/∂x + (1/2)*σ^2*∂2f/∂x2)dt + σ∂f/∂xdW
153111
old_eqs = equations(sys)
154112
neqs = get_noise_eqs(sys)
155-
neqs = [neqs[i,:] for i in 1:size(neqs,1)]
113+
if neqs !== nothing
114+
isSDE = true
115+
neqs = [neqs[i,:] for i in 1:size(neqs,1)]
156116

157-
brownvars = map([Symbol(:B, :_, i) for i in 1:length(neqs[1])]) do name
158-
unwrap(only(@brownian $name))
117+
brownvars = map([Symbol(:B, :_, i) for i in 1:length(neqs[1])]) do name
118+
unwrap(only(@brownian $name))
119+
end
120+
else
121+
neqs = ones(1, length(old_eqs))
159122
end
160123

161124
# df/dt = ∂f/∂x dx/dt + ∂f/∂t
@@ -168,8 +131,10 @@ function change_of_variable_SDE(sys::System, iv, forward_subs, backward_subs; si
168131
for (eqs, neq) in zip(old_eqs, neqs)
169132
if occursin(value(eqs.lhs), value(ex))
170133
ex = substitute(ex, eqs.lhs => eqs.rhs)
171-
for (noise, B) in zip(neq, brownvars)
172-
ex = ex + 1/2 * noise^2 * second + noise*first*B
134+
if isSDE
135+
for (noise, B) in zip(neq, brownvars)
136+
ex = ex + 1/2 * noise^2 * second + noise*first*B
137+
end
173138
end
174139
end
175140
end

test/changeofvariables.jl

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,101 +4,101 @@ using Test, LinearAlgebra
44

55
# Change of variables: z = log(x)
66
# (this implies that x = exp(z) is automatically non-negative)
7-
# @independent_variables t
8-
# @variables z(t)[1:2, 1:2]
9-
# D = Differential(t)
10-
# eqs = [D(D(z)) ~ ones(2, 2)]
11-
# @mtkcompile sys = System(eqs, t)
12-
# @test_nowarn ODEProblem(sys, [z => zeros(2, 2), D(z) => ones(2, 2)], (0.0, 10.0))
7+
@independent_variables t
8+
@variables z(t)[1:2, 1:2]
9+
D = Differential(t)
10+
eqs = [D(D(z)) ~ ones(2, 2)]
11+
@mtkcompile sys = System(eqs, t)
12+
@test_nowarn ODEProblem(sys, [z => zeros(2, 2), D(z) => ones(2, 2)], (0.0, 10.0))
1313

14-
# @parameters α
15-
# @variables x(t)
16-
# D = Differential(t)
17-
# eqs = [D(x) ~ α*x]
14+
@parameters α
15+
@variables x(t)
16+
D = Differential(t)
17+
eqs = [D(x) ~ α*x]
1818

19-
# tspan = (0., 1.)
20-
# def = [x => 1.0, α => -0.5]
19+
tspan = (0., 1.)
20+
def = [x => 1.0, α => -0.5]
2121

22-
# @mtkcompile sys = System(eqs, t;defaults=def)
23-
# prob = ODEProblem(sys, [], tspan)
24-
# sol = solve(prob, Tsit5())
22+
@mtkcompile sys = System(eqs, t;defaults=def)
23+
prob = ODEProblem(sys, [], tspan)
24+
sol = solve(prob, Tsit5())
2525

26-
# @variables z(t)
27-
# forward_subs = [log(x) => z]
28-
# backward_subs = [x => exp(z)]
29-
# new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
30-
# @test equations(new_sys)[1] == (D(z) ~ α)
26+
@variables z(t)
27+
forward_subs = [log(x) => z]
28+
backward_subs = [x => exp(z)]
29+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
30+
@test equations(new_sys)[1] == (D(z) ~ α)
3131

32-
# new_prob = ODEProblem(new_sys, [], tspan)
33-
# new_sol = solve(new_prob, Tsit5())
32+
new_prob = ODEProblem(new_sys, [], tspan)
33+
new_sol = solve(new_prob, Tsit5())
3434

35-
# @test isapprox(new_sol[x][end], sol[x][end], atol=1e-4)
35+
@test isapprox(new_sol[x][end], sol[x][end], atol=1e-4)
3636

3737

3838

39-
# # Riccati equation
40-
# @parameters α
41-
# @variables x(t)
42-
# D = Differential(t)
43-
# eqs = [D(x) ~ t^2 + α - x^2]
44-
# def = [x=>1., α => 1.]
45-
# @mtkcompile sys = System(eqs, t; defaults=def)
39+
# Riccati equation
40+
@parameters α
41+
@variables x(t)
42+
D = Differential(t)
43+
eqs = [D(x) ~ t^2 + α - x^2]
44+
def = [x=>1., α => 1.]
45+
@mtkcompile sys = System(eqs, t; defaults=def)
4646

47-
# @variables z(t)
48-
# forward_subs = [t + α/(x+t) => z ]
49-
# backward_subs = [ x => α/(z-t) - t]
47+
@variables z(t)
48+
forward_subs = [t + α/(x+t) => z ]
49+
backward_subs = [ x => α/(z-t) - t]
5050

51-
# new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true, t0=0.)
52-
# # output should be equivalent to
53-
# # t^2 + α - z^2 + 2 (but this simplification is not found automatically)
51+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true, t0=0.)
52+
# output should be equivalent to
53+
# t^2 + α - z^2 + 2 (but this simplification is not found automatically)
5454

55-
# tspan = (0., 1.)
56-
# prob = ODEProblem(sys,[],tspan)
57-
# new_prob = ODEProblem(new_sys,[],tspan)
55+
tspan = (0., 1.)
56+
prob = ODEProblem(sys,[],tspan)
57+
new_prob = ODEProblem(new_sys,[],tspan)
5858

59-
# sol = solve(prob, Tsit5())
60-
# new_sol = solve(new_prob, Tsit5())
59+
sol = solve(prob, Tsit5())
60+
new_sol = solve(new_prob, Tsit5())
6161

62-
# @test isapprox(sol[x][end], new_sol[x][end], rtol=1e-4)
62+
@test isapprox(sol[x][end], new_sol[x][end], rtol=1e-4)
6363

6464

65-
# # Linear transformation to diagonal system
66-
# @independent_variables t
67-
# @variables x(t)[1:3]
68-
# x = reshape(x, 3, 1)
69-
# D = Differential(t)
70-
# A = [0. -1. 0.; -0.5 0.5 0.; 0. 0. -1.]
71-
# right = A*x
72-
# eqs = vec(D.(x) .~ right)
65+
# Linear transformation to diagonal system
66+
@independent_variables t
67+
@variables x(t)[1:3]
68+
x = reshape(x, 3, 1)
69+
D = Differential(t)
70+
A = [0. -1. 0.; -0.5 0.5 0.; 0. 0. -1.]
71+
right = A*x
72+
eqs = vec(D.(x) .~ right)
7373

74-
# tspan = (0., 10.)
75-
# u0 = [x[1] => 1.0, x[2] => 2.0, x[3] => -1.0]
74+
tspan = (0., 10.)
75+
u0 = [x[1] => 1.0, x[2] => 2.0, x[3] => -1.0]
7676

77-
# @mtkcompile sys = System(eqs, t; defaults=u0)
78-
# prob = ODEProblem(sys,[],tspan)
79-
# sol = solve(prob, Tsit5())
77+
@mtkcompile sys = System(eqs, t; defaults=u0)
78+
prob = ODEProblem(sys,[],tspan)
79+
sol = solve(prob, Tsit5())
8080

81-
# T = eigen(A).vectors
82-
# T_inv = inv(T)
81+
T = eigen(A).vectors
82+
T_inv = inv(T)
8383

84-
# @variables z(t)[1:3]
85-
# z = reshape(z, 3, 1)
86-
# forward_subs = vec(T_inv*x .=> z)
87-
# backward_subs = vec(x .=> T*z)
84+
@variables z(t)[1:3]
85+
z = reshape(z, 3, 1)
86+
forward_subs = vec(T_inv*x .=> z)
87+
backward_subs = vec(x .=> T*z)
8888

89-
# new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true)
89+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true)
9090

91-
# new_prob = ODEProblem(new_sys, [], tspan)
92-
# new_sol = solve(new_prob, Tsit5())
91+
new_prob = ODEProblem(new_sys, [], tspan)
92+
new_sol = solve(new_prob, Tsit5())
9393

94-
# # test RHS
95-
# new_rhs = [eq.rhs for eq in equations(new_sys)]
96-
# new_A = Symbolics.value.(Symbolics.jacobian(new_rhs, z))
97-
# A = diagm(eigen(A).values)
98-
# A = sortslices(A, dims=1)
99-
# new_A = sortslices(new_A, dims=1)
100-
# @test isapprox(A, new_A, rtol = 1e-10)
101-
# @test isapprox( new_sol[x[1],end], sol[x[1],end], rtol=1e-4)
94+
# test RHS
95+
new_rhs = [eq.rhs for eq in equations(new_sys)]
96+
new_A = Symbolics.value.(Symbolics.jacobian(new_rhs, z))
97+
A = diagm(eigen(A).values)
98+
A = sortslices(A, dims=1)
99+
new_A = sortslices(new_A, dims=1)
100+
@test isapprox(A, new_A, rtol = 1e-10)
101+
@test isapprox( new_sol[x[1],end], sol[x[1],end], rtol=1e-4)
102102

103103
# Change of variables for sde
104104
noise_eqs = ModelingToolkit.get_noise_eqs
@@ -115,7 +115,7 @@ def = [x=>0., μ => 2., σ=>1.]
115115
@mtkcompile sys = System(eqs, t; defaults=def)
116116
forward_subs = [log(x) => y]
117117
backward_subs = [x => exp(y)]
118-
new_sys = change_of_variable_SDE(sys, t, forward_subs, backward_subs)
118+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
119119
@test equations(new_sys)[1] == (D(y) ~ μ - 1/2*σ^2)
120120
@test noise_eqs(new_sys)[1] === value(σ)
121121

@@ -130,7 +130,7 @@ def = [x=>0., y=> 0., u=>0., μ => 2., σ=>1., α=>3.]
130130
@mtkcompile sys = System(eqs, t; defaults=def)
131131
forward_subs = [log(x) => z, y^2 => w, log(u) => v]
132132
backward_subs = [x => exp(z), y => w^.5, u => exp(v)]
133-
new_sys = change_of_variable_SDE(sys, t, forward_subs, backward_subs)
133+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
134134
@test equations(new_sys)[1] == (D(z) ~ μ - 1/2*σ^2)
135135
@test equations(new_sys)[2] == (D(w) ~ α^2)
136136
@test equations(new_sys)[3] == (D(v) ~ μ - 1/2*^2 + σ^2))

0 commit comments

Comments
 (0)