Skip to content

Commit c5bc4bd

Browse files
authored
Merge pull request #102 from JuliaSymbolics/s/hastrig
Check if an expression has trigonometric functions and use a different ruleset
2 parents 43a8b18 + c0021df commit c5bc4bd

File tree

4 files changed

+32
-4
lines changed

4 files changed

+32
-4
lines changed

src/rule_dsl.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,14 @@ function _recurse_apply_ruleset_threaded(r::RuleSet, term, context; depth, threa
264264
Term{symtype(term)}(operation(term), args)
265265
end
266266

267-
function (r::RuleSet)(term, context=EmptyCtx(); depth=typemax(Int), applyall::Bool=false, recurse::Bool=true,
268-
threaded::Bool=false, thread_subtree_cutoff::Int=100)
267+
const rule_repr = IdDict()
268+
269+
function (r::RuleSet)(term, context=EmptyCtx();
270+
depth=typemax(Int),
271+
applyall::Bool=false,
272+
recurse::Bool=true,
273+
threaded::Bool=false,
274+
thread_subtree_cutoff::Int=100)
269275
rules = r.rules
270276
term = to_symbolic(term)
271277
# simplify the subexpressions
@@ -286,7 +292,7 @@ function (r::RuleSet)(term, context=EmptyCtx(); depth=typemax(Int), applyall::B
286292
end
287293
for i in 1:length(rules)
288294
expr′ = try
289-
@timer(repr(rules[i]), rules[i](expr, context))
295+
@timer(Base.@get!(rule_repr, rules[i], repr(rules[i])), rules[i](expr, context))
290296
catch err
291297
throw(RuleRewriteError(rules[i], expr))
292298
end

src/rulesets.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@ const SIMPLIFY_RULES = RuleSet([
44
@rule ~t::sym_isa(Number) => NUMBER_RULES(~t, applyall=true, recurse=true)
55
])
66

7+
const SIMPLIFY_RULES_TRIG = RuleSet([
8+
@rule ~t::sym_isa(Bool) => BOOLEAN_RULES(~t, applyall=true, recurse=true)
9+
@rule ~t::sym_isa(Number) => NUMBER_RULES(~t, applyall=true, recurse=true)
10+
@rule ~t::sym_isa(Number) => TRIG_RULES(~t, recurse=true)
11+
])
12+
713
const NUMBER_RULES = RuleSet([
814
@rule ~t => ASSORTED_RULES(~t, recurse=false)
915
@rule ~t::is_operation(+) => PLUS_RULES(~t, recurse=false)
1016
@rule ~t::is_operation(*) => TIMES_RULES(~t, recurse=false)
1117
@rule ~t::is_operation(^) => POW_RULES(~t, recurse=false)
12-
@rule ~t => TRIG_RULES(~t, recurse=false)
1318
])
1419

1520
const PLUS_RULES = RuleSet([

src/simplify.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ of symtype Number.
1111
"""
1212
default_rules(x, ctx) = SIMPLIFY_RULES
1313

14+
function default_rules(x, ctx::EmptyCtx)
15+
has_trig(x) ?
16+
SIMPLIFY_RULES_TRIG :
17+
SIMPLIFY_RULES
18+
end
19+
1420
"""
1521
simplify(x, ctx=EmptyCtx();
1622
rules=default_rules(x, ctx),

src/util.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,14 @@ drop_n(ll, n) = n === 0 ? ll : drop_n(cdr(ll), n-1)
3434
@inline drop_n(ll::AbstractArray, n) = drop_n(LL(ll, 1), n)
3535
@inline drop_n(ll::LL, n) = LL(ll.v, ll.i+n)
3636

37+
has_trig(x) = false
38+
function has_trig(term::Term)
39+
fns = (sin, cos, tan, cot, sec, csc)
40+
op = operation(term)
41+
42+
if Base.@nany 6 i->fns[i] === op
43+
return true
44+
else
45+
return any(has_trig, arguments(term))
46+
end
47+
end

0 commit comments

Comments
 (0)