Skip to content

Commit 8707d24

Browse files
committed
refactor: no more need for memoize_on
1 parent b5285f7 commit 8707d24

File tree

3 files changed

+60
-190
lines changed

3 files changed

+60
-190
lines changed

src/Utils.jl

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -12,103 +12,6 @@ macro return_on_false2(flag, retval, retval2)
1212
)
1313
end
1414

15-
"""
16-
@memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode)
17-
...
18-
end
19-
20-
This macro takes a function definition and creates a second version of the
21-
function with an additional `id_map` argument. When passed this argument (an
22-
IdDict()), it will use use the `id_map` to avoid recomputing the same value
23-
for the same node in a tree. Use this to automatically create functions that
24-
work with trees that have shared child nodes.
25-
26-
Can optionally take a `postprocess` function, which will be applied to the
27-
result of the function before returning it, taking the result as the
28-
first argument and a boolean for whether the result was memoized as the
29-
second argument. This is useful for functions that need to count the number
30-
of unique nodes in a tree, for example.
31-
"""
32-
macro memoize_on(tree, args...)
33-
if length(args) (1, 2)
34-
error("Expected 2 or 3 arguments to @memoize_on")
35-
end
36-
postprocess = length(args) == 1 ? :((r, _) -> r) : args[1]
37-
def = length(args) == 1 ? args[1] : args[2]
38-
idmap_def = _memoize_on(tree, postprocess, def)
39-
40-
return quote
41-
$(esc(def)) # The normal function
42-
$(esc(idmap_def)) # The function with an id_map argument
43-
end
44-
end
45-
function _memoize_on(tree::Symbol, postprocess, def)
46-
sdef = splitdef(def)
47-
48-
# Add an id_map argument
49-
push!(sdef[:args], :(id_map::AbstractDict))
50-
51-
f_name = sdef[:name]
52-
53-
# Forward id_map argument to all calls of the same function
54-
# within the function body:
55-
sdef[:body] = postwalk(sdef[:body]) do ex
56-
if @capture(ex, f_(args__))
57-
if f == f_name
58-
return Expr(:call, f, args..., :id_map)
59-
end
60-
end
61-
return ex
62-
end
63-
64-
# Wrap the function body in a get!(id_map, tree) do ... end block:
65-
@gensym key is_memoized result body
66-
sdef[:body] = quote
67-
$key = objectid($tree)
68-
$is_memoized = haskey(id_map, $key)
69-
function $body()
70-
return $(sdef[:body])
71-
end
72-
$result = if $is_memoized
73-
@inbounds(id_map[$key])
74-
else
75-
id_map[$key] = $body()
76-
end
77-
return $postprocess($result, $is_memoized)
78-
end
79-
80-
return combinedef(sdef)
81-
end
82-
83-
"""
84-
@with_memoize(call, id_map)
85-
86-
This simple macro simply puts the `id_map`
87-
into the call, to be consistent with the `@memoize_on` macro.
88-
89-
```
90-
@with_memoize(_copy_node(tree), IdDict{Any,Any}())
91-
````
92-
93-
is converted to
94-
95-
```
96-
_copy_node(tree, IdDict{Any,Any}())
97-
```
98-
99-
"""
100-
macro with_memoize(def, id_map)
101-
idmap_def = _add_idmap_to_call(def, id_map)
102-
return quote
103-
$(esc(idmap_def))
104-
end
105-
end
106-
107-
function _add_idmap_to_call(def::Expr, id_map::Union{Symbol,Expr})
108-
@assert def.head == :call
109-
return Expr(:call, def.args[1], def.args[2:end]..., id_map)
110-
end
111-
11215
@inline function fill_similar(value::T, array, args...) where {T}
11316
out_array = similar(array, args...)
11417
fill!(out_array, value)

src/base.jl

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import Base:
2525

2626
using DispatchDoctor: @unstable
2727
using Compat: @inline, Returns
28-
using ..UtilsModule: @memoize_on, @with_memoize, Undefined
28+
using ..UtilsModule: Undefined
2929

