Skip to content

Commit b798093

Browse files
authored
Merge pull request #549 from JuliaSymbolics/s/deglex-print
better ordering of terms
2 parents 9ef38f3 + f0cfce0 commit b798093

File tree

9 files changed

+158
-174
lines changed

9 files changed

+158
-174
lines changed

src/ordering.jl

Lines changed: 65 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -8,120 +8,84 @@
88
<(a::Symbolic, b::Number) = false
99
<(a::Number, b::Symbolic) = true
1010

11-
arglength(a) = length(arguments(a))
12-
function <(a, b)
13-
if isterm(a) && (b isa Symbolic && !isterm(b))
14-
return false
15-
elseif isterm(b) && (a isa Symbolic && !isterm(a))
16-
return true
17-
elseif (isadd(a) || ismul(a)) && (isadd(b) || ismul(b))
18-
return cmp_mul_adds(a, b)
19-
elseif issym(a) && issym(b)
20-
nameof(a) < nameof(b)
21-
elseif !istree(a) && !istree(b)
22-
T = typeof(a)
23-
S = typeof(b)
24-
return T===S ? (T <: Number ? isless(a, b) : hash(a) < hash(b)) : nameof(T) < nameof(S)
25-
elseif istree(b) && !istree(a)
26-
return true
27-
elseif istree(a) && istree(b)
28-
return cmp_term_term(a,b)
11+
<(a::Function, b::Function) = nameof(a) <nameof(b)
12+
13+
<(a::Type, b::Type) = nameof(a) <nameof(b)
14+
<(a::T, b::S) where{T,S} = T<S
15+
<(a::T, b::T) where{T} = a < b
16+
17+
18+
###### A variation on degree lexicographic order ########
19+
# find symbols and their corresponding degrees
20+
function get_degrees(expr)
21+
if issym(expr)
22+
((Symbol(expr),) => 1,)
23+
elseif istree(expr)
24+
op = operation(expr)
25+
args = arguments(expr)
26+
if operation(expr) == (^) && args[2] isa Number
27+
return map(get_degrees(args[1])) do (base, pow)
28+
(base => pow * args[2])
29+
end
30+
elseif operation(expr) == (*)
31+
return mapreduce(get_degrees,
32+
(x,y)->(x...,y...,), args)
33+
elseif operation(expr) == (+)
34+
ds = map(get_degrees, args)
35+
_, idx = findmax(x->sum(last.(x), init=0), ds)
36+
return ds[idx]
37+
elseif operation(expr) == (getindex)
38+
args = arguments(expr)
39+
return ((Symbol.(args)...,) => 1,)
40+
else
41+
return ((Symbol("zzzzzzz", hash(expr)),) => 1,)
42+
end
2943
else
30-
return !(b <ₑ a)
44+
return ()
3145
end
3246
end
3347

34-
function cmp_mul_adds(a, b)
35-
(isadd(a) && ismul(b)) && return true
36-
(ismul(a) && isadd(b)) && return false
37-
a_args = unsorted_arguments(a)
38-
b_args = unsorted_arguments(b)
39-
length(a_args) < length(b_args) && return true
40-
length(a_args) > length(b_args) && return false
41-
a_args = arguments(a)
42-
b_args = arguments(b)
43-
for (x, y) in zip(a_args, b_args)
44-
x <ₑ y && return true
45-
end
46-
return false
48+
function monomial_lt(degs1, degs2)
49+
d1 = sum(last, degs1, init=0)
50+
d2 = sum(last, degs2, init=0)
51+
d1 != d2 ? d1 < d2 : lexlt(degs1, degs2)
4752
end
4853

49-
function <(a::Symbol, b::Symbol)
50-
# Enforce the order [+,-,\,/,^,*]
51-
if b === :*
52-
a in (:^, :/, :\, :-, :+)
53-
elseif b === :^
54-
a in (:/, :\, :-, :+) && return true
55-
elseif b === :/
56-
a in (:\, :-, :+) && return true
57-
elseif b === :\
58-
a in (:-, :+) && return true
59-
elseif b === :-
60-
a === :+ && return true
61-
elseif a in (:*, :^, :/, :-, :+)
62-
false
63-
else
64-
a < b
54+
function lexlt(degs1, degs2)
55+
for (a, b) in zip(degs1, degs2)
56+
if a[1] == b[1] && a[2] != b[2]
57+
return a[2] > b[2]
58+
elseif a[1] != b[1]
59+
return a < b
60+
end
6561
end
62+
return false # they are equal
6663
end
6764

