Skip to content

Commit 41130c8

Browse files
committed
Fix latexify
1 parent 131d1da commit 41130c8

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

src/direct.jl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,18 +231,36 @@ function sparsehessian(O, vars::AbstractVector; simplify=false)
231231
return H
232232
end
233233

234-
"""
235-
canonicalexpr(O) -> (canonical::Bool, expr)
234+
# `_toexpr` is only used for latexify
235+
function _toexpr(O; canonicalize=true)
236+
if canonicalize
237+
canonical, O = canonicalexpr(O)
238+
canonical && return O
239+
else
240+
!istree(O) && return O
241+
end
242+
243+
op = operation(O)
244+
args = arguments(O)
245+
if op isa Differential
246+
ex = _toexpr(args[1]; canonicalize=canonicalize)
247+
wrt = _toexpr(op.x; canonicalize=canonicalize)
248+
return :(_derivative($ex, $wrt))
249+
elseif op isa Sym
250+
isempty(args) && return nameof(op)
251+
return Expr(:call, _toexpr(op; canonicalize=canonicalize), _toexpr(args; canonicalize=canonicalize)...)
252+
end
253+
return Expr(:call, op, _toexpr(args; canonicalize=canonicalize)...)
254+
end
255+
_toexpr(s::Sym; kw...) = nameof(s)
236256

237-
Canonicalize `O`. Return `canonical` if `expr` is valid code to generate.
238-
"""
239257
function canonicalexpr(O)
240258
!istree(O) && return true, O
241259
op = operation(O)
242260
args = arguments(O)
243261
if op === (^)
244262
if length(args) == 2 && args[2] isa Number && args[2] < 0
245-
ex = toexpr(args[1])
263+
ex = _toexpr(args[1])
246264
if args[2] == -1
247265
expr = Expr(:call, inv, ex)
248266
else
@@ -254,11 +272,15 @@ function canonicalexpr(O)
254272
return false, O
255273
end
256274

257-
function toexpr(eq::Equation; kw...)
258-
Expr(:(=), toexpr(eq.lhs; kw...), toexpr(eq.rhs; kw...))
259-
end
275+
for fun in [:toexpr, :_toexpr]
276+
@eval begin
277+
function $fun(eq::Equation; kw...)
278+
Expr(:(=), $fun(eq.lhs; kw...), $fun(eq.rhs; kw...))
279+
end
260280

261-
toexpr(eqs::AbstractArray; kw...) = map(eq->toexpr(eq; kw...), eqs)
262-
toexpr(x::Integer; kw...) = x
263-
toexpr(x::AbstractFloat; kw...) = x
264-
toexpr(x::Num; kw...) = toexpr(value(x); kw...)
281+
$fun(eqs::AbstractArray; kw...) = map(eq->$fun(eq; kw...), eqs)
282+
$fun(x::Integer; kw...) = x
283+
$fun(x::AbstractFloat; kw...) = x
284+
$fun(x::Num; kw...) = $fun(value(x); kw...)
285+
end
286+
end

src/latexify_recipes.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ prettify_expr(expr::Expr) = Expr(expr.head, prettify_expr.(expr.args)...)
1111
# that latexify can deal with
1212

1313
rhs = getfield.(eqs, :rhs)
14-
rhs = prettify_expr.(toexpr(rhs; canonicalize=false))
14+
rhs = prettify_expr.(_toexpr(rhs; canonicalize=false))
1515
rhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in rhs]
1616
rhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in rhs]
1717
rhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in rhs]
1818

1919
lhs = getfield.(eqs, :lhs)
20-
lhs = prettify_expr.(toexpr(lhs; canonicalize=false))
20+
lhs = prettify_expr.(_toexpr(lhs; canonicalize=false))
2121
lhs = [postwalk(x -> x isa Expr && length(x.args) == 1 ? x.args[1] : x, eq) for eq in lhs]
2222
lhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative && length(x.args[2].args) == 2 ? :($(Symbol(:d, x.args[2]))/($(Symbol(:d, x.args[2].args[2])))) : x, eq) for eq in lhs]
2323
lhs = [postwalk(x -> x isa Expr && x.args[1] == :_derivative ? "\\frac{d\\left($(Latexify.latexraw(x.args[2]))\\right)}{d$(Latexify.latexraw(x.args[3]))}" : x, eq) for eq in lhs]

0 commit comments

Comments
 (0)