Skip to content

Commit 9130dbc

Browse files
committed
Continued updates to apply_transform
1 parent 38e1fba commit 9130dbc

File tree

4 files changed

+147
-40
lines changed

4 files changed

+147
-40
lines changed

src/interval/interval.jl

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,76 @@ Rules for constructing interval bounding expressions
55
# Structure used to indicate an overload with intervals is preferable
66
struct IntervalTransform <: AbstractTransform end
77

8-
# Creates names for interval state variables
9-
function var_names(::IntervalTransform, s::String)
10-
sL = Symbol(s*"_lo")
11-
sU = Symbol(s*"_hi")
12-
sL, sU
8+
# Creates names for interval state variables [DEPRECATED]
9+
# function var_names(::IntervalTransform, s::Num)
10+
# if s.val.metadata.value[1]==:variables
11+
# arg_list = Symbol[]
12+
# for i in s.val.arguments
13+
# push!(arg_list, get_name(i))
14+
# end
15+
# sL = genvar(Symbol(string(s.val.f)*"_lo"), arg_list)
16+
# sU = genvar(Symbol(string(s.val.f)*"_hi"), arg_list)
17+
# elseif s.val.metadata.parent.value[1]==:parameters
18+
# sL = genparam(Symbol(string(s.val.name)*"_lo"))
19+
# sU = genparam(Symbol(string(s.val.name)*"_hi"))
20+
# else
21+
# error("Not a variable or a parameter, check type of s")
22+
# end
23+
# return sL, sU
24+
# end
25+
26+
function var_names(::IntervalTransform, s::Term{Real, Base.ImmutableDict{DataType, Any}}) #The variables
27+
arg_list = Symbol[]
28+
for i in s.arguments
29+
push!(arg_list, get_name(i))
30+
end
31+
sL = genvar(Symbol(string(s.f)*"_lo"), arg_list)
32+
sU = genvar(Symbol(string(s.f)*"_hi"), arg_list)
33+
return sL, sU
34+
end
35+
function var_names(::IntervalTransform, s::Term{Real, Nothing}) #Any terms like "Differential"
36+
if length(s.arguments)>1
37+
error("Multiple arguments not supported.")
38+
end
39+
if typeof(s.arguments[1])<:Term #then it has args
40+
args = Symbol[]
41+
for i in s.arguments[1].arguments
42+
push!(args, get_name(i))
43+
end
44+
var = get_name(s.arguments[1])
45+
var_lo = genvar(Symbol(string(var)*"_lo"), args)
46+
var_hi = genvar(Symbol(string(var)*"_hi"), args)
47+
elseif typeof(s.arguments[1])<:Sym #Then it has no args
48+
var_lo = genparam(Symbol(string(s.arguments[1].name)*"_lo"))
49+
var_hi = genparam(Symbol(string(s.arguments[1].name)*"_hi"))
50+
else
51+
error("Type of argument invalid")
52+
end
53+
54+
sL = s.f(var_lo)
55+
sU = s.f(var_hi)
56+
return sL, sU
57+
end
58+
function var_names(::IntervalTransform, s::Sym) #The parameters
59+
sL = genparam(Symbol(string(s.name)*"_lo"))
60+
sU = genparam(Symbol(string(s.name)*"_hi"))
61+
return sL, sU
1362
end
1463

15-
function var_names(::IntervalTransform, s::Number)
16-
sL = s
17-
sU = s
18-
sL, sU
64+
65+
66+
# function var_names(::IntervalTransform, s::Number)
67+
# sL = s
68+
# sU = s
69+
# sL, sU
70+
# end
71+
72+
# Helper functions for navigating SymbolicUtils structures
73+
function get_name(s::Term)
74+
return s.f.name
75+
end
76+
function get_name(s::Sym)
77+
return s.name
1978
end
2079

2180
include(joinpath(@__DIR__, "rules.jl"))

