Skip to content

Commit b9843c6

Browse files
authored
Merge pull request #261 from JuliaSymbolics/ys/canon
Formalize canonical form
2 parents 7e7a306 + e0c1beb commit b9843c6

File tree

5 files changed

+55
-10
lines changed

5 files changed

+55
-10
lines changed

src/SymbolicUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ include("types.jl")
1313

1414
# Methods on symbolic objects
1515
using SpecialFunctions, NaNMath
16-
import IfElse: ifelse # need to not bring IfElse name in or it will clash
16+
import IfElse: ifelse # need to not bring IfElse name in or it will clash with Rewriters.IfElse
1717
include("methods.jl")
1818

1919
# LinkedList, simplification utilities

src/code.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ function_to_expr(op, args, st) = nothing
100100

101101
function function_to_expr(::typeof(^), O, st)
102102
args = arguments(O)
103-
if length(args) == 2 && args[2] isa Number && args[2] < 0
103+
if length(args) == 2 && args[2] isa Real && args[2] < 0
104104
ex = args[1]
105105
if args[2] == -1
106106
return toexpr(Term{Any}(inv, [ex]), st)

src/types.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ function show_call(io, f, args)
503503
if binary
504504
for (i, t) in enumerate(args)
505505
i != 1 && print(io, " $fname ")
506-
print_arg(io, t)
506+
print_arg(io, t, paren=true)
507507
end
508508
else
509509
if f isa Sym
@@ -513,7 +513,7 @@ function show_call(io, f, args)
513513
end
514514
print(io, "(")
515515
for i=1:length(args)
516-
print(IOContext(io, :paren => false), args[i])
516+
print(io, args[i])
517517
i != length(args) && print(io, ", ")
518518
end
519519
print(io, ")")
@@ -775,9 +775,21 @@ mul_t(a) = promote_symtype(*, symtype(a))
775775
a.coeff * b.coeff,
776776
_merge(+, a.dict, b.dict, filter=_iszero))
777777

778-
*(a::Number, b::SN) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(a, b)...)
778+
function *(a::Number, b::SN)
779+
if iszero(a)
780+
a
781+
elseif isone(a)
782+
b
783+
elseif b isa Add
784+
# 2(a+b) -> 2a + 2b
785+
T = promote_symtype(+, typeof(a), symtype(b))
786+
Add(T, b.coeff * a, Dict(k=>v*a for (k, v) in b.dict))
787+
else
788+
Mul(mul_t(a, b), makemul(a, b)...)
789+
end
790+
end
779791

780-
*(b::SN, a::Number) = iszero(a) ? a : isone(a) ? b : Mul(mul_t(a, b), makemul(a, b)...)
792+
*(a::SN, b::Number) = b * a
781793

782794
/(a::Union{SN,Number}, b::SN) = a * b^(-1)
783795

@@ -853,7 +865,13 @@ end
853865

854866
*(a::Pow, b::Mul) = b * a
855867

856-
_merge(f, d, others...; filter=x->false) = _merge!(f, copy(d), others...; filter=filter)
868+
function copy_similar(d, others)
869+
K = promote_type(keytype(d), keytype.(others)...)
870+
V = promote_type(valtype(d), valtype.(others)...)
871+
Dict{K, V}(d)
872+
end
873+
874+
_merge(f, d, others...; filter=x->false) = _merge!(f, copy_similar(d, others), others...; filter=filter)
857875
function _merge!(f, d, others...; filter=x->false)
858876
acc = d
859877
for other in others

test/basics.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ end
151151
@test repr(Term(*, [1, 1])) == "*1"
152152
@test repr(Term(*, [2, 1])) == "2*1"
153153
@test repr((a + b) - (b + c)) == "a - c"
154-
@test repr(a + -1*(b + c)) == "a - (b + c)"
154+
@test repr(a + -1*(b + c)) == "a - b - c"
155155
@test repr(a + -1*b) == "a - b"
156156
end
157157

@@ -165,7 +165,7 @@ toterm(t) = Term{symtype(t)}(operation(t), arguments(t))
165165
@testset "diffs" begin
166166
@syms a b c
167167
@test isequal(toterm(-1c), Term{Number}(*, [-1, c]))
168-
@test isequal(toterm(-1(a+b)), Term{Number}(*, [-1, a+b]))
168+
@test isequal(toterm(-1(a+b)), Term{Number}(+, [-1a, -b]))
169169
@test isequal(toterm((a + b) - (b + c)), Term{Number}(+, [a, -1c]))
170170
end
171171

@@ -181,3 +181,17 @@ end
181181
@test_throws MethodError a * b
182182
@test_throws MethodError a + b
183183
end
184+
185+
@testset "canonical form" begin
186+
@syms a b c
187+
for x in [a, a*b, a+b, a-b, a^2, sin(a)]
188+
@test isequal(x * 1, x)
189+
@test x * 0 === 0
190+
@test isequal(x + 0, x)
191+
@test isequal(x + x, 2x)
192+
@test isequal(x + 2x, 3x)
193+
@test x - x === 0
194+
@test isequal(-x, -1x)
195+
@test isequal(x^1, x)
196+
end
197+
end

test/fuzzlib.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
using SymbolicUtils
2+
using SymbolicUtils: Term
3+
using SpecialFunctions
24
using Test
5+
import IfElse: ifelse
6+
import IfElse
37

48
using SymbolicUtils: showraw, Symbolic
59

@@ -101,6 +105,9 @@ function gen_rand_expr(inputs;
101105
min_depth=min_depth,
102106
max_depth=max_depth)
103107
else
108+
@show f
109+
@show arity
110+
@show args
104111
rethrow(err)
105112
end
106113
end
@@ -114,9 +121,15 @@ function fuzz_test(ntrials, spec, simplify=simplify;kwargs...)
114121
inputs = Set()
115122
expr = gen_rand_expr(inputs; spec=spec, kwargs...)
116123
inputs = collect(inputs)
124+
code = try
125+
SymbolicUtils.Code.toexpr(expr)
126+
catch err
127+
@show expr
128+
rethrow(err)
129+
end
117130
unsimplifiedstr = """
118131
function $(tuple(inputs...))
119-
$(sprint(io->showraw(io, expr)))
132+
$(sprint(io->print(io, code)))
120133
end
121134
"""
122135

0 commit comments

Comments
 (0)