68-
<(a::Function, b::Function) = nameof(a) <nameof(b)
69-
70-
<(a::Type, b::Type) = nameof(a) <nameof(b)
71-
72-
function cmp_term_term(a, b)
73-
la = arglength(a)
74-
lb = arglength(b)
65+
_arglen(a) = istree(a) ? length(unsorted_arguments(a)) : 0
7566

76-
if la == 0 && lb == 0
77-
return operation(a) <operation(b)
78-
elseif la === 0
79-
return operation(a) <ₑ b
80-
elseif lb === 0
81-
return a <operation(b)
82-
end
83-
84-
na = operation(a)
85-
nb = operation(b)
86-
87-
if 0 < arglength(a) <= 2 && 0 < arglength(b) <= 2
88-
# e.g. a < sin(a) < b ^ 2 < b
89-
@goto compare_args
67+
function <(a::Tuple, b::Tuple)
68+
for (x, y) in zip(a, b)
69+
if x <ₑ y
70+
return true
71+
elseif y <ₑ x
72+
return false
73+
end
9074
end
75+
return length(a) < length(b)
76+
end
9177

92-
if na !== nb
93-
return na <ₑ nb
94-
elseif arglength(a) != arglength(b)
95-
return arglength(a) < arglength(b)
96-
else
97-
@label compare_args
98-
aa, ab = arguments(a), arguments(b)
99-
if length(aa) !== length(ab)
100-
return length(aa) < length(ab)
78+
function <(a::BasicSymbolic, b::BasicSymbolic)
79+
da, db = get_degrees(a), get_degrees(b)
80+
fw = monomial_lt(da, db)
81+
bw = monomial_lt(db, da)
82+
if fw === bw && !isequal(a, b)
83+
if _arglen(a) == _arglen(b)
84+
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
10185
else
102-
terms = zip(Iterators.filter(!is_literal_number, aa), Iterators.filter(!is_literal_number, ab))
103-
104-
for (x,y) in terms
105-
if x <ₑ y
106-
return true
107-
elseif y <ₑ x
108-
return false
109-
end
110-
end
111-
112-
# compare the numbers
113-
nums = zip(Iterators.filter(is_literal_number, aa),
114-
Iterators.filter(is_literal_number, ab))
115-
116-
for (x,y) in nums
117-
if x <ₑ y
118-
return true
119-
elseif y <ₑ x
120-
return false
121-
end
122-
end
123-
86+
return _arglen(a) < _arglen(b)
12487
end
125-
return na <ₑ nb # all args are equal, compare the name
88+
else
89+
return fw
12690
end
12791
end

src/types.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,20 @@ end
113113
function arguments(x::BasicSymbolic)
114114
args = unsorted_arguments(x)
115115
@compactified x::BasicSymbolic begin
116-
Add => @goto ADDMUL
117-
Mul => @goto ADDMUL
116+
Add => @goto ADD
117+
Mul => @goto MUL
118118
_ => return args
119119
end
120-
@label ADDMUL
120+
@label MUL
121+
if !x.issorted[]
122+
sort!(args, by=get_degrees)
123+
x.issorted[] = true
124+
end
125+
return args
126+
127+
@label ADD
121128
if !x.issorted[]
122-
sort!(args, lt = <)
129+
sort!(args, lt = monomial_lt, by=get_degrees)
123130
x.issorted[] = true
124131
end
125132
return args
@@ -668,20 +675,19 @@ function remove_minus(t)
668675
Any[-args[1], args[2:end]...]
669676
end
670677