src/transform/factor.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function isfactor(ex::SymbolicUtils.Add)
2626
(iszero(ex.coeff)) && (length(ex.dict)>2) && return false
2727
for (key, val) in ex.dict
2828
~(isone(val)) && return false
29-
~(typeof(key)<:Term) && return false
29+
~(typeof(key)<:Term) && ~(typeof(key)<:Sym) && return false
3030
end
3131
return true
3232
end
@@ -35,18 +35,18 @@ function isfactor(ex::SymbolicUtils.Mul)
3535
(isone(ex.coeff)) && (length(ex.dict)>2) && return false
3636
for (key, val) in ex.dict
3737
~(isone(val)) && return false
38-
~(typeof(key)<:Term) && return false
38+
~(typeof(key)<:Term) && ~(typeof(key)<:Sym) && return false
3939
end
4040
return true
4141
end
4242
function isfactor(ex::SymbolicUtils.Div)
43-
~(typeof(ex.num)<:Term) && ~(typeof(ex.num)<:Real) && return false
44-
~(typeof(ex.den)<:Term) && ~(typeof(ex.num)<:Real) && return false
43+
~(typeof(ex.num)<:Term) && ~(typeof(key)<:Sym) && ~(typeof(ex.num)<:Real) && return false
44+
~(typeof(ex.den)<:Term) && ~(typeof(key)<:Sym) && ~(typeof(ex.num)<:Real) && return false
4545
return true
4646
end
4747
function isfactor(ex::SymbolicUtils.Pow)
48-
~(typeof(ex.base)<:Term) && ~(typeof(ex.base)<:Real) && return false
49-
~(typeof(ex.exp)<:Term) && ~(typeof(ex.exp)<:Real) && return false
48+
~(typeof(ex.base)<:Term) && ~(typeof(key)<:Sym) && ~(typeof(ex.base)<:Real) && return false
49+
~(typeof(ex.exp)<:Term) && ~(typeof(key)<:Sym) && ~(typeof(ex.exp)<:Real) && return false
5050
return true
5151
end
5252

@@ -186,9 +186,9 @@ function factor!(ex::SymbolicUtils.Mul; assignments = Assignment[])
186186

187187
new_terms = Dict{Any, Number}()
188188
for (key, val) in ex.dict
189-
if (typeof(key)<:Term) && isone(val)
189+
if ((typeof(key)<:Term) && isone(val)) || ((typeof(key)<:Sym) && isone(val))
190190
new_terms[key] = val
191-
elseif (typeof(key)<:Term)
191+
elseif (typeof(key)<:Term) || (typeof(key)<:Sym)
192192
index = findall(x -> isequal(x.rhs,key^val), assignments)
193193
if isempty(index)
194194
newsym = gensym(:aux)

