Skip to content

Commit d32d6c9

Browse files
committed
Cleaning updates
1 parent 157bd94 commit d32d6c9

File tree

4 files changed

+75
-31
lines changed

4 files changed

+75
-31
lines changed

src/SourceCodeMcCormick.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
module SourceCodeMcCormick
3+
14
using ModelingToolkit
25
using SymbolicUtils.Code
36
# TODO: Need to import Assignment and other stuff probably
@@ -8,18 +11,13 @@ abstract type AbstractTransform end
811
# ADD documentation for generic function here
912
function transform_rule end
1013

11-
# struct AssignmentPair
12-
# l::Assignment
13-
# u::Assignment
14-
# end
1514

16-
# struct AssignmentQuad
17-
# l::Assignment
18-
# u::Assignment
19-
# cv::Assignment
20-
# cc::Assignment
21-
# end
15+
export McCormickIntervalTransform
16+
17+
export apply_transform, extract_terms, genvar, genparam, get_name
2218

2319
include(joinpath(@__DIR__, "interval", "interval.jl"))
2420
include(joinpath(@__DIR__, "relaxation", "relaxation.jl"))
25-
include(joinpath(@__DIR__, "transform", "transform.jl"))
21+
include(joinpath(@__DIR__, "transform", "transform.jl"))
22+
23+
end

src/relaxation/relaxation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ end
105105
line_expr(x, xL, xU, zL, zU) = IfElse.ifelse(zU > zL, (zL*(xU - x) + zU*(x - xL))/(xU - xL), zU)
106106

107107
# A symbolic way of computing the mid of three numbers (returns IfElse block)
108-
mid_expr(x, y, z) = IfElse.ifelse((x < y) && (y < z), y, IfElse.ifelse((z < y) && (y < x), y,
109-
IfElse.ifelse((y < x) && (x < z), x, IfElse.ifelse((z < x) && (x < y), x, z))))
108+
mid_expr(a, b, c) = IfElse.ifelse((a < b) && (b < c), y, IfElse.ifelse((c < b) && (b < a), b,
109+
IfElse.ifelse((b < a) && (a < c), x, IfElse.ifelse((c < a) && (a < b), a, c))))
110110

111111
include(joinpath(@__DIR__, "rules.jl"))

src/transform/transform.jl

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ end
148148
function apply_transform(transform::T, prob::ODESystem) where T<:AbstractTransform
149149

150150
# Factorize all model equations to generate a new set of equations
151+
152+
genparam(get_name(prob.iv.val))
153+
151154
equations = Equation[]
152155
for eqn in prob.eqs
153156
current = length(equations)
@@ -178,33 +181,37 @@ function apply_transform(transform::T, prob::ODESystem) where T<:AbstractTransfo
178181
end
179182
end
180183

181-
println("")
182-
println("Old equations:")
183-
for i in prob.eqs
184-
println(i)
185-
end
186-
println("")
187-
println("New equations:")
188-
for i in new_equations
189-
println(i)
190-
end
191-
println("")
184+
# println("")
185+
# println("Old equations:")
186+
# for i in prob.eqs
187+
# println(i)
188+
# end
189+
# println("")
190+
# println("New equations:")
191+
# for i in new_equations
192+
# println(i)
193+
# end
194+
# println("")
192195

193196
# Copy model start points to the newly transformed variables
194197
var_defaults, param_defaults = translate_initial_conditions(transform, prob, new_equations)
198+
# for i in param_defaults
199+
# print(i)
200+
# print("\n")
201+
# end
202+
# for i in var_defaults
203+
# print(i)
204+
# print("\n")
205+
# end
206+
195207

196208
# Use the transformed equations and new start points to generate a new ODE system
197-
@named new_sys = ODESystem(new_equations, default_u0=var_defaults, default_p=param_defaults)
198-
209+
# @named new_sys = ODESystem(new_equations, default_u0=var_defaults, default_p=param_defaults)
199210

200-
# Form ODE system from new equations
201-
# CSE - MTK.structural_simplify()
211+
@named new_sys = ODESystem(new_equations, defaults=merge(var_defaults, param_defaults))
202212

203213
#Extract RHS and evaluate at a point, make that a script that tests any functions.
204214
# RHS this function, reference the correct thing, go
205215

206-
# Figure out a way to give the new ODE system the proper parameters, variables, etc.
207-
208-
209216
return new_sys
210217
end

src/transform/utilities.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ function extract_terms(eqs::Vector{Equation})
9090
ModelingToolkit.collect_vars!(allstates, ps, eq.lhs, iv)
9191
ModelingToolkit.collect_vars!(allstates, ps, eq.rhs, iv)
9292
end
93+
9394
return allstates, ps
9495
end
9596

@@ -137,3 +138,41 @@ function set_bounds(sys::ODESystem, terms::Vector{Num}, bounds::Vector{Tuple{Flo
137138
end
138139
return sys
139140
end
141+
142+
function get_cvcc_start_dict(sys::ODESystem, term::Num, start_point::Float64)
143+
base_name = get_name(Symbolics.value(term))
144+
name_cv = String(base_name)*"_"*"cv"
145+
name_cc = String(base_name)*"_"*"cc"
146+
147+
model_terms = Vector{Union{Term,Sym}}()
148+
for i in sys.states
149+
push!(model_terms, Symbolics.value(i))
150+
end
151+
for i in sys.ps
152+
push!(model_terms, Symbolics.value(i))
153+
end
154+
real_cv = nothing
155+
real_cc = nothing
156+
for i in model_terms
157+
if String(get_name(i))==name_cv
158+
real_cv = i
159+
elseif String(get_name(i))==name_cc
160+
real_cc = i
161+
end
162+
end
163+
164+
new_dict = copy(sys.defaults)
165+
if real_cv in keys(new_dict)
166+
delete!(new_dict, real_cv)
167+
new_dict[real_cv] = start_point
168+
else
169+
new_dict[real_cv] = start_point
170+
end
171+
if real_cc in keys(new_dict)
172+
delete!(new_dict, real_cc)
173+
new_dict[real_cc] = start_point
174+
else
175+
new_dict[real_cc] = start_point
176+
end
177+
return new_dict
178+
end

0 commit comments

Comments
 (0)