671-
function show_add(io, args)
672-
negs = filter(isnegative, args)
673-
nnegs = filter(!isnegative, args)
674-
for (i, t) in enumerate(nnegs)
675-
i != 1 && print(io, " + ")
676-
print_arg(io, +, t)
677-
end
678678

679-
for (i, t) in enumerate(negs)
680-
if i==1 && isempty(nnegs)
681-
print_arg(io, -, t)
682-
else
683-
print(io, " - ")
679+
function show_add(io, args)
680+
for (i, t) in enumerate(args)
681+
neg = isnegative(t)
682+
if i != 1
683+
print(io, neg ? " - " : " + ")
684+
elseif isnegative(t)
685+
print(io, "-")
686+
end
687+
if neg
684688
show_mul(io, remove_minus(t))
689+
else
690+
print_arg(io, +, t)
685691
end
686692
end
687693
end

test/basics.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,29 @@ end
188188
@test repr((-1)^a) == "(-1)^a"
189189
end
190190

191+
@testset "polynomial printing" begin
192+
@syms a b c x[1:3]
193+
@test repr(b+a) == "a + b"
194+
@test repr(b-a) == "-a + b"
195+
@test repr(2a+1+3a^2) == "1 + 2a + 3(a^2)"
196+
@test repr(2a+1+3a^2+2b+3b^2+4a*b) == "1 + 2a + 2b + 3(a^2) + 4a*b + 3(b^2)"
197+
198+
@syms a b[1:3] c d[1:3]
199+
get(x, i) = term(getindex, x, i, type=Number)
200+
b1, b3, d1, d2 = get(b,1),get(b,3), get(d,1), get(d,2)
201+
@test repr(a + b3 + b1 + d2 + c) == "a + b[1] + b[3] + c + d[2]"
202+
@test repr(expand((c + b3 - d1)^3)) == "b[3]^3 + 3(b[3]^2)*c - 3(b[3]^2)*d[1] + 3b[3]*(c^2) - 6b[3]*c*d[1] + 3b[3]*(d[1]^2) + c^3 - 3(c^2)*d[1] + 3c*(d[1]^2) - (d[1]^3)"
203+
end
204+
191205
@testset "inspect" begin
192206
@syms x y z
193207
y = SymbolicUtils.setmetadata(y, Integer, 42) # Set some metadata
194208
ex = z*(2x + 3y + 1)^2/(z+2x)
195209
@test_reference "inspect_output/ex.txt" sprint(io->SymbolicUtils.inspect(io, ex))
196210
@test_reference "inspect_output/ex-md.txt" sprint(io->SymbolicUtils.inspect(io, ex, metadata=true))
197211
@test_reference "inspect_output/ex-nohint.txt" sprint(io->SymbolicUtils.inspect(io, ex, hint=false))
198-
@test SymbolicUtils.pluck(ex, 8) == 2
199-
@test_reference "inspect_output/sub10.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 10)))
212+
@test SymbolicUtils.pluck(ex, 12) == 2
213+
@test_reference "inspect_output/sub10.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 9)))
200214
@test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14)))
201215
end
202216

test/inspect_output/ex-md.txt

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
1 DIV
22
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3-
3 │ ├─ SYM(z)
4-
4 │ └─ POW
5-
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6-
6 │ │ ├─ 1
7-
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8-
8 │ │ │ ├─ 2
9-
9 │ │ │ └─ SYM(x)
10-
10 │ └─ MUL(scalar = 3, powers = (y => 1,))
11-
11 │ ├─ 3
12-
12 │ └─ SYM(y) metadata=(Integer => 42,)
13-
13 │ └─ 2
3+
3 │ ├─ POW
4+
4 │ │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
5+
5 │ │ │ ├─ 1
6+
6 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
7+
7 │ │ │ │ ├─ 2
8+
8 │ │ │ └─ SYM(x)
9+
9 │ │ │ └─ MUL(scalar = 3, powers = (y => 1,))
10+
10 │ ├─ 3
11+
11 │ └─ SYM(y) metadata=(Integer => 42,)
12+
12 │ │ └─ 2
13+
13 │ └─ SYM(z)
1414
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15-
15 ├─ SYM(z)
16-
16 └─ MUL(scalar = 2, powers = (x => 1,))
17-
17 ├─ 2
18-
18 └─ SYM(x)
15+
15 ├─ MUL(scalar = 2, powers = (x => 1,))
16+
16 │ ├─ 2
17+
17 │ └─ SYM(x)
18+
18 └─ SYM(z)
1919

