Skip to content

Commit d20f52c

Browse files
committed
sub constants in function gen
1 parent 6cc9816 commit d20f52c

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
174174
check_operator_variables(eqs, Differential)
175175
check_lhs(eqs, Differential, Set(dvs))
176176
end
177+
178+
# substitute constants in
179+
eqs = map(subs_constants, eqs)
180+
177181
# substitute x(t) by just x
178182
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
179183
[eq.rhs for eq in eqs]

test/odesystem.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,3 +1264,47 @@ end
12641264
fn2, = ModelingToolkit.generate_function(sys2; expression = Val{false})
12651265
@test_nowarn fn2(ones(4), 2ones(6), 4.0)
12661266
end
1267+
1268+
# https://github.com/SciML/ModelingToolkit.jl/issues/2969
1269+
@testset "Constant substitution" begin
1270+
make_model = function (c_a, c_b; name = nothing)
1271+
@mtkmodel ModelA begin
1272+
@constants begin
1273+
a = c_a
1274+
end
1275+
@variables begin
1276+
x(t)
1277+
end
1278+
@equations begin
1279+
D(x) ~ -a * x
1280+
end
1281+
end
1282+
1283+
@mtkmodel ModelB begin
1284+
@constants begin
1285+
b = c_b
1286+
end
1287+
@variables begin
1288+
y(t)
1289+
end
1290+
@components begin
1291+
modela = ModelA()
1292+
end
1293+
@equations begin
1294+
D(y) ~ -b * y
1295+
end
1296+
end
1297+
return ModelB(; name = name)
1298+
end
1299+
c_a, c_b = 1.234, 5.578
1300+
@named sys = make_model(c_a, c_b)
1301+
sys = complete(sys)
1302+
1303+
u0 = [sys.y => -1.0, sys.modela.x => -1.0]
1304+
p = defaults(sys)
1305+
prob = ODEProblem(sys, u0, (0.0, 1.0), p)
1306+
1307+
# evaluate
1308+
u0_v, p_v, _ = ModelingToolkit.get_u0_p(sys, u0, p)
1309+
@test prob.f(u0_v, p_v, 0.0) == [c_b, c_a]
1310+
end

0 commit comments

Comments
 (0)