Skip to content

Commit e39b109

Browse files
Merge pull request #548 from JuliaSymbolics/s/rewriter-tweaks
WIP: Various tweaks to the Rewriters
2 parents e9a96bd + 1457894 commit e39b109

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

src/rewriters.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ module Rewriters
3333
using SymbolicUtils: @timer
3434
using TermInterface
3535

36-
import SymbolicUtils: similarterm
36+
import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count
3737
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough
3838

3939
# Cache of printed rules to speed up @timer
@@ -63,18 +63,24 @@ If(f, x) = IfElse(f, x, Empty())
6363

6464
struct Chain
6565
rws
66+
stop_on_match::Bool
6667
end
68+
Chain(rws) = Chain(rws, false)
6769

6870
function (rw::Chain)(x)
6971
for f in rw.rws
7072
y = @timer cached_repr(f) f(x)
73+
if rw.stop_on_match && !isnothing(y) && !isequal(y, x)
74+
return y
75+
end
76+
7177
if y !== nothing
7278
x = y
7379
end
7480
end
7581
return x
76-
end
7782

83+
end
7884
instrument(c::Chain, f) = Chain(map(x->instrument(x,f), c.rws))
7985

8086
struct RestartedChain{Cs}
@@ -145,8 +151,8 @@ function (rw::FixpointNoCycle)(x)
145151
f = rw.rw
146152
push!(rw.hist, hash(x))
147153
y = @timer cached_repr(f) f(x)
148-
while x !== y && hash(x) hist
149-
if y === nothing
154+
while x !== y && hash(x) rw.hist
155+
if y === nothing
150156
empty!(rw.hist)
151157
return x
152158
end
@@ -195,9 +201,10 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
195201
if ord === :pre
196202
x = p.rw(x)
197203
end
198-
if iscall(x)
199-
x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)))
200-
end
204+
205+
x = p.similarterm(x, operation(x), map(PassThrough(p),
206+
unsorted_arguments(x)), metadata=metadata(x))
207+
201208
return ord === :post ? p.rw(x) : x
202209
else
203210
return p.rw(x)
@@ -219,7 +226,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
219226
end
220227
end
221228
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
222-
t = p.similarterm(x, operation(x), args)
229+
t = p.similarterm(x, operation(x), args, metadata=metadata(x))
223230
end
224231
return ord === :post ? p.rw(t) : t
225232
else

src/rule.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function (r::Rule)(term)
139139

140140
try
141141
# n == 1 means that exactly one term of the input (term,) was matched
142-
success(bindings, n) = n == 1 ? (@timer "RHS" rhs(bindings)) : nothing
142+
success(bindings, n) = n == 1 ? (@timer "RHS" rhs(assoc(bindings, :MATCH, term))) : nothing
143143
return r.matcher(success, (term,), EMPTY_IMMUTABLE_DICT)
144144
catch err
145145
throw(RuleRewriteError(r, term))

test/rewrite.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,32 @@ using SymbolicUtils: @capture
7676
@eqtest f(b^b) == b
7777
@test f(b+b) == nothing
7878
end
79+
80+
@testset "Rewriter tweaks #548" begin
81+
struct MetaData end
82+
ex = a + b
83+
ex = setmetadata(ex, MetaData, :metadata)
84+
ex1 = ex + c
85+
86+
@test SymbolicUtils.isterm(ex1)
87+
@test getmetadata(arguments(ex1)[1], MetaData) == :metadata
88+
89+
ex = a
90+
ex = setmetadata(ex, MetaData, :metadata)
91+
ex1 = ex + b
92+
93+
@test getmetadata(arguments(ex1)[1], MetaData) == :metadata
94+
95+
ex = a * b
96+
ex = setmetadata(ex, MetaData, :metadata)
97+
ex1 = ex * c
98+
99+
@test SymbolicUtils.isterm(ex1)
100+
@test getmetadata(arguments(ex1)[1], MetaData) == :metadata
101+
102+
ex = a
103+
ex = setmetadata(ex, MetaData, :metadata)
104+
ex1 = ex * b
105+
106+
@test getmetadata(arguments(ex1)[1], MetaData) == :metadata
107+
end

0 commit comments

Comments
 (0)