Skip to content

Commit 8f933e2

Browse files
authored
Merge pull request #68 from JuliaSymbolics/s/ctx
Pass around a context object from simplify to matchers and rewriters
2 parents 6a58962 + 41240e9 commit 8f933e2

File tree

5 files changed

+69
-34
lines changed

5 files changed

+69
-34
lines changed

src/matchers.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,23 @@ struct Segment{F}
2222
predicate::F
2323
end
2424

25-
ismatch(s::Segment, t) = s.predicate(t)
26-
2725
Segment(s) = Segment(s, alwaystrue)
2826

2927
Base.show(io::IO, s::Segment) = (print(io, "~~"); print(io, s.name))
3028

3129
makesegment(s::Symbol, keys) = (push!(keys, s); Segment(s))
3230

31+
"""
32+
A wrapper for slot and segment predicates which allows them to
33+
take two arguments: the value and a Context
34+
"""
35+
struct Contextual{F}
36+
f::F
37+
end
38+
(c::Contextual)(args...) = c.f(args...)
39+
40+
ctxcall(f, x, ctx) = f isa Contextual ? f(x, ctx) : f(x)
41+
3342
function makesegment(s::Expr, keys)
3443
if !(s.head == :(::))
3544
error("Syntax for specifying a segment is ~~x::\$predicate, where predicate is a boolean function")
@@ -91,6 +100,16 @@ function makeconsequent(expr)
91100
return Expr(:call, map(makeconsequent, expr.args)...)
92101
end
93102
else
103+
if expr.head == :macrocall
104+
if expr.args[1] === Symbol("@ctx")
105+
if length(filter(x->!(x isa LineNumberNode), expr.args)) != 1
106+
error("@ctx takes no arguments. try (@ctx)")
107+
end
108+
return :__CTX__
109+
else
110+
return esc(expr)
111+
end
112+
end
94113
return Expr(expr.head, map(makeconsequent, expr.args)...)
95114
end
96115
else
@@ -106,21 +125,21 @@ end
106125
# 3. Callback: takes arguments Dictionary × Number of elements matched
107126
#
108127
function matcher(val::Any)
109-
function literal_matcher(data, bindings, next)
128+
function literal_matcher(next, data, bindings, ctx)
110129
!isempty(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
111130
end
112131
end
113132

114133
function matcher(slot::Slot)
115-
function slot_matcher(data, bindings, next)
134+
function slot_matcher(next, data, bindings, ctx)
116135
isempty(data) && return
117136
val = get(bindings, slot.name, nothing)
118137
if val !== nothing
119138
if isequal(val, car(data))
120139
return next(bindings, 1)
121140
end
122141
else
123-
if slot.predicate(car(data))
142+
if ctxcall(slot.predicate, car(data), ctx)
124143
next(assoc(bindings, slot.name, car(data)), 1)
125144
end
126145
end
@@ -156,7 +175,7 @@ function trymatchexpr(data, value, n)
156175
end
157176

158177
function matcher(segment::Segment)
159-
function segment_matcher(data, bindings, success)
178+
function segment_matcher(success, data, bindings, ctx)
160179
val = get(bindings, segment.name, nothing)
161180

162181
if val !== nothing
@@ -170,7 +189,7 @@ function matcher(segment::Segment)
170189
for i=length(data):-1:0
171190
subexpr = take_n(data, i)
172191

173-
if segment.predicate(subexpr)
192+
if ctxcall(segment.predicate, subexpr, ctx)
174193
res = success(assoc(bindings, segment.name, subexpr), i)
175194
if res !== nothing
176195
break
@@ -185,7 +204,7 @@ end
185204

186205
function matcher(term::Term)
187206
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
188-
function term_matcher(data, bindings, success)
207+
function term_matcher(success, data, bindings, ctx)
189208

190209
isempty(data) && return nothing
191210
!(car(data) isa Term) && return nothing
@@ -197,8 +216,9 @@ function matcher(term::Term)
197216
end
198217
return nothing
199218
end
200-
res = car(matchers′)(term, bindings′,
201-
(b, n) -> loop(drop_n(term, n), b, cdr(matchers′)))
219+
car(matchers′)(term, bindings′, ctx) do b, n
220+
loop(drop_n(term, n), b, cdr(matchers′))
221+
end
202222
end
203223

204224
loop(car(data), bindings, matchers) # Try to eat exactly one term

src/rule_dsl.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ function Base.show(io::IO, r::Rule)
2828
end
2929

3030
const EMPTY_DICT = ImmutableDict{Symbol, Any}(:____, nothing)
31+
struct EmptyCtx end
3132

32-
function (r::Rule)(term)
33-
match_function = r.matcher
33+
function (r::Rule)(term, ctx=EmptyCtx())
3434
rhs = r.rhs
3535

36-
match_function((term,),
37-
EMPTY_DICT,
38-
(d, n) -> n === 1 ? (@timer "RHS" rhs(d)) : nothing)
36+
r.matcher((term,), EMPTY_DICT, ctx) do bindings, n
37+
# n == 1 means that exactly one term of the input (term,) was matched
38+
n === 1 ? (@timer "RHS" rhs(bindings, ctx)) : nothing
39+
end
3940
end
4041

4142
"""
@@ -152,7 +153,7 @@ macro rule(expr)
152153
Rule($(QuoteNode(expr)),
153154
lhs_pattern,
154155
matcher(lhs_pattern),
155-
__MATCHES__ -> $(makeconsequent(rhs)),
156+
(__MATCHES__, __CTX__) -> $(makeconsequent(rhs)),
156157
rule_depth($lhs_term))
157158
end
158159
end
@@ -177,7 +178,7 @@ end
177178

178179
Base.show(io::IO, acr::ACRule) = print(io, "ACRule(", acr.rule, ")")
179180

180-
function (acr::ACRule)(term)
181+
function (acr::ACRule)(term, ctx=EmptyCtx())
181182
r = Rule(acr)
182183
if !(term isa Term)
183184
r(term)
@@ -187,12 +188,12 @@ function (acr::ACRule)(term)
187188
if f != operation(r.lhs) # Maybe offer a fallback if m.term errors.
188189
return nothing
189190
end
190-
191+
191192
T = symtype(term)
192193
args = arguments(term)
193-
194+
194195
for inds in permutations(eachindex(args), acr.arity)
195-
result = r(Term{T}(f, args[inds]))
196+
result = r(Term{T}(f, args[inds]), ctx)
196197
if !isnothing(result)
197198
return Term{T}(f, [result, (args[i] for i in eachindex(args) if i inds)...])
198199
end
@@ -233,7 +234,7 @@ struct RuleRewriteError
233234
expr
234235
end
235236

236-
function (r::RuleSet)(term; depth=typemax(Int), applyall=false, recurse=true)
237+
function (r::RuleSet)(term, context=EmptyCtx(); depth=typemax(Int), applyall=false, recurse=true)
237238
rules = r.rules
238239
term = to_symbolic(term)
239240
# simplify the subexpressions
@@ -243,21 +244,21 @@ function (r::RuleSet)(term; depth=typemax(Int), applyall=false, recurse=true)
243244
if term isa Symbolic
244245
if term isa Term && recurse
245246
expr = Term{symtype(term)}(operation(term),
246-
map(t -> r(t, depth=depth-1), arguments(term)))
247+
map(t -> r(t, context, depth=depth-1), arguments(term)))
247248
else
248249
expr = term
249250
end
250251
for i in 1:length(rules)
251252
expr′ = try
252-
@timer(repr(rules[i]), rules[i](expr))
253+
@timer(repr(rules[i]), rules[i](expr, context))
253254
catch err
254255
throw(RuleRewriteError(rules[i], expr))
255256
end
256257
if expr′ === nothing
257258
# this rule doesn't apply
258259
continue
259260
else
260-
expr = r(expr′, depth=getdepth(rules[i]))# levels touched
261+
expr = r(expr′, context, depth=getdepth(rules[i]))# levels touched
261262
applyall || return expr
262263
end
263264
end
@@ -269,17 +270,15 @@ end
269270

270271
getdepth(::RuleSet) = typemax(Int)
271272

272-
function fixpoint(f, x; kwargs...)
273-
x1 = f(x; kwargs...)
273+
function fixpoint(f, x, ctx; kwargs...)
274+
x1 = f(x, ctx; kwargs...)
274275
while !isequal(x1, x)
275276
x = x1
276-
x1 = f(x; kwargs...)
277+
x1 = f(x, ctx; kwargs...)
277278
end
278279
return x1
279280
end
280281

281-
fixpoint(f; kwargs...) = x -> fixpoint(f, x; kwargs...)
282-
283282
@noinline function Base.showerror(io::IO, err::RuleRewriteError)
284283
msg = "Failed to apply rule $(err.rule) on expression "
285284
msg *= sprint(io->showraw(io, err.expr))

src/simplify.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@ Applies them once if `fixpoint=false`.
1111
The `applyall` and `recurse` keywords are forwarded to the enclosed
1212
`RuleSet`.
1313
"""
14-
function simplify(x, rules=SIMPLIFY_RULES; fixpoint=true, applyall=true, recurse=true)
14+
function simplify(x, ctx=EmptyCtx(); rules=SIMPLIFY_RULES, fixpoint=true, applyall=true, recurse=true)
1515
if fixpoint
16-
SymbolicUtils.fixpoint(rules; recurse=recurse, applyall=recurse)(x)
16+
SymbolicUtils.fixpoint(rules, x, ctx; recurse=recurse, applyall=recurse)
1717
else
18-
rules(x; recurse=recurse, applyall=recurse)
18+
rules(x, ctx; recurse=recurse, applyall=recurse)
1919
end
2020
end
2121

22+
23+
Base.@deprecate simplify(x, rules::RuleSet; kwargs...) simplify(x, rules=rules; kwargs...)
24+
2225
"""
2326
substitute(expr, dict)
2427

test/basics.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SymbolicUtils: Sym, FnType, Term, symtype
1+
using SymbolicUtils: Sym, FnType, Term, symtype, Contextual, EmptyCtx
22
using SymbolicUtils
33
using Test
44

@@ -68,6 +68,19 @@ end
6868
@test_throws ErrorException t(2)
6969
end
7070

71+
@testset "Contexts" begin
72+
@syms a b c
73+
74+
@test @rule(~x::Contextual((x, ctx) -> ctx==EmptyCtx()) => "yes")(1) == "yes"
75+
@test @rule(~x::Contextual((x, ctx) -> haskey(ctx, x)) => true)(a, Dict(a=>1))
76+
@test @rule(~x::Contextual((x, ctx) -> haskey(ctx, x)) => true)(b, Dict(a=>1)) === nothing
77+
@test_throws UndefVarError @rule(~x => __CTX__)(a, "test")
78+
@test @rule(~x => @ctx)(a, "test") == "test"
79+
@test @rule(~x::Contextual((x, ctx) -> haskey(ctx, x)) => (@ctx)[~x])(a, Dict(a=>1)) === 1
80+
@test @rule(~x::Contextual((x, ctx) -> haskey(ctx, x)) => (@ctx)[~x])(b, Dict(a=>1)) === nothing
81+
@test simplify(a+a, "test", rules=RuleSet([@rule ~x => @ctx])) == "test"
82+
end
83+
7184
@testset "substitute" begin
7285
@syms a b
7386
@test substitute(a, Dict(a=>1)) == 1

test/rulesets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using SymbolicUtils: fixpoint, getdepth
1515
@test rset(ex) == (((2 * w) + (2 * w)) + (2 * α)) + (2 * β)
1616
@test rset(ex) == simplify(ex, rset; fixpoint=false, applyall=false)
1717

18-
@test fixpoint(rset, ex) == ((2 * (2 * w)) + (2 * α)) + (2 * β)
18+
@test fixpoint(rset, ex, "ctx") == ((2 * (2 * w)) + (2 * α)) + (2 * β)
1919
end
2020

2121
@testset "Numeric" begin

0 commit comments

Comments
 (0)