Skip to content

Commit dbc4e33

Browse files
committed
get tests updated
1 parent 1301e4d commit dbc4e33

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

src/abstractalgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function labels!(dicts, t)
2323
return t
2424
elseif istree(t) && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
2525
tt = arguments(t)
26-
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
26+
return similarterm(t, operation(t), map(x->labels!(dicts, x), tt))
2727
elseif istree(t) && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
2828
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
2929
else

src/types.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -359,25 +359,37 @@ end
359359
setargs(t, args) = Term{symtype(t)}(operation(t), args)
360360
cdrargs(args) = setargs(t, cdr(args))
361361

362-
print_arg(io, n::Union{Complex, Rational}) = print(io, "(", n, ")")
363-
print_arg(io, n) = print(io, n)
362+
print_arg(io, f, n::Union{Complex, Rational}) = print(io, "(", n, ")")
363+
print_arg(io, f::typeof(^), n) = print(IOContext(io, :paren=>true), n)
364+
function print_arg(io, f, n)
365+
f !== (*) && return print(io, n)
366+
if istree(n) && Base.isbinaryoperator(nameof(operation(n)))
367+
print(IOContext(io, :paren=>true), n)
368+
else
369+
print(io, n)
370+
end
371+
end
364372

365373
function show_add(io, args)
366374
negs = filter(isnegative, args)
367375
nnegs = filter(!isnegative, args)
368376
for (i, t) in enumerate(nnegs)
369377
i != 1 && print(io, " + ")
370-
print_arg(io, t)
378+
print_arg(io, +, t)
371379
end
372380

373381
for (i, t) in enumerate(negs)
374-
print(io, " - ")
375-
print_arg(io, -t)
382+
if i==1 && isempty(nnegs)
383+
print_arg(io, -, t)
384+
else
385+
print(io, " - ")
386+
print_arg(io, +, -t)
387+
end
376388
end
377389
end
378390

379391
function show_mul(io, args)
380-
length(args) == 1 && return print_arg(io, args[1])
392+
length(args) == 1 && return print_arg(io, *, args[1])
381393

382394
paren_scalar = args[1] isa Complex || args[1] isa Rational
383395
minus = args[1] isa Number && args[1] == -1
@@ -394,7 +406,7 @@ function show_mul(io, args)
394406
print(io, "-")
395407
elseif i == 1 && unit
396408
else
397-
print_arg(io, t)
409+
print_arg(io, *, t)
398410
end
399411
end
400412
end
@@ -403,12 +415,10 @@ function show_call(io, f, args)
403415
fname = nameof(f)
404416
binary = Base.isbinaryoperator(fname)
405417
if binary
406-
get(io, :paren, false) && print(io, "(")
407418
for (i, t) in enumerate(args)
408-
i != 1 && print(io, " $fname ")
409-
print_arg(io, t)
419+
i != 1 && print(io, fname == :^ ? fname : " $fname ")
420+
print_arg(io, (^), t)
410421
end
411-
get(io, :paren, false) && print(io, ")")
412422
else
413423
if f isa Sym
414424
Base.show_unquoted(io, nameof(f))
@@ -432,13 +442,15 @@ function show_term(io::IO, t)
432442
f = operation(t)
433443
args = arguments(t)
434444

445+
get(io, :paren, false) && print(io, "(")
435446
if f === (+)
436447
show_add(io, args)
437448
elseif f === (*)
438449
show_mul(io, args)
439450
else
440451
show_call(io, f, args)
441452
end
453+
get(io, :paren, false) && print(io, ")")
442454

443455
return nothing
444456
end

test/basics.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,15 @@ end
102102
@testset "printing" begin
103103
@syms a b c
104104
@test repr(a+b) == "a + b"
105-
@test repr(-a) == "-1a"
106-
@test repr(-a + 3) == "3 + -1a"
107-
@test repr(-(a + b)) == "-1a + -1b"
105+
@test repr(-a) == "-a"
106+
@test repr(-a + 3) == "3 - a"
107+
@test repr(-(a + b)) == "-a - b"
108108
@test repr((2a)^(-2a)) == "(2a)^(-2a)"
109109
@test repr(1/2a) == "(1//2)*(a^-1)"
110110
@test repr(2/(2*a)) == "a^-1"
111-
@test repr(Term(*, [1, 1])) == "1*1"
112-
@test repr((a + b) - (b + c)) == "a + -1c"
111+
@test repr(Term(*, [1, 1])) == "*1"
112+
@test repr(Term(*, [2, 1])) == "2*1"
113+
@test repr((a + b) - (b + c)) == "a - c"
113114
end
114115

115116
toterm(t) = Term{symtype(t)}(operation(t), arguments(t))

0 commit comments

Comments
 (0)