Skip to content

Commit 4473574

Browse files
committed
Add canonicalize kw and fix latexify tests
1 parent 0f9baea commit 4473574

File tree

3 files changed

+33
-18
lines changed

3 files changed

+33
-18
lines changed

src/direct.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -215,21 +215,31 @@ function sparsehessian(O, vars::AbstractVector; simplify = true)
215215
return H
216216
end
217217

218-
function toexpr(O)
219-
canonical, O = canonicalexpr(O)
220-
canonical && return O
218+
"""
219+
toexpr(O::Union{Symbolics,Num,Equation,AbstractArray}; canonicalize=true) -> Expr
220+
221+
Convert `Symbolics` into `Expr`. If `canonicalize`, then we turn exprs like
222+
`x^(-n)` into `inv(x)^n` to avoid type error when evaluating.
223+
"""
224+
function toexpr(O; canonicalize=true)
225+
if canonicalize
226+
canonical, O = canonicalexpr(O)
227+
canonical && return O
228+
else
229+
!istree(O) && return O
230+
end
221231

222232
op = operation(O)
223233
args = arguments(O)
224234
if op isa Differential
225-
return :(derivative($(toexpr(args[1])),$(toexpr(op.x))))
235+
return :(derivative($(toexpr(args[1]; canonicalize=canonicalize)),$(toexpr(op.x; canonicalize=canonicalize))))
226236
elseif op isa Sym
227237
isempty(args) && return nameof(op)
228-
return Expr(:call, toexpr(op), toexpr.(args)...)
238+
return Expr(:call, toexpr(op; canonicalize=canonicalize), toexpr(args; canonicalize=canonicalize)...)
229239
end
230-
return Expr(:call, op, toexpr.(args)...)
240+
return Expr(:call, op, toexpr(args; canonicalize=canonicalize)...)
231241
end
232-
toexpr(s::Sym) = nameof(s)
242+
toexpr(s::Sym; kw...) = nameof(s)
233243

234244
"""
235245
canonicalexpr(O) -> (canonical::Bool, expr)
@@ -242,18 +252,23 @@ function canonicalexpr(O)
242252
args = arguments(O)
243253
if op === (^)
244254
if length(args) == 2 && args[2] isa Number && args[2] < 0
245-
expr = Expr(:call, ^, Expr(:call, inv, toexpr(args[1])), -args[2])
255+
ex = toexpr(args[1])
256+
if args[2] == -1
257+
expr = Expr(:call, inv, ex)
258+
else
259+
expr = Expr(:call, ^, Expr(:call, inv, ex), -args[2])
260+
end
246261
return true, expr
247262
end
248263
end
249264
return false, O
250265
end
251266

252-
function toexpr(eq::Equation)
253-
Expr(:(=), toexpr(eq.lhs), toexpr(eq.rhs))
267+
function toexpr(eq::Equation; kw...)
268+
Expr(:(=), toexpr(eq.lhs; kw...), toexpr(eq.rhs; kw...))
254269
end
255270

256-
toexpr(eq::AbstractArray) = toexpr.(eq)
257-
toexpr(x::Integer) = x
258-
toexpr(x::AbstractFloat) = x
259-
toexpr(x::Num) = toexpr(value(x))
271+
toexpr(eqs::AbstractArray; kw...) = map(eq->toexpr(eq; kw...), eqs)
272+
toexpr(x::Integer; kw...) = x
273+
toexpr(x::AbstractFloat; kw...) = x
274+
toexpr(x::Num; kw...) = toexpr(value(x); kw...)

src/latexify_recipes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...)
1111
# that latexify can deal with
1212

1313
rhs = getfield.(eqs, :rhs)
14-
rhs = prettify_expr.(toexpr.(rhs))
14+
rhs = prettify_expr.(toexpr(rhs; canonicalize=false))
1515
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
1616
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
1717
rhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs]
1818

1919
lhs = getfield.(eqs, :lhs)
20-
lhs = prettify_expr.(toexpr.(lhs))
20+
lhs = prettify_expr.(toexpr(lhs; canonicalize=false))
2121
lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs]
2222
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs]
2323
lhs = [postwalk(x -> x isa Expr && x.args[1] == :derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs]

test/latexify.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ eqs = [D(x) ~ σ*(y-x)*D(x-y)/D(z),
3030
# Latexify.@generate_test latexify(eqs)
3131
@test latexify(eqs) == replace(
3232
raw"\begin{align}
33-
\frac{dx(t)}{dt} =& \frac{d\left(x\left( t \right) -1 \cdot y\left( t \right)\right)}{dt} \left( \mathrm{inv}\left( \frac{dz(t)}{dt} \right) \right)^{1} \sigma \left( y\left( t \right) -1 x\left( t \right) \right) \\
33+
\frac{dx(t)}{dt} =& \left( y\left( t \right) -1 x\left( t \right) \right) \left( \frac{dz(t)}{dt} \right)^{-1} \sigma \frac{d\left(x\left( t \right) -1 \cdot y\left( t \right)\right)}{dt} \\
3434
0 =& -1 y\left( t \right) + 0.1 x\left( t \right) \sigma \left( -1 z\left( t \right) + \rho \right) \\
3535
\frac{dz(t)}{dt} =& x\left( t \right) \left( y\left( t \right) \right)^{\frac{2}{3}} -1 z\left( t \right) \beta
3636
\end{align}
@@ -71,6 +71,6 @@ eqs = [D(x) ~ (1+cos(t))/(1+2*x)]
7171

7272
@test latexify(eqs) == replace(
7373
raw"\begin{align}
74-
\frac{dx(t)}{dt} =& \left( 1 + \cos\left( t \right) \right) \left( \mathrm{inv}\left( 1 + 2 x\left( t \right) \right) \right)^{1}
74+
\frac{dx(t)}{dt} =& \left( 1 + \cos\left( t \right) \right) \left( 1 + 2 x\left( t \right) \right)^{-1}
7575
\end{align}
7676
", "\r\n"=>"\n")

0 commit comments

Comments
 (0)