Skip to content

Commit 157bd94

Browse files
committed
Minor rules/values update, and working example
1 parent cf75525 commit 157bd94

File tree

4 files changed

+51
-6
lines changed

4 files changed

+51
-6
lines changed

src/interval/interval.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function var_names(::IntervalTransform, s::Term{Real, Base.ImmutableDict{DataTyp
1717
sL = genvar(Symbol(string(get_name(s))*"_lo"), arg_list)
1818
sU = genvar(Symbol(string(get_name(s))*"_hi"), arg_list)
1919
end
20-
return sL, sU
20+
return Symbolics.value(sL), Symbolics.value(sU)
2121
end
2222
function var_names(::IntervalTransform, s::Real)
2323
return s, s
@@ -43,12 +43,12 @@ function var_names(::IntervalTransform, s::Term{Real, Nothing}) #Any terms like
4343

4444
sL = s.f(var_lo)
4545
sU = s.f(var_hi)
46-
return sL, sU
46+
return Symbolics.value(sL), Symbolics.value(sU)
4747
end
4848
function var_names(::IntervalTransform, s::Sym) #The parameters
4949
sL = genparam(Symbol(string(get_name(s))*"_lo"))
5050
sU = genparam(Symbol(string(get_name(s))*"_hi"))
51-
return sL, sU
51+
return Symbolics.value(sL), Symbolics.value(sU)
5252
end
5353

5454
function translate_initial_conditions(::IntervalTransform, prob::ODESystem, new_eqs::Vector{Equation})

src/relaxation/relaxation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function var_names(::McCormickTransform, s::Term{Real, Base.ImmutableDict{DataTy
1313
scv = genvar(Symbol(string(get_name(s))*"_cv"), arg_list)
1414
scc = genvar(Symbol(string(get_name(s))*"_cc"), arg_list)
1515
end
16-
return scv, scc
16+
return Symbolics.value(scv), Symbolics.value(scc)
1717
end
1818
function var_names(::McCormickTransform, s::Real)
1919
return s, s
@@ -39,12 +39,12 @@ function var_names(::McCormickTransform, s::Term{Real, Nothing}) #Any terms like
3939

4040
scv = s.f(var_cv)
4141
scc = s.f(var_cc)
42-
return scv, scc
42+
return Symbolics.value(scv), Symbolics.value(scc)
4343
end
4444
function var_names(::McCormickTransform, s::Sym) #The parameters
4545
scv = genparam(Symbol(string(get_name(s))*"_cv"))
4646
scc = genparam(Symbol(string(get_name(s))*"_cc"))
47-
return scv, scc
47+
return Symbolics.value(scv), Symbolics.value(scc)
4848
end
4949

5050
function translate_initial_conditions(::McCormickTransform, prob::ODESystem, new_eqs::Vector{Equation})

src/relaxation/rules.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ function transform_rule(::McCormickTransform, ::typeof(*), zL, zU, zcv, zcc, xL,
142142

143143
end
144144

145+
function transform_rule(::McCormickTransform, ::typeof(min), zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL, yU, ycv, ycc)
146+
rcv = Equation(zcv, min(xcv, ycv))
147+
rcc = Equation(zcc, min(xcc, ycc))
148+
return rcv, rcc
149+
end
150+
151+
function transform_rule(::McCormickTransform, ::typeof(max), zL, zU, zcv, zcc, xL, xU, xcv, xcc, yL, yU, ycv, ycc)
152+
rcv = Equation(zcv, min(xcv, ycv))
153+
rcc = Equation(zcc, min(xcc, ycc))
154+
return rcv, rcc
155+
end
145156

146157
#=
147158
TODO: Add other operators. It's probably helpful to break the McCormick overload and McCormick + Interval Outputs

works.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
# This case WORKED at one point, so it should work whenever you're testing something.
3+
using ModelingToolkit, OrdinaryDiffEq
4+
using DifferentialEquations: solve
5+
using DiffEqGPU
6+
using BenchmarkTools
7+
@parameters p[1:2] t
8+
@variables x[1:2](t)
9+
D = Differential(t)
10+
x0 = [1.0; 0.0]
11+
tspan = (0.0, 35.0)
12+
p_start = [0.020; 0.025]
13+
eqns = [D(x[1]) ~ -p[1]*x[1] + p[2]*x[2],
14+
D(x[2]) ~ p[1]*x[1] - p[2]*x[2]]
15+
16+
@named test_works = ODESystem(eqns, t, x, p, default_u0=Dict( x[i] .=> x0[i] for i in 1:2), default_p = Dict( p[i] .=> p_start[i] for i in 1:2))
17+
tested_MC_works = apply_transform(McCormickIntervalTransform(), test_works) #This one breaks
18+
set_bounds(tested_MC_works, p[1], (0.015, 0.025))
19+
ode_problem_works = eval(ODEProblemExpr(structural_simplify(tested_MC_works), tested_MC_works.defaults, tspan))
20+
21+
# Can create u's and p's to interpolate into the prob_func definition like this:
22+
p_order = tested_MC_works.ps.value
23+
u_order = structural_simplify(tested_MC_works).states.value
24+
u_list = []
25+
p_list = []
26+
for i = 1:11
27+
push!(u_list, [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0+(i-1)*0.1])
28+
push!(p_list, [0.01, 0.01, 0.01, 0.01, 0.01+(i-1)*0.001, 0.01-(i-1)*0.001, 0.01, 0.01])
29+
end
30+
prob_func_works = (prob, i, repeat) -> remake(prob, u0=:($u_list)[i], p = :($p_list)[i])
31+
32+
# And this ensemble problem works with EnsembleGPUArray, currently.
33+
ensemble_works = EnsembleProblem(ode_problem_works, prob_func=prob_func_works)
34+
@benchmark solve(ensemble_works, Tsit5(), EnsembleGPUArray(), trajectories=11)

0 commit comments

Comments
 (0)