Skip to content

Commit e5cb78d

Browse files
committed
configurable similarterm in Walk
1 parent acaf8e3 commit e5cb78d

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

src/rewriters.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,20 @@ function (rw::Fixpoint)(x)
107107
return x
108108
end
109109

110-
struct Walk{ord, C, threaded}
110+
struct Walk{ord, C, F, threaded}
111111
rw::C
112112
thread_cutoff::Int
113+
similarterm::F
113114
end
114115

115116
using .Threads
116117

117-
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100)
118-
Walk{:post, typeof(rw), threaded}(rw, thread_cutoff)
118+
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
119+
Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
119120
end
120121

121-
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100)
122-
Walk{:pre, typeof(rw), threaded}(rw, thread_cutoff)
122+
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
123+
Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
123124
end
124125

125126
struct PassThrough{C}
@@ -128,22 +129,22 @@ end
128129
(p::PassThrough)(x) = (y=p.rw(x); isnothing(y) ? x : y)
129130

130131
passthrough(x, default) = isnothing(x) ? default : x
131-
function (p::Walk{ord, C, false})(x) where {ord, C}
132+
function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
132133
@assert ord === :pre || ord === :post
133134
if istree(x)
134135
if ord === :pre
135136
x = p.rw(x)
136137
end
137138
if istree(x)
138-
x = similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
139+
x = p.similarterm(x, operation(x), map(PassThrough(p), arguments(x)))
139140
end
140141
return ord === :post ? p.rw(x) : x
141142
else
142143
return p.rw(x)
143144
end
144145
end
145146

146-
function (p::Walk{ord, C, true})(x) where {ord, C}
147+
function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
147148
@assert ord === :pre || ord === :post
148149
if istree(x)
149150
if ord === :pre
@@ -158,7 +159,7 @@ function (p::Walk{ord, C, true})(x) where {ord, C}
158159
end
159160
end
160161
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
161-
t = similarterm(x, operation(x), args)
162+
t = p.similarterm(x, operation(x), args)
162163
end
163164
return ord === :post ? p.rw(t) : t
164165
else

0 commit comments

Comments
 (0)