3030
"""
3131
tree_mapreduce(
@@ -89,46 +89,83 @@ function tree_mapreduce(
8989
f_leaf::F1,
9090
f_branch::F2,
9191
op::G,
92-
tree::AbstractNode,
92+
tree::AbstractNode{D},
9393
result_type::Type{RT}=Undefined;
9494
f_on_shared::H=(result, is_shared) -> result,
95-
break_sharing::Val=Val(false),
96-
) where {F1<:Function,F2<:Function,G<:Function,H<:Function,RT}
97-
98-
# Trick taken from here:
99-
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
100-
# to speed up recursive closure
101-
@memoize_on t f_on_shared function inner(inner, t)
102-
if t.degree == 0
103-
return @inline(f_leaf(t))
104-
elseif t.degree == 1
105-
return @inline(op(@inline(f_branch(t)), inner(inner, t.l)))
106-
else
107-
return @inline(op(@inline(f_branch(t)), inner(inner, t.l), inner(inner, t.r)))
108-
end
109-
end
110-
111-
sharing = preserve_sharing(typeof(tree)) && break_sharing === Val(false)
95+
break_sharing::Val{BS}=Val(false),
96+
) where {F1<:Function,F2<:Function,G<:Function,D,H<:Function,RT,BS}
97+
sharing = preserve_sharing(typeof(tree)) && !break_sharing
11298

11399
RT == Undefined &&
114100
sharing &&
115101
throw(ArgumentError("Need to specify `result_type` if nodes are shared.."))
116102

117103
if sharing && RT != Undefined
118-
d = allocate_id_map(tree, RT)
119-
return @with_memoize inner(inner, tree) d
104+
id_map = allocate_id_map(tree, RT)
105+
reducer = TreeMapreducer(Val(D), id_map, f_leaf, f_branch, op, f_on_shared)
106+
return reducer(tree)
107+
else
108+
reducer = TreeMapreducer(Val(D), nothing, f_leaf, f_branch, op, f_on_shared)
109+
return reducer(tree)
110+
end
111+
end
112+
113+
struct TreeMapreducer{D,ID,F1<:Function,F2<:Function,G<:Function,H<:Function}
114+
max_degree::Val{D}
115+
id_map::ID
116+
f_leaf::F1
117+
f_branch::F2
118+
op::G
119+
f_on_shared::H
120+
end
121+
122+
@generated function (mapreducer::TreeMapreducer{MAX_DEGREE,ID})(
123+
tree::AbstractNode
124+
) where {MAX_DEGREE,ID}
125+
base_expr = quote
126+
d = tree.degree
127+
Base.Cartesian.@nif(
128+
$(MAX_DEGREE + 1),
129+
d_p_one -> (d_p_one - 1) == d,
130+
d_p_one -> if d_p_one == 1
131+
mapreducer.f_leaf(tree)
132+
else
133+
mapreducer.op(
134+
mapreducer.f_branch(tree),
135+
Base.Cartesian.@ntuple(
136+
d_p_one - 1, i -> mapreducer(tree.children[i][])
137+
)...,
138+
)
139+
end
140+
)
141+
end
142+
if ID <: Nothing
143+
# No sharing of nodes (is a tree, not a graph)
144+
return base_expr
120145
else
121-
return inner(inner, tree)
146+
# Otherwise, we need to cache results in `id_map`
147+
# according to `objectid` of the node
148+
return quote
149+
key = objectid(tree)
150+
is_cached = haskey(mapreducer.id_map, key)
151+
if is_cached
152+
return mapreducer.f_on_shared(@inbounds(mapreducer.id_map[key]), true)
153+
else
154+
res = $base_expr
155+
mapreducer.id_map[key] = res
156+
return mapreducer.f_on_shared(res, false)
157+
end
158+
end
122159
end
123160
end
161+
124162
function allocate_id_map(tree::AbstractNode, ::Type{RT}) where {RT}
125163
d = Dict{UInt,RT}()
126164
# Preallocate maximum storage (counting with duplicates is fast)
127165
N = length(tree; break_sharing=Val(true))
128166
sizehint!(d, N)
129167
return d
130168
end
131-
132169
# TODO: Raise Julia issue for this.
133170
# Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here!
134171
# I think it's because `setindex!` is declared with `@nospecialize` in IdDict.

test/test_graphs.jl

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -120,76 +120,6 @@ end
120120

121121
@test expr_eql(ex, true_ex)
122122
end
123-
124-
@testset "@memoize_on" begin
125-
ex = @macroexpand DynamicExpressions.UtilsModule.@memoize_on tree ((x, _) -> x) function _copy_node(
126-
tree::Node{T}
127-
)::Node{T} where {T}
128-
if tree.degree == 0
129-
if tree.constant
130-
Node(; val=copy(tree.val))
131-
else
132-
Node(T; feature=copy(tree.feature))
133-
end
134-
elseif tree.degree == 1
135-
Node(copy(tree.op), _copy_node(tree.l))
136-
else
137-
Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r))
138-
end
139-
end
140-
true_ex = quote
141-
function _copy_node(tree::Node{T})::Node{T} where {T}
142-
if tree.degree == 0
143-
if tree.constant
144-
Node(; val=copy(tree.val))
145-
else
146-
Node(T; feature=copy(tree.feature))
147-
end
148-
elseif tree.degree == 1
149-
Node(copy(tree.op), _copy_node(tree.l))
150-
else
151-
Node(copy(tree.op), _copy_node(tree.l), _copy_node(tree.r))
152-
end
153-
end
154-
function _copy_node(tree::Node{T}, id_map::AbstractDict;)::Node{T} where {T}
155-
key = objectid(tree)
156-
is_memoized = haskey(id_map, key)
157-
function body()
158-
return begin
159-
if tree.degree == 0
160-
if tree.constant
161-
Node(; val=copy(tree.val))
162-
else
163-
Node(T; feature=copy(tree.feature))
164-
end
165-
elseif tree.degree == 1
166-
Node(copy(tree.op), _copy_node(tree.l, id_map))
167-
else
168-
Node(
169-
copy(tree.op),
170-
_copy_node(tree.l, id_map),
171-
_copy_node(tree.r, id_map),
172-
)
173-
end
174-
end
175-
end
176-
result = if is_memoized
177-
begin
178-
$(Expr(:inbounds, true))
179-
local val = id_map[key]
180-
$(Expr(:inbounds, :pop))
181-
val
182-
end
183-
else
184-
id_map[key] = body()
185-
end
186-
return (((x, _) -> begin
187-
x
188-
end)(result, is_memoized))
189-
end
190-
end
191-
@test expr_eql(ex, true_ex)
192-
end
193123
end
194124

195125
@testset "Operations on graphs" begin

0 commit comments

Comments
 (0)