Skip to content

Commit 20f777b

Browse files
authored
Merge pull request #164 from JuliaSymbolics/myb/mul
Fix division and printing
2 parents ed67d95 + 85e294b commit 20f777b

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

src/types.jl

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,15 @@ function show_term(io::IO, t)
343343
paren_scalar = args[i] isa Complex || args[i] isa Rational
344344

345345
paren_scalar && Base.print(io, "(")
346-
# Do not put parenthesis if it's a multiplication
347-
paren = !(istree(args[i]) && operation(args[i]) == (*))
346+
# Do not put parenthesis if it's a multiplication and not args
347+
# of power
348+
paren = !(istree(args[i]) && operation(args[i]) == (*)) || fname === :^
348349
Base.print(IOContext(io, :paren => paren), args[i])
349350
paren_scalar && Base.print(io, ")")
350351

351352
if i != length(args)
352353
if fname == :*
353-
if i == 1 && args[1] isa Number && !paren_scalar
354+
if i == 1 && args[1] isa Number && !(args[2] isa Number) && !paren_scalar
354355
# skip
355356
# do not show * if it's a scalar times something
356357
else
@@ -414,7 +415,7 @@ function Add(T, coeff, dict)
414415
return coeff
415416
elseif _iszero(coeff) && length(dict) == 1
416417
k,v = first(dict)
417-
return _isone(v) ? k : Mul(T, makemul(1, v, k)...)
418+
return _isone(v) ? k : Mul(T, makemul(v, k)...)
418419
end
419420

420421
Add{T, typeof(coeff), typeof(dict)}(coeff, dict, Ref{Any}(nothing))
@@ -563,23 +564,21 @@ Base.isequal(a::Mul, b::Mul) = isequal(a.coeff, b.coeff) && isequal(a.dict, b.di
563564

564565
Base.show(io::IO, a::Mul) = show_term(io, a)
565566

566-
function makemul(sign, coeff, xs...; d=sdict())
567+
function makemul(coeff, xs...; d=sdict())
567568
for x in xs
568569
if x isa Pow && x.exp isa Number
569-
d[x.base] = sign * x.exp + get(d, x.base, 0)
570+
d[x.base] = x.exp + get(d, x.base, 0)
570571
elseif x isa Number
571572
coeff *= x
572573
elseif x isa Mul
573574
coeff *= x.coeff
574-
dict = isone(sign) ? x.dict : mapvalues((_,v)->sign*v, x.dict)
575-
d = _merge(+, d, dict, filter=_iszero)
575+
d = _merge(+, d, x.dict, filter=_iszero)
576576
else
577-
k = x
578-
v = sign + get(d, x, 0)
577+
v = 1 + get(d, x, 0)
579578
if _iszero(v)
580-
delete!(d, k)
579+
delete!(d, x)
581580
else
582-
d[k] = v
581+
d[x] = v
583582
end
584583
end
585584
end
@@ -591,19 +590,17 @@ mul_t(a) = promote_symtype(*, symtype(a))
591590

592591
*(a::SN) = a
593592

594-
*(a::SN, b::SN) = Mul(mul_t(a,b), makemul(1, 1, a, b)...)
593+
*(a::SN, b::SN) = Mul(mul_t(a,b), makemul(1, a, b)...)
595594

596595
*(a::Mul, b::Mul) = Mul(mul_t(a, b),
597596
a.coeff * b.coeff,
598597
_merge(+, a.dict, b.dict, filter=_iszero))
599598

600-
*(a::Number, b::SN) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(1,a, b)...)
599+
*(a::Number, b::SN) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(a, b)...)
601600

602-
*(b::SN, a::Number) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(1,a, b)...)
601+
*(b::SN, a::Number) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(a, b)...)
603602

604-
function /(a::Union{SN,Number}, b::SN)
605-
a * Mul(promote_symtype(/, Int, symtype(b)), makemul(-1, 1, b)...)
606-
end
603+
/(a::Union{SN,Number}, b::SN) = a * b^(-1)
607604

608605
\(a::SN, b::Union{Number, SN}) = b / a
609606

@@ -648,8 +645,9 @@ Base.show(io::IO, p::Pow) = show_term(io, p)
648645
^(a::Number, b::SN) = Pow(a, b)
649646

650647
function ^(a::Mul, b::Number)
648+
coeff = a.coeff isa Integer && b isa Integer ? (a.coeff//1) ^ b : a.coeff ^ b
651649
Mul(promote_symtype(^, symtype(a), symtype(b)),
652-
a.coeff ^ b, mapvalues((k, v) -> b*v, a.dict))
650+
coeff, mapvalues((k, v) -> b*v, a.dict))
653651
end
654652

655653
function *(a::Mul, b::Pow)
@@ -693,7 +691,7 @@ function similarterm(p::Union{Mul, Add, Pow}, f, args)
693691
if f === (+)
694692
Add(symtype(p), makeadd(1, 0, args...)...)
695693
elseif f == (*)
696-
Mul(symtype(p), makemul(1, 1, args...)...)
694+
Mul(symtype(p), makemul(1, args...)...)
697695
elseif f == (^)
698696
Pow(args...)
699697
else

test/basics.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,8 @@ end
9898
@test repr(-a) == "-1a"
9999
@test repr(-a + 3) == "3 + -1a"
100100
@test repr(-(a + b)) == "-1a + -1b"
101+
@test repr((2a)^(-2a)) == "(2a)^(-2a)"
102+
@test repr(1/2a) == "(1//2)*(a^-1)"
103+
@test repr(2/(2*a)) == "a^-1"
104+
@test repr(Term(*, [1, 1])) == "1*1"
101105
end

0 commit comments

Comments
 (0)