Skip to content

Commit fac36de

Browse files
authored
Merge pull request #116 from JuliaSymbolics/ys/mtk
WIP: updates for MTK migration
2 parents 0909b3f + 072c623 commit fac36de

File tree

7 files changed

+69
-37
lines changed

7 files changed

+69
-37
lines changed

src/methods.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
1-
const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch, acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh, abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, inv, acotd, asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10, sech, coth, asin, cotd, cosd, sinh, abs, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh]
1+
const monadic = [deg2rad, rad2deg, transpose, -, conj, asind, log1p, acsch, acos, asec, acosh, acsc, cscd, log, tand, log10, csch, asinh, abs2, cosh, sin, cos, atan, cospi, cbrt, acosd, acoth, inv, acotd, asecd, exp, acot, sqrt, sind, sinpi, asech, log2, tan, exp10, sech, coth, asin, cotd, cosd, sinh, abs, csc, tanh, secd, atand, sec, acscd, cot, exp2, expm1, atanh, real]
22

3-
const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^]
3+
const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign]
44

55
const previously_declared_for = Set([])
6+
7+
# TODO: it's not possible to dispatch on the symtype! (only problem is Parameter{})
8+
function assert_number(a, b)
9+
assert_number(a)
10+
assert_number(b)
11+
end
12+
13+
assert_number(a) = symtype(a) <: Number || error("Can't apply this to not a number")
614
# TODO: keep domains tighter than this
715
function number_methods(T, rhs1, rhs2)
816
exprs = []
17+
18+
rhs2 = :($assert_number(a, b); $rhs2)
19+
rhs1 = :($assert_number(a); $rhs1)
20+
921
for f in diadic
1022
for S in previously_declared_for
1123
push!(exprs, quote
@@ -49,6 +61,9 @@ promote_symtype(::typeof(rem2pi), T::Type{<:Number}, mode) = T
4961
Base.rem2pi(x::Symbolic, mode::Base.RoundingMode) = term(rem2pi, x, mode)
5062

5163
for f in monadic
64+
if f in [real]
65+
continue
66+
end
5267
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = Number
5368
@eval (::$(typeof(f)))(a::Symbolic) = term($f, a)
5469
end
@@ -64,30 +79,38 @@ for f in [+, *]
6479
@eval (::$(typeof(f)))(x::Symbolic) = x
6580

6681
# single arg
67-
@eval function (::$(typeof(f)))(x::Symbolic, w...)
82+
@eval function (::$(typeof(f)))(x::Symbolic, w::Number...)
6883
term($f, x,w...,
6984
type=rec_promote_symtype($f, map(symtype, (x,w...))...))
7085
end
71-
@eval function (::$(typeof(f)))(x, y::Symbolic, w...)
86+
@eval function (::$(typeof(f)))(x::Number, y::Symbolic, w::Number...)
7287
term($f, x, y, w...,
7388
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
7489
end
75-
@eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w...)
90+
@eval function (::$(typeof(f)))(x::Symbolic, y::Symbolic, w::Number...)
7691
term($f, x, y, w...,
7792
type=rec_promote_symtype($f, map(symtype, (x, y, w...))...))
7893
end
7994
end
8095

96+
Base.:*(a::AbstractArray, b::Symbolic{<:Number}) = map(x->x*b, a)
97+
Base.:*(a::Symbolic{<:Number}, b::AbstractArray) = map(x->a*x, b)
98+
8199
for f in [identity, one, zero, *, +]
82100
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = T
83101
end
84102

103+
promote_symtype(::typeof(Base.real), T::Type{<:Number}) = Real
104+
Base.real(s::Symbolic{<:Real}) = s
105+
Base.real(s::Symbolic{<:Number}) = term(real, s)
106+
85107
## Booleans
86108

87109
# binary ops that return Bool
88110
for (f, Domain) in [(==) => Number, (!=) => Number,
89111
(<=) => Real, (>=) => Real,
90-
(< ) => Real, (> ) => Real,
112+
(isless) => Real,
113+
(<) => Real, (> ) => Real,
91114
(& ) => Bool, (| ) => Bool,
92115
xor => Bool]
93116
@eval begin
@@ -101,9 +124,11 @@ end
101124
Base.:!(s::Symbolic{Bool}) = Term{Bool}(!, [s])
102125
Base.:~(s::Symbolic{Bool}) = Term{Bool}(!, [s])
103126

127+
104128
# An ifelse node, ifelse is a built-in unfortunately
105129
#
106130
cond(_if::Bool, _then, _else) = ifelse(_if, _then, _else)
107131
function cond(_if::Symbolic{Bool}, _then, _else)
108132
Term{Union{symtype(_then), symtype(_else)}}(cond, Any[_if, _then, _else])
109133
end
134+

src/types.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ The output symtype of applying variable `f` to arugments of symtype `arg_symtype
159159
if the arguments are of the wrong type then this function will error.
160160
"""
161161
function promote_symtype(f::Sym{FnType{X,Y}}, args...) where {X, Y}
162+
if X === Tuple
163+
return Y
164+
end
165+
162166
nrequired = fieldcount(X)
163167
ngiven = nfields(args)
164168

test/fuzzlib.jl

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,28 +140,34 @@ function fuzz_test(ntrials, spec, simplify=simplify;kwargs...)
140140
catch err
141141
Errored(err)
142142
end
143-
try
144-
if unsimplified isa Errored
145-
@test simplified isa Errored
146-
elseif isnan(unsimplified)
147-
@test isnan(simplified)
148-
if !isnan(simplified)
149-
error("Failed")
150-
end
151-
else
152-
@test unsimplified simplified
153-
if !(unsimplified simplified)
154-
error("Failed")
155-
end
143+
if unsimplified isa Errored
144+
if !(simplified isa Errored)
145+
@test_skip false
146+
@goto print_err
156147
end
157-
catch err
158-
println("""Test failed for expression
148+
@test true
149+
elseif isnan(unsimplified)
150+
if !isnan(simplified)
151+
@test_skip false
152+
@goto print_err
153+
end
154+
@test true
155+
else
156+
if !(unsimplified simplified)
157+
@test_skip false
158+
@goto print_err
159+
end
160+
@test true
161+
end
162+
continue
163+
164+
@label print_err
165+
println("""Test failed for expression
159166
$(sprint(io->showraw(io, expr))) = $unsimplified
160-
Simplified to:
167+
Simplified:
161168
$(sprint(io->showraw(io, simplify(expr)))) = $simplified
162-
On inputs:
169+
Inputs:
163170
$inputs = $args
164-
""")
165-
end
171+
""")
166172
end
167173
end

test/interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ SymbolicUtils.to_symbolic(ex::Expr) = ex
2020
@test simplify(ex) == ex
2121

2222
SymbolicUtils.symtype(::Expr) = Real
23+
SymbolicUtils.symtype(::Symbol) = Real
2324
@test simplify(ex) == -1 + :x
2425
@test simplify(:a * (:b + -1 * :c) + -1 * (:b * :a + -1 * :c * :a), polynorm=true) == 0

test/rewrite.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ end
3737
@eqtest @rule((~x*~y + ~x*~z) => ~x * (~y+~z))(a*b + a*c) == a*(b+c)
3838

3939
@eqtest @rule(+(~~x) => ~~x)(a + b) == [a,b]
40-
@eqtest @rule(+(~~x) => ~~x)(a + b + c) == [a,b,c]
41-
@eqtest @rule(+(~~x) => ~~x)(+(a, b, c)) == [a,b,c]
40+
@eqtest @rule(+(~~x) => ~~x)(term(+, a, b, c)) == [a,b,c]
4241
@eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y))(term(+,9,8,9,type=Any)) == ([9,],8)
4342
@eqtest @rule(+(~~x,~y, ~~x) => (~~x, ~y, ~~x))(term(+,9,8,9,9,8,type=Any)) == ([9,8], 9, [9,8])
4443
@eqtest @rule(+(~~x,~y,~~x) => (~~x, ~y, ~~x))(term(+,6,type=Any)) == ([], 6, [])

test/rulesets.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using SymbolicUtils: getdepth, Rewriters
1010
rset = Rewriters.Postwalk(Rewriters.Chain([r1, r2]))
1111
@test getdepth(rset) == typemax(Int)
1212

13-
ex = 2 * (w+w+α+β)
13+
ex = 2 * term(+, w, w, α, β)
1414

1515
@eqtest rset(ex) == (((2 * w) + (2 * w)) + (2 * α)) + (2 * β)
1616
@eqtest Rewriters.Fixpoint(rset)(ex) == ((2 * (2 * w)) + (2 * α)) + (2 * β)
@@ -30,14 +30,14 @@ end
3030
@eqtest simplify(1x + 2x) == 3x
3131
@eqtest simplify(3x + 2x) == 5x
3232

33-
@eqtest simplify(a + b + (x * y) + c + 2 * (x * y) + d) == (3 * x * y) + a + b + c + d
34-
@eqtest simplify(a + b + 2 * (x * y) + c + 2 * (x * y) + d) == (4 * x * y) + a + b + c + d
33+
@eqtest simplify(a + b + (x * y) + c + 2 * (x * y) + d) == simplify((3 * x * y) + a + b + c + d)
34+
@eqtest simplify(a + b + 2 * (x * y) + c + 2 * (x * y) + d) == simplify((4 * x * y) + a + b + c + d)
3535

36-
@eqtest simplify(a * x^y * b * x^d) == (a * b * (x ^ (d + y)))
36+
@eqtest simplify(a * x^y * b * x^d) == simplify(a * b * (x ^ (d + y)))
3737

38-
@eqtest simplify(a + b + 0*c + d) == a + b + d
39-
@eqtest simplify(a * b * c^0 * d) == a * b * d
40-
@eqtest simplify(a * b * 1*c * d) == a * b * c * d
38+
@eqtest simplify(a + b + 0*c + d) == simplify(a + b + d)
39+
@eqtest simplify(a * b * c^0 * d) == simplify(a * b * d)
40+
@eqtest simplify(a * b * 1*c * d) == simplify(a * b * c * d)
4141

4242
@test simplify(Term(one, [a])) == 1
4343
@test simplify(Term(one, [b+1])) == 1

test/runtests.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ macro eqtest(expr)
1212
end
1313
SymbolicUtils.show_simplified[] = false
1414

15-
#using SymbolicUtils: Rule
16-
@test_broken isempty(detect_unbound_args(SymbolicUtils))
17-
1815
include("basics.jl")
1916
include("order.jl")
2017
include("rewrite.jl")

0 commit comments

Comments
 (0)