2020
Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number

test/inspect_output/ex-nohint.txt

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
1 DIV
22
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3-
3 │ ├─ SYM(z)
4-
4 │ └─ POW
5-
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6-
6 │ │ ├─ 1
7-
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8-
8 │ │ │ ├─ 2
9-
9 │ │ │ └─ SYM(x)
10-
10 │ └─ MUL(scalar = 3, powers = (y => 1,))
11-
11 │ ├─ 3
12-
12 │ └─ SYM(y)
13-
13 │ └─ 2
3+
3 │ ├─ POW
4+
4 │ │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
5+
5 │ │ │ ├─ 1
6+
6 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
7+
7 │ │ │ │ ├─ 2
8+
8 │ │ │ └─ SYM(x)
9+
9 │ │ │ └─ MUL(scalar = 3, powers = (y => 1,))
10+
10 │ ├─ 3
11+
11 │ └─ SYM(y)
12+
12 │ │ └─ 2
13+
13 │ └─ SYM(z)
1414
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15-
15 ├─ SYM(z)
16-
16 └─ MUL(scalar = 2, powers = (x => 1,))
17-
17 ├─ 2
18-
18 └─ SYM(x)
15+
15 ├─ MUL(scalar = 2, powers = (x => 1,))
16+
16 │ ├─ 2
17+
17 │ └─ SYM(x)
18+
18 └─ SYM(z)

test/inspect_output/ex.txt

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
1 DIV
22
2 ├─ MUL(scalar = 1, powers = (z => 1, 1 + 2x + 3y => 2))
3-
3 │ ├─ SYM(z)
4-
4 │ └─ POW
5-
5 │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
6-
6 │ │ ├─ 1
7-
7 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
8-
8 │ │ │ ├─ 2
9-
9 │ │ │ └─ SYM(x)
10-
10 │ └─ MUL(scalar = 3, powers = (y => 1,))
11-
11 │ ├─ 3
12-
12 │ └─ SYM(y)
13-
13 │ └─ 2
3+
3 │ ├─ POW
4+
4 │ │ ├─ ADD(scalar = 1, coeffs = (y => 3, x => 2))
5+
5 │ │ │ ├─ 1
6+
6 │ │ ├─ MUL(scalar = 2, powers = (x => 1,))
7+
7 │ │ │ │ ├─ 2
8+
8 │ │ │ └─ SYM(x)
9+
9 │ │ │ └─ MUL(scalar = 3, powers = (y => 1,))
10+
10 │ ├─ 3
11+
11 │ └─ SYM(y)
12+
12 │ │ └─ 2
13+
13 │ └─ SYM(z)
1414
14 └─ ADD(scalar = 0, coeffs = (z => 1, x => 2))
15-
15 ├─ SYM(z)
16-
16 └─ MUL(scalar = 2, powers = (x => 1,))
17-
17 ├─ 2
18-
18 └─ SYM(x)
15+
15 ├─ MUL(scalar = 2, powers = (x => 1,))
16+
16 │ ├─ 2
17+
17 │ └─ SYM(x)
18+
18 └─ SYM(z)
1919

2020
Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number

test/inspect_output/sub14.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
1 ADD(scalar = 0, coeffs = (z => 1, x => 2))
2-
2 ├─ SYM(z)
3-
3 └─ MUL(scalar = 2, powers = (x => 1,))
4-
4 ├─ 2
5-
5 └─ SYM(x)
2+
2 ├─ MUL(scalar = 2, powers = (x => 1,))
3+
3 │ ├─ 2
4+
4 │ └─ SYM(x)
5+
5 └─ SYM(z)
66

77
Hint: call SymbolicUtils.pluck(expr, line_number) to get the subexpression starting at line_number

0 commit comments

Comments
 (0)