Skip to content

Commit 87d6fc6

Browse files
committed
trait-like dispatch
1 parent 2994e16 commit 87d6fc6

File tree

5 files changed

+24
-21
lines changed

5 files changed

+24
-21
lines changed

benchmark/benchmarks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using BenchmarkTools, SymbolicUtils
2-
using SymbolicUtils: isnumber
2+
using SymbolicUtils: is_literal_number
33

44
using Random
55

@@ -8,7 +8,7 @@ SUITE = BenchmarkGroup()
88
@syms a b c d; Random.seed!(123);
99

1010
let r = @rule(~x => ~x), rs = RuleSet([r]),
11-
acr = @rule(~x::isnumber + ~y => ~y)
11+
acr = @rule(~x::is_literal_number + ~y => ~y)
1212

1313
overhead = SUITE["overhead"] = BenchmarkGroup()
1414
overhead["rule"] = BenchmarkGroup()

src/methods.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@ const diadic = [+, -, max, min, *, /, \, hypot, atan, mod, rem, ^, copysign]
55
const previously_declared_for = Set([])
66

77
# TODO: it's not possible to dispatch on the symtype! (only problem is Parameter{})
8-
function assert_number(a, b)
9-
assert_number(a)
10-
assert_number(b)
8+
9+
assert_like(f, T) = nothing
10+
function assert_like(f, T, a, b...)
11+
islike(a, T) || throw(ArgumentError("The function $f cannot be applied to $a which is not a $T-like object." *
12+
"Define `isnumberlike(::$(typeof(a))) = true` to enable this."))
13+
assert_like(f, T, b...)
1114
end
1215

13-
assert_number(a) = symtype(a) <: Number || error("Can't apply this to not a number")
16+
islike(a, T) = symtype(a) <: T
17+
1418
# TODO: keep domains tighter than this
1519
function number_methods(T, rhs1, rhs2)
1620
exprs = []
1721

18-
rhs2 = :($assert_number(a, b); $rhs2)
19-
rhs1 = :($assert_number(a); $rhs1)
22+
rhs2 = :($assert_like(f, Number, a, b); $rhs2)
23+
rhs1 = :($assert_like(f, Number, a); $rhs1)
2024

2125
for f in diadic
2226
for S in previously_declared_for
@@ -101,8 +105,7 @@ for f in [identity, one, zero, *, +]
101105
end
102106

103107
promote_symtype(::typeof(Base.real), T::Type{<:Number}) = Real
104-
Base.real(s::Symbolic{<:Real}) = s
105-
Base.real(s::Symbolic{<:Number}) = term(real, s)
108+
Base.real(s::Symbolic) = islike(s, Real) ? s : term(real, s)
106109

107110
## Booleans
108111

src/ordering.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function <ₑ(a, b)
1717
elseif istree(b) && !istree(a)
1818
args = arguments(b)
1919
if length(args) === 2
20-
n1, n2 = !isnumber(args[1]) , !isnumber(args[2])
20+
n1, n2 = !is_literal_number(args[1]) , !is_literal_number(args[2])
2121
if n1 && n2
2222
# both subterms are terms, so it's definitely firster
2323
return true
@@ -102,7 +102,7 @@ function cmp_term_term(a, b)
102102
if length(aa) !== length(ab)
103103
return length(aa) < length(ab)
104104
else
105-
terms = zip(Iterators.filter(!isnumber, aa), Iterators.filter(!isnumber, ab))
105+
terms = zip(Iterators.filter(!is_literal_number, aa), Iterators.filter(!is_literal_number, ab))
106106

107107
for (x,y) in terms
108108
if x <ₑ y
@@ -113,8 +113,8 @@ function cmp_term_term(a, b)
113113
end
114114

115115
# compare the numbers
116-
nums = zip(Iterators.filter(isnumber, aa),
117-
Iterators.filter(isnumber, ab))
116+
nums = zip(Iterators.filter(is_literal_number, aa),
117+
Iterators.filter(is_literal_number, ab))
118118

119119
for (x,y) in nums
120120
if x <ₑ y

src/simplify_rules.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ let
44
PLUS_RULES = [
55
@rule(~x::isnotflat(+) => flatten_term(+, ~x))
66
@rule(~x::needs_sorting(+) => sort_args(+, ~x))
7-
@ordered_acrule(~a::isnumber + ~b::isnumber => ~a + ~b)
7+
@ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b)
88

99
@acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...))
1010
@acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...))
1111
@acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...))
1212

1313
@acrule(~x + *(~β, ~x) => *(1 + ~β, ~x))
14-
@acrule(*(~α::isnumber, ~x) + ~x => *(~α + 1, ~x))
14+
@acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x))
1515
@rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...))
1616

1717
@ordered_acrule((~z::_iszero + ~x) => ~x)
@@ -22,7 +22,7 @@ let
2222
@rule(~x::isnotflat(*) => flatten_term(*, ~x))
2323
@rule(~x::needs_sorting(*) => sort_args(*, ~x))
2424

25-
@ordered_acrule(~a::isnumber * ~b::isnumber => ~a * ~b)
25+
@ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b)
2626
@rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...))
2727

2828
@acrule((~y)^(~n) * ~y => (~y)^(~n+1))
@@ -50,7 +50,7 @@ let
5050
@rule(~x / ~y => ~x * pow(~y, -1))
5151
@rule(one(~x) => one(symtype(~x)))
5252
@rule(zero(~x) => zero(symtype(~x)))
53-
@rule(cond(~x::isnumber, ~y, ~z) => ~x ? ~y : ~z)
53+
@rule(cond(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z)
5454
]
5555

5656
TRIG_RULES = [
@@ -92,9 +92,9 @@ let
9292
# simplify terms with no symbolic arguments
9393
# e.g. this simplifies term(isodd, 3, type=Bool)
9494
# or term(!, false)
95-
@rule((~f)(~x::isnumber) => (~f)(~x))
95+
@rule((~f)(~x::is_literal_number) => (~f)(~x))
9696
# and this simplifies any binary comparison operator
97-
@rule((~f)(~x::isnumber, ~y::isnumber) => (~f)(~x, ~y))
97+
@rule((~f)(~x::is_literal_number, ~y::is_literal_number) => (~f)(~x, ~y))
9898
]
9999

100100
function number_simplifier()

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T
109109
is_operation(f) = @nospecialize(x) -> istree(x) && (operation(x) == f)
110110

111111
isliteral(::Type{T}) where {T} = x -> x isa T
112-
isnumber(x) = isliteral(Number)(x)
112+
is_literal_number(x) = isliteral(Number)(x)
113113

114114
_iszero(t) = false
115115
_iszero(x::Number) = iszero(x)

0 commit comments

Comments
 (0)