Skip to content

Commit 45f4efc

Browse files
authored
Merge pull request #185 from JuliaSymbolics/s/printing-update
Some printing improvements
2 parents 65eae5d + 84e5776 commit 45f4efc

File tree

3 files changed

+106
-51
lines changed

3 files changed

+106
-51
lines changed

src/abstractalgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function labels!(dicts, t)
2323
return t
2424
elseif istree(t) && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
2525
tt = arguments(t)
26-
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
26+
return similarterm(t, operation(t), map(x->labels!(dicts, x), tt))
2727
elseif istree(t) && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
2828
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
2929
else

src/types.jl

Lines changed: 99 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -347,58 +347,112 @@ const show_simplified = Ref(false)
347347

348348
Base.show(io::IO, t::Term) = show_term(io, t)
349349

350-
function show_term(io::IO, t)
351-
if get(io, :simplify, show_simplified[])
352-
s = simplify(t)
350+
isnegative(t::Real) = t < 0
351+
function isnegative(t)
352+
if istree(t) && operation(t) === (*)
353+
coeff = first(arguments(t))
354+
return isnegative(coeff)
355+
end
356+
return false
357+
end
358+
359+
setargs(t, args) = Term{symtype(t)}(operation(t), args)
360+
cdrargs(args) = setargs(t, cdr(args))
353361

354-
Base.print(IOContext(io, :simplify=>false), s)
362+
print_arg(io, x::Union{Complex, Rational}) = print(io, "(", x, ")")
363+
print_arg(io, x) = print(io, x)
364+
print_arg(io, f::typeof(^), x) = print_arg(IOContext(io, :paren=>true), x)
365+
function print_arg(io, f, x)
366+
f !== (*) && return print_arg(io, x)
367+
if istree(x) && Base.isbinaryoperator(nameof(operation(x)))
368+
print_arg(IOContext(io, :paren=>true), x)
355369
else
356-
f = operation(t)
357-
args = arguments(t)
358-
fname = nameof(f)
359-
binary = Base.isbinaryoperator(fname)
360-
if binary
361-
get(io, :paren, false) && Base.print(io, "(")
362-
for i = 1:length(args)
363-
length(args) == 1 && Base.print(io, fname)
364-
365-
paren_scalar = args[i] isa Complex || args[i] isa Rational
366-
367-
paren_scalar && Base.print(io, "(")
368-
# Do not put parenthesis if it's a multiplication and not args
369-
# of power
370-
paren = !(istree(args[i]) && operation(args[i]) == (*)) || fname === :^
371-
Base.print(IOContext(io, :paren => paren), args[i])
372-
paren_scalar && Base.print(io, ")")
373-
374-
if i != length(args)
375-
if fname == :*
376-
if i == 1 && args[1] isa Number && !(args[2] isa Number) && !paren_scalar
377-
# skip
378-
# do not show * if it's a scalar times something
379-
else
380-
Base.print(io, "*")
381-
end
382-
else
383-
Base.print(io, fname == :^ ? '^' : " $fname ")
384-
end
385-
end
386-
end
387-
get(io, :paren, false) && Base.print(io, ")")
370+
print_arg(io, x)
371+
end
372+
end
373+
374+
function show_add(io, args)
375+
negs = filter(isnegative, args)
376+
nnegs = filter(!isnegative, args)
377+
for (i, t) in enumerate(nnegs)
378+
i != 1 && print(io, " + ")
379+
print_arg(io, +, t)
380+
end
381+
382+
for (i, t) in enumerate(negs)
383+
if i==1 && isempty(nnegs)
384+
print_arg(io, -, t)
388385
else
389-
if f isa Sym
390-
Base.print(io, nameof(f))
386+
print(io, " - ")
387+
print_arg(io, +, -t)
388+
end
389+
end
390+
end
391+
392+
function show_mul(io, args)
393+
length(args) == 1 && return print_arg(io, *, args[1])
394+
395+
paren_scalar = args[1] isa Complex || args[1] isa Rational
396+
minus = args[1] isa Number && args[1] == -1
397+
unit = args[1] isa Number && args[1] == 1
398+
nostar = !paren_scalar && args[1] isa Number && !(args[2] isa Number)
399+
for (i, t) in enumerate(args)
400+
if i != 1
401+
if i==2 && nostar
391402
else
392-
Base.show(io, f)
393-
end
394-
Base.print(io, "(")
395-
for i=1:length(args)
396-
Base.print(IOContext(io, :paren => false), args[i])
397-
i != length(args) && Base.print(io, ", ")
403+
print(io, "*")
398404
end
399-
Base.print(io, ")")
405+
end
406+
if i == 1 && minus
407+
print(io, "-")
408+
elseif i == 1 && unit
409+
else
410+
print_arg(io, *, t)
400411
end
401412
end
413+
end
414+
415+
function show_call(io, f, args)
416+
fname = nameof(f)
417+
binary = Base.isbinaryoperator(fname)
418+
if binary
419+
for (i, t) in enumerate(args)
420+
i != 1 && print(io, fname == :^ ? fname : " $fname ")
421+
print_arg(io, (^), t)
422+
end
423+
else
424+
if f isa Sym
425+
Base.show_unquoted(io, nameof(f))
426+
else
427+
Base.show(io, f)
428+
end
429+
print(io, "(")
430+
for i=1:length(args)
431+
print(IOContext(io, :paren => false), args[i])
432+
i != length(args) && print(io, ", ")
433+
end
434+
print(io, ")")
435+
end
436+
end
437+
438+
function show_term(io::IO, t)
439+
if get(io, :simplify, show_simplified[])
440+
return print(IOContext(io, :simplify=>false), simplify(t))
441+
end
442+
443+
f = operation(t)
444+
args = arguments(t)
445+
446+
get(io, :paren, false) && print(io, "(")
447+
if f === (+)
448+
show_add(io, args)
449+
elseif f === (*)
450+
show_mul(io, args)
451+
else
452+
show_call(io, f, args)
453+
end
454+
get(io, :paren, false) && print(io, ")")
455+
402456
return nothing
403457
end
404458

test/basics.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,15 @@ end
102102
@testset "printing" begin
103103
@syms a b c
104104
@test repr(a+b) == "a + b"
105-
@test repr(-a) == "-1a"
106-
@test repr(-a + 3) == "3 + -1a"
107-
@test repr(-(a + b)) == "-1a + -1b"
105+
@test repr(-a) == "-a"
106+
@test repr(-a + 3) == "3 - a"
107+
@test repr(-(a + b)) == "-a - b"
108108
@test repr((2a)^(-2a)) == "(2a)^(-2a)"
109109
@test repr(1/2a) == "(1//2)*(a^-1)"
110110
@test repr(2/(2*a)) == "a^-1"
111-
@test repr(Term(*, [1, 1])) == "1*1"
112-
@test repr((a + b) - (b + c)) == "a + -1c"
111+
@test repr(Term(*, [1, 1])) == "*1"
112+
@test repr(Term(*, [2, 1])) == "2*1"
113+
@test repr((a + b) - (b + c)) == "a - c"
113114
end
114115

115116
toterm(t) = Term{symtype(t)}(operation(t), arguments(t))

0 commit comments

Comments
 (0)