Skip to content

Commit 4a6e57a

Browse files
committed
Add canonicalexpr to generate safe expressions for evaluations
1 parent b2a5f9d commit 4a6e57a

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

src/direct.jl

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,21 +216,38 @@ function sparsehessian(O, vars::AbstractVector; simplify = true)
216216
end
217217

218218
function toexpr(O)
219-
!istree(O) && return O
220-
if isa(operation(O), Differential)
221-
return :(derivative($(toexpr(arguments(O)[1])),$(toexpr(operation(O).x))))
222-
elseif isa(operation(O), Sym)
223-
isempty(arguments(O)) && return operation(O).name
224-
return Expr(:call, toexpr(operation(O)), toexpr.(arguments(O))...)
219+
canonical, O = canonicalexpr(O)
220+
canonical && return O
221+
222+
op = operation(O)
223+
args = arguments(O)
224+
if op isa Differential
225+
return :(derivative($(toexpr(args[1])),$(toexpr(op.x))))
226+
elseif op isa Sym
227+
isempty(args) && return nameof(op)
228+
return Expr(:call, toexpr(op), toexpr.(args)...)
225229
end
226-
if operation(O) === (^)
227-
if length(arguments(O)) > 1 && arguments(O)[2] isa Number && arguments(O)[2] < 0
228-
return Expr(:call, ^, Expr(:call, inv, toexpr(arguments(O)[1])), -(arguments(O)[2]))
230+
return Expr(:call, op, toexpr.(args)...)
231+
end
232+
toexpr(s::Sym) = nameof(s)
233+
234+
"""
235+
canonicalexpr(O) -> (canonical::Bool, expr)
236+
237+
Canonicalize `O`. Return `canonical` if `expr` is valid code to generate.
238+
"""
239+
function canonicalexpr(O)
240+
!istree(O) && return true, O
241+
op = operation(O)
242+
args = arguments(O)
243+
if op === (^)
244+
if length(args) == 2 && args[2] isa Number && args[2] < 0
245+
expr = Expr(:call, ^, Expr(:call, inv, toexpr(args[1])), -args[2])
246+
return true, expr
229247
end
230248
end
231-
return Expr(:call, operation(O), toexpr.(arguments(O))...)
249+
return false, O
232250
end
233-
toexpr(s::Sym) = nameof(s)
234251

235252
function toexpr(eq::Equation)
236253
Expr(:(=), toexpr(eq.lhs), toexpr(eq.rhs))

src/utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,15 @@ function states_to_sym(states::Set)
163163
if O isa Equation
164164
Expr(:(=), _states_to_sym(O.lhs), _states_to_sym(O.rhs))
165165
elseif istree(O)
166-
if isa(operation(O), Sym)
166+
op = operation(O)
167+
args = arguments(O)
168+
if op isa Sym
167169
O in states && return tosymbol(O)
168170
# dependent variables
169-
return build_expr(:call, Any[operation(O).name; _states_to_sym.(arguments(O))])
171+
return build_expr(:call, Any[nameof(op); _states_to_sym.(args)])
170172
else
171-
return build_expr(:call, Any[operation(O); _states_to_sym.(arguments(O))])
173+
canonical, O = canonicalexpr(O)
174+
return canonical ? O : build_expr(:call, Any[op; _states_to_sym.(args)])
172175
end
173176
elseif O isa Num
174177
return _states_to_sym(value(O))

test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using SafeTestsets, Test
22

3-
#=
43
@safetestset "Parsing Test" begin include("variable_parsing.jl") end
54
@safetestset "Differentiation Test" begin include("derivatives.jl") end
65
@safetestset "Simplify Test" begin include("simplify.jl") end
@@ -33,5 +32,4 @@ using SafeTestsets, Test
3332
@safetestset "Variable Utils Test" begin include("variable_utils.jl") end
3433
println("Last test requires gcc available in the path!")
3534
@safetestset "C Compilation Test" begin include("ccompile.jl") end
36-
=#
3735
@safetestset "Latexify recipes Test" begin include("latexify.jl") end

0 commit comments

Comments
 (0)