Skip to content

Commit e0bb5d2

Browse files
author
fchen121
committed
Implement and fix change of variable for ODE
1 parent 24ebd99 commit e0bb5d2

File tree

3 files changed

+82
-55
lines changed

3 files changed

+82
-55
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
4444
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
4545
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4646
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
47+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
4748
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
49+
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
4850
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4951
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
5052
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -59,9 +61,11 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
5961
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
6062
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
6163
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
64+
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
6265
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
6366
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
6467
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
68+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6569
URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
6670
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
6771
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
@@ -159,6 +163,7 @@ StochasticDiffEq = "6.72.1"
159163
SymbolicIndexingInterface = "0.3.39"
160164
SymbolicUtils = "3.26.1"
161165
Symbolics = "6.40"
166+
Test = "1.11.0"
162167
URIs = "1"
163168
UnPack = "0.1, 1.0"
164169
Unitful = "1.1"

src/systems/diffeqs/basic_transformations.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ new_sol = solve(new_prob, Tsit5())
9494
```
9595
9696
"""
97-
function changeofvariables(sys::System, forward_subs, backward_subs; simplify=false, t0=missing)
98-
t = independent_variable(sys)
97+
function changeofvariables(sys::System, iv, forward_subs, backward_subs; simplify=true, t0=missing)
98+
t = iv
9999

100100
old_vars = first.(backward_subs)
101101
new_vars = last.(forward_subs)
102-
kept_vars = setdiff(states(sys), old_vars)
103-
rhs = [eq.rhs for eq in equations(sys)]
102+
# kept_vars = setdiff(states(sys), old_vars)
103+
# rhs = [eq.rhs for eq in equations(sys)]
104104

105105
# use: dz/dt = ∂z/∂x dx/dt + ∂z/∂t
106106
dzdt = Symbolics.derivative( first.(forward_subs), t )
@@ -120,36 +120,44 @@ function changeofvariables(sys::System, forward_subs, backward_subs; simplify=fa
120120
defs = get_defaults(sys)
121121
new_defs = Dict()
122122
for f_sub in forward_subs
123-
#TODO call value(...)?
124123
ex = substitute(first(f_sub), defs)
125124
if !ismissing(t0)
126125
ex = substitute(ex, t => t0)
127126
end
128127
new_defs[last(f_sub)] = ex
129128
end
130-
return ODESystem(new_eqs;
129+
for para in parameters(sys)
130+
if haskey(defs, para)
131+
new_defs[para] = defs[para]
132+
end
133+
end
134+
@named new_sys = System(new_eqs, t;
131135
defaults=new_defs,
132136
observed=vcat(observed(sys),first.(backward_subs) .~ last.(backward_subs))
133137
)
138+
if simplify
139+
return mtkcompile(new_sys)
140+
end
141+
return new_sys
134142
end
135143

136-
function change_of_variable_SDE(sys::System, forward_subs, backward_subs, iv; simplify=false, t0=missing)
137-
t = independent_variable(sys)
144+
function change_of_variable_SDE(sys::System, iv, nvs, forward_subs, backward_subs; simplify=false, t0=missing)
145+
t = iv
138146

139147
old_vars = first.(backward_subs)
140148
new_vars = last.(forward_subs)
141149

142150
# use: f = Y(t, X)
143151
# use: dY = (∂f/∂t + μ∂f/∂x + (1/2)*σ^2*∂2f/∂x2)dt + σ∂f/∂xdW
144152
old_eqs = equations(sys)
145-
old_noise = get_noiseeqs(sys)
153+
old_noise = ModelingToolkit.get_noise_eqs(sys)
146154
∂f∂t = Symbolics.derivative( first.(forward_subs), t )
147-
∂f∂x = Symbolics.derivative( first.(forward_subs), old_vars )
155+
∂f∂x = [Symbolics.derivative( first(f_sub), old_var )]
148156
∂2f∂x2 = Symbolics.derivative( ∂f∂x, old_vars )
149157
new_eqs = Equation[]
150158

151-
for (new_var, eq, noise, first, second, third) in zip(new_vars, old_eqs, old_noise, ∂f∂t, ∂f∂x, ∂2f∂x2)
152-
ex = first + eq.rhs * second + 1/2 * noise^2 * third
159+
for (new_var, eq, noise, nv, first, second, third) in zip(new_vars, old_eqs, old_noise, nvs, ∂f∂t, ∂f∂x, ∂2f∂x2)
160+
ex = first + eq.rhs * second + 1/2 * noise^2 * third + noise*second*nv
153161
for eqs in old_eqs
154162
ex = substitute(ex, eqs.lhs => eqs.rhs)
155163
end
@@ -160,11 +168,11 @@ function change_of_variable_SDE(sys::System, forward_subs, backward_subs, iv; si
160168
end
161169
push!(new_eqs, Differential(t)(new_var) ~ ex)
162170
end
163-
new_noise = [noise * div for (noise, div) in zip(old_noise, ∂f∂x)]
164171

165-
return SDESystem(new_eqs, new_noise;
172+
@named new_sys = System(new_eqs;
166173
observed=vcat(observed(sys),first.(backward_subs) .~ last.(backward_subs))
167174
)
175+
return new_sys
168176
end
169177

170178
"""

test/changeofvariables.jl

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,110 @@
1-
using ModelingToolkit, OrdinaryDiffEq
1+
using ModelingToolkit, OrdinaryDiffEq, StochasticDiffEq
22
using Test, LinearAlgebra
33

44

55
# Change of variables: z = log(x)
66
# (this implies that x = exp(z) is automatically non-negative)
7-
8-
@parameters t α
7+
@independent_variables t
8+
# @variables z(t)[1:2, 1:2]
9+
# D = Differential(t)
10+
# eqs = [D(D(z)) ~ ones(2, 2)]
11+
# @mtkcompile sys = System(eqs, t)
12+
# @test_nowarn ODEProblem(sys, [z => zeros(2, 2), D(z) => ones(2, 2)], (0.0, 10.0))
13+
14+
@parameters α
915
@variables x(t)
1016
D = Differential(t)
1117
eqs = [D(x) ~ α*x]
1218

1319
tspan = (0., 1.)
14-
u0 = [x => 1.0]
15-
p ==> -0.5]
20+
def = [x => 1.0, α => -0.5]
1621

17-
sys = ODESystem(eqs; defaults=u0)
18-
prob = ODEProblem(sys, [], tspan, p)
22+
@mtkcompile sys = System(eqs, t;defaults=def)
23+
prob = ODEProblem(sys, [], tspan)
1924
sol = solve(prob, Tsit5())
2025

2126
@variables z(t)
2227
forward_subs = [log(x) => z]
2328
backward_subs = [x => exp(z)]
24-
new_sys = changeofvariables(sys, forward_subs, backward_subs)
29+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs)
2530
@test equations(new_sys)[1] == (D(z) ~ α)
2631

27-
new_prob = ODEProblem(new_sys, [], tspan, p)
32+
new_prob = ODEProblem(new_sys, [], tspan)
2833
new_sol = solve(new_prob, Tsit5())
2934

3035
@test isapprox(new_sol[x][end], sol[x][end], atol=1e-4)
3136

3237

3338

3439
# Riccati equation
35-
@parameters t α
40+
@parameters α
3641
@variables x(t)
3742
D = Differential(t)
3843
eqs = [D(x) ~ t^2 + α - x^2]
39-
sys = ODESystem(eqs, defaults=[x=>1.])
44+
def = [x=>1., α => 1.]
45+
@mtkcompile sys = System(eqs, t; defaults=def)
4046

4147
@variables z(t)
4248
forward_subs = [t + α/(x+t) => z ]
4349
backward_subs = [ x => α/(z-t) - t]
4450

45-
new_sys = changeofvariables(sys, forward_subs, backward_subs; simplify=true, t0=0.)
51+
new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true, t0=0.)
4652
# output should be equivalent to
4753
# t^2 + α - z^2 + 2 (but this simplification is not found automatically)
4854

4955
tspan = (0., 1.)
50-
p ==> 1.]
51-
prob = ODEProblem(sys,[],tspan,p)
52-
new_prob = ODEProblem(new_sys,[],tspan,p)
56+
prob = ODEProblem(sys,[],tspan)
57+
new_prob = ODEProblem(new_sys,[],tspan)
5358

5459
sol = solve(prob, Tsit5())
5560
new_sol = solve(new_prob, Tsit5())
5661

5762
@test isapprox(sol[x][end], new_sol[x][end], rtol=1e-4)
5863

5964

60-
# Linear transformation to diagonal system
61-
@parameters t
62-
@variables x[1:3](t)
63-
D = Differential(t)
64-
A = [0. -1. 0.; -0.5 0.5 0.; 0. 0. -1.]
65-
eqs = D.(x) .~ A*x
65+
# # Linear transformation to diagonal system
66+
# @variables x(t)[1:3]
67+
# D = Differential(t)
68+
# A = [0. -1. 0.; -0.5 0.5 0.; 0. 0. -1.]
69+
# right = A.*transpose(x)
70+
# eqs = [D(x[1]) ~ sum(right[1, 1:3]), D(x[2]) ~ sum(right[2, 1:3]), D(x[3]) ~ sum(right[3, 1:3])]
6671

67-
tspan = (0., 10.)
68-
u0 = x .=> [1.0, 2.0, -1.0]
72+
# tspan = (0., 10.)
73+
# u0 = [x[1] => 1.0, x[2] => 2.0, x[3] => -1.0]
6974

70-
sys = ODESystem(eqs; defaults=u0)
71-
prob = ODEProblem(sys,[],tspan)
72-
sol = solve(prob, Tsit5())
75+
# @mtkcompile sys = System(eqs, t; defaults=u0)
76+
# prob = ODEProblem(sys,[],tspan)
77+
# sol = solve(prob, Tsit5())
7378

74-
T = eigen(A).vectors
79+
# T = eigen(A).vectors
7580

76-
@variables z[1:3](t)
77-
forward_subs = T \ x .=> z
78-
backward_subs = x .=> T*z
81+
# @variables z(t)[1:3]
82+
# forward_subs = T \ x .=> z
83+
# backward_subs = x .=> T*z
7984

80-
new_sys = changeofvariables(sys, forward_subs, backward_subs; simplify=true)
85+
# new_sys = changeofvariables(sys, t, forward_subs, backward_subs; simplify=true)
8186

82-
new_prob = ODEProblem(new_sys, [], tspan, p)
83-
new_sol = solve(new_prob, Tsit5())
87+
# new_prob = ODEProblem(new_sys, [], tspan)
88+
# new_sol = solve(new_prob, Tsit5())
8489

85-
# test RHS
86-
new_rhs = [eq.rhs for eq in equations(new_sys)]
87-
new_A = Symbolics.value.(Symbolics.jacobian(new_rhs, z))
88-
@test isapprox(diagm(eigen(A).values), new_A, rtol = 1e-10)
89-
@test isapprox( new_sol[x[1],end], sol[x[1],end], rtol=1e-4)
90+
# # test RHS
91+
# new_rhs = [eq.rhs for eq in equations(new_sys)]
92+
# new_A = Symbolics.value.(Symbolics.jacobian(new_rhs, z))
93+
# @test isapprox(diagm(eigen(A).values), new_A, rtol = 1e-10)
94+
# @test isapprox( new_sol[x[1],end], sol[x[1],end], rtol=1e-4)
9095

9196
# Change of variables for sde
92-
@Browian B
93-
@parameters μ σ
94-
@variables x(t) y(t)
97+
# @independent_variables t
98+
# @brownian B
99+
# @parameters μ σ
100+
# @variables x(t) y(t)
101+
# D = Differential(t)
102+
# eqs = [D(x) ~ μ*x + σ*x*B]
103+
104+
# def = [x=>0., μ => 2., σ=>1.]
105+
# @mtkcompile sys = System(eqs, t; defaults=def)
106+
# forward_subs = [log(x) => y]
107+
# backward_subs = [x => exp(y)]
108+
# new_sys = change_of_variable_SDE(sys, t, [B], forward_subs, backward_subs)
95109

96110

0 commit comments

Comments
 (0)