src/transform/transform.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,50 +145,61 @@ end
145145
```
146146
147147
=#
148-
function apply_transform(t::T, prob::ODESystem) where T<:AbstractTransform
148+
function apply_transform(transform::T, prob::ODESystem) where T<:AbstractTransform
149149

150+
# Factor out the equations
150151
assignments = Assignment[]
151152
for eqn in prob.eqs
152153
# Flesh out the original RHS
153154
current = length(assignments)
154-
factor!(toexpr(eqn.rhs), assignments=assignments)
155+
factor!(eqn.rhs, assignments=assignments)
155156

156157
# If new equations were added, stick on the original LHS to the last point
157158
# in assignments by stealing its RHS from the last item and taking its place
158159
if length(assignments) > current
159-
push!(assignments, Assignment(toexpr(eqn.lhs), assignments[end].rhs))
160+
push!(assignments, Assignment(eqn.lhs, assignments[end].rhs))
160161
deleteat!(assignments, length(assignments)-1)
161162
end
162163
end
163164

164-
# new_assignments = AssignmentPair[]
165+
# Develop equations for the transforms
165166
new_assignments = Assignment[]
166167
for a in assignments
167-
zn = var_names(t, zstr(a)) #LHS
168-
xn = var_names(t, xstr(a)) #RHS(1)
168+
# Get zn, first. Which is the LHS.
169+
zn = var_names(transform, zstr(a))
170+
xn = var_names(transform, xstr(a))
169171

170-
first = zn[1]
171-
second = zn[2]
172-
173-
# Define the zn's as variables
174-
@variables $first $second
175172
if isone(arity(a))
176-
targs = (t, op(a), zn..., xn...)
173+
targs = (transform, op(a), zn..., xn...)
177174
else
178-
targs = (t, op(a), zn..., xn..., var_names(t, ystr(a))...)
175+
targs = (transform, op(a), zn..., xn..., var_names(transform, ystr(a))...)
179176
end
180-
println("targs: $targs")
181177

182178
# push!(new_assignments, transform_rule(targs...))
183179
new = transform_rule(targs...)
184180
push!(new_assignments, new.l)
185181
push!(new_assignments, new.u)
186182
end
187183

184+
# Combine all transforms into a new set of equations and create a new ODE system
185+
new_eqs = Equation[]
186+
for i in new_assignments
187+
push!(new_eqs, Equation(i.lhs, i.rhs))
188+
end
189+
188190
@named new_sys = ODESystem(new_eqs)
191+
println("Completed.")
192+
193+
println(new_sys)
194+
println(typeof(new_sys))
195+
println(new_sys.eqs)
196+
189197
# Form ODE system from new assignments
190198
# CSE - MTK.structural_simplify()
191199

192200
# Figure out a way to give the new ODE system the proper parameters, variables, etc.
193-
return the_new_ODE_System
201+
202+
# println(new_assignments)
203+
204+
return new_sys
194205
end

src/transform/utilities.jl

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,53 @@ arity(a::SymbolicUtils.Mul) = length(a.dict) + (~isone(a.coeff))
88

99
op(a::Expr) = a.args[1]
1010
op(a::Assignment) = op(a.rhs)
11-
op(a::Number) = "const"
12-
op(a::Symbol) = "const"
11+
op(::Number) = "const"
12+
op(::Symbol) = "const"
13+
op(::SymbolicUtils.Add) = +
14+
op(::SymbolicUtils.Mul) = *
15+
op(::SymbolicUtils.Pow) = ^
16+
op(::SymbolicUtils.Div) = /
17+
18+
xstr(a::Assignment) = sub_1(a.rhs)
19+
ystr(a::Assignment) = sub_2(a.rhs)
20+
zstr(a::Assignment) = a.lhs
21+
22+
function sub_1(a::SymbolicUtils.Add)
23+
sorted_dict = sort(collect(a.dict), by=x->string(x[1]))
24+
return sorted_dict[1].first
25+
end
26+
function sub_2(a::SymbolicUtils.Add)
27+
~(iszero(a.coeff)) && return a.coeff
28+
sorted_dict = sort(collect(a.dict), by=x->string(x[1]))
29+
return sorted_dict[2].first
30+
end
31+
32+
function sub_1(a::SymbolicUtils.Mul)
33+
sorted_dict = sort(collect(a.dict), by=x->string(x[1]))
34+
return sorted_dict[1].first
35+
end
36+
function sub_2(a::SymbolicUtils.Mul)
37+
~(isone(a.coeff)) && return a.coeff
38+
sorted_dict = sort(collect(a.dict), by=x->string(x[1]))
39+
return sorted_dict[2].first
40+
end
41+
42+
# xstr(a::Assignment) = string_or_num(a.rhs.args[2])
43+
# ystr(a::Assignment) = string_or_num(a.rhs.args[3])
44+
# zstr(a::Assignment) = string_or_num(a.lhs)
45+
46+
# string_or_num(a::Num) = a
47+
# string_or_num(a::Expr) = string(a)
48+
# string_or_num(a::Symbol) = string(a)
49+
# string_or_num(a::Number) = a
50+
51+
# # A function to identify whether a term... is.... variable or parameter, hm
52+
# function identify(x::Num)
53+
54+
# end
55+
1356

14-
xstr(a::Assignment) = string_or_num(a.rhs.args[2])
15-
ystr(a::Assignment) = string_or_num(a.rhs.args[3])
16-
zstr(a::Assignment) = string_or_num(a.lhs)
1757

18-
string_or_num(a::Expr) = string(a)
19-
string_or_num(a::Symbol) = string(a)
20-
string_or_num(a::Number) = a
2158

2259
# Uses Symbolics functions to generate a variable as a function of the dependent variables of choice (default: t)
2360
function genvar(a::Symbol, b=:t)

0 commit comments

Comments
 (0)