Skip to content

Commit 1301e4d

Browse files
committed
print -1X as -X, X + -Y as X - Y etc.
1 parent 65eae5d commit 1301e4d

File tree

1 file changed

+88
-47
lines changed

1 file changed

+88
-47
lines changed

src/types.jl

Lines changed: 88 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -347,58 +347,99 @@ const show_simplified = Ref(false)
347347

348348
Base.show(io::IO, t::Term) = show_term(io, t)
349349

350-
function show_term(io::IO, t)
351-
if get(io, :simplify, show_simplified[])
352-
s = simplify(t)
350+
isnegative(t::Number) = t < 0
351+
function isnegative(t)
352+
if istree(t) && operation(t) === (*)
353+
coeff = first(arguments(t))
354+
return isnegative(coeff)
355+
end
356+
return false
357+
end
353358

354-
Base.print(IOContext(io, :simplify=>false), s)
355-
else
356-
f = operation(t)
357-
args = arguments(t)
358-
fname = nameof(f)
359-
binary = Base.isbinaryoperator(fname)
360-
if binary
361-
get(io, :paren, false) && Base.print(io, "(")
362-
for i = 1:length(args)
363-
length(args) == 1 && Base.print(io, fname)
364-
365-
paren_scalar = args[i] isa Complex || args[i] isa Rational
366-
367-
paren_scalar && Base.print(io, "(")
368-
# Do not put parenthesis if it's a multiplication and not args
369-
# of power
370-
paren = !(istree(args[i]) && operation(args[i]) == (*)) || fname === :^
371-
Base.print(IOContext(io, :paren => paren), args[i])
372-
paren_scalar && Base.print(io, ")")
373-
374-
if i != length(args)
375-
if fname == :*
376-
if i == 1 && args[1] isa Number && !(args[2] isa Number) && !paren_scalar
377-
# skip
378-
# do not show * if it's a scalar times something
379-
else
380-
Base.print(io, "*")
381-
end
382-
else
383-
Base.print(io, fname == :^ ? '^' : " $fname ")
384-
end
385-
end
386-
end
387-
get(io, :paren, false) && Base.print(io, ")")
388-
else
389-
if f isa Sym
390-
Base.print(io, nameof(f))
359+
setargs(t, args) = Term{symtype(t)}(operation(t), args)
360+
cdrargs(args) = setargs(t, cdr(args))
361+
362+
print_arg(io, n::Union{Complex, Rational}) = print(io, "(", n, ")")
363+
print_arg(io, n) = print(io, n)
364+
365+
function show_add(io, args)
366+
negs = filter(isnegative, args)
367+
nnegs = filter(!isnegative, args)
368+
for (i, t) in enumerate(nnegs)
369+
i != 1 && print(io, " + ")
370+
print_arg(io, t)
371+
end
372+
373+
for (i, t) in enumerate(negs)
374+
print(io, " - ")
375+
print_arg(io, -t)
376+
end
377+
end
378+
379+
function show_mul(io, args)
380+
length(args) == 1 && return print_arg(io, args[1])
381+
382+
paren_scalar = args[1] isa Complex || args[1] isa Rational
383+
minus = args[1] isa Number && args[1] == -1
384+
unit = args[1] isa Number && args[1] == 1
385+
nostar = !paren_scalar && args[1] isa Number && !(args[2] isa Number)
386+
for (i, t) in enumerate(args)
387+
if i != 1
388+
if i==2 && nostar
391389
else
392-
Base.show(io, f)
390+
print(io, "*")
393391
end
394-
Base.print(io, "(")
395-
for i=1:length(args)
396-
Base.print(IOContext(io, :paren => false), args[i])
397-
i != length(args) && Base.print(io, ", ")
398-
end
399-
Base.print(io, ")")
392+
end
393+
if i == 1 && minus
394+
print(io, "-")
395+
elseif i == 1 && unit
396+
else
397+
print_arg(io, t)
400398
end
401399
end
400+
end
401+
402+
function show_call(io, f, args)
403+
fname = nameof(f)
404+
binary = Base.isbinaryoperator(fname)
405+
if binary
406+
get(io, :paren, false) && print(io, "(")
407+
for (i, t) in enumerate(args)
408+
i != 1 && print(io, " $fname ")
409+
print_arg(io, t)
410+
end
411+
get(io, :paren, false) && print(io, ")")
412+
else
413+
if f isa Sym
414+
Base.show_unquoted(io, nameof(f))
415+
else
416+
Base.show(io, f)
417+
end
418+
print(io, "(")
419+
for i=1:length(args)
420+
print(IOContext(io, :paren => false), args[i])
421+
i != length(args) && print(io, ", ")
422+
end
423+
print(io, ")")
424+
end
425+
end
426+
427+
function show_term(io::IO, t)
428+
if get(io, :simplify, show_simplified[])
429+
return print(IOContext(io, :simplify=>false), simplify(t))
430+
end
431+
432+
f = operation(t)
433+
args = arguments(t)
434+
435+
if f === (+)
436+
show_add(io, args)
437+
elseif f === (*)
438+
show_mul(io, args)
439+
else
440+
show_call(io, f, args)
441+
end
442+
402443
return nothing
403444
end
404445

0 commit comments

Comments
 (0)