@@ -25,7 +25,7 @@ import Base:
2525
2626using DispatchDoctor: @unstable
2727using Compat: @inline , Returns
28- using .. UtilsModule: @memoize_on , @with_memoize , Undefined
28+ using .. UtilsModule: Undefined
2929
3030"""
3131 tree_mapreduce(
@@ -89,46 +89,83 @@ function tree_mapreduce(
8989 f_leaf:: F1 ,
9090 f_branch:: F2 ,
9191 op:: G ,
92- tree:: AbstractNode ,
92+ tree:: AbstractNode{D} ,
9393 result_type:: Type{RT} = Undefined;
9494 f_on_shared:: H = (result, is_shared) -> result,
95- break_sharing:: Val = Val (false ),
96- ) where {F1<: Function ,F2<: Function ,G<: Function ,H<: Function ,RT}
97-
98- # Trick taken from here:
99- # https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
100- # to speed up recursive closure
101- @memoize_on t f_on_shared function inner (inner, t)
102- if t. degree == 0
103- return @inline (f_leaf (t))
104- elseif t. degree == 1
105- return @inline (op (@inline (f_branch (t)), inner (inner, t. l)))
106- else
107- return @inline (op (@inline (f_branch (t)), inner (inner, t. l), inner (inner, t. r)))
108- end
109- end
110-
111- sharing = preserve_sharing (typeof (tree)) && break_sharing === Val (false )
95+ break_sharing:: Val{BS} = Val (false ),
96+ ) where {F1<: Function ,F2<: Function ,G<: Function ,D,H<: Function ,RT,BS}
97+ sharing = preserve_sharing (typeof (tree)) && ! break_sharing
11298
11399 RT == Undefined &&
114100 sharing &&
115101 throw (ArgumentError (" Need to specify `result_type` if nodes are shared.." ))
116102
117103 if sharing && RT != Undefined
118- d = allocate_id_map (tree, RT)
119- return @with_memoize inner (inner, tree) d
104+ id_map = allocate_id_map (tree, RT)
105+ reducer = TreeMapreducer (Val (D), id_map, f_leaf, f_branch, op, f_on_shared)
106+ return reducer (tree)
107+ else
108+ reducer = TreeMapreducer (Val (D), nothing , f_leaf, f_branch, op, f_on_shared)
109+ return reducer (tree)
110+ end
111+ end
112+
113+ struct TreeMapreducer{D,ID,F1<: Function ,F2<: Function ,G<: Function ,H<: Function }
114+ max_degree:: Val{D}
115+ id_map:: ID
116+ f_leaf:: F1
117+ f_branch:: F2
118+ op:: G
119+ f_on_shared:: H
120+ end
121+
122+ @generated function (mapreducer:: TreeMapreducer{MAX_DEGREE,ID} )(
123+ tree:: AbstractNode
124+ ) where {MAX_DEGREE,ID}
125+ base_expr = quote
126+ d = tree. degree
127+ Base. Cartesian. @nif (
128+ $ (MAX_DEGREE + 1 ),
129+ d_p_one -> (d_p_one - 1 ) == d,
130+ d_p_one -> if d_p_one == 1
131+ mapreducer. f_leaf (tree)
132+ else
133+ mapreducer. op (
134+ mapreducer. f_branch (tree),
135+ Base. Cartesian. @ntuple (
136+ d_p_one - 1 , i -> mapreducer (tree. children[i][])
137+ ). .. ,
138+ )
139+ end
140+ )
141+ end
142+ if ID <: Nothing
143+ # No sharing of nodes (is a tree, not a graph)
144+ return base_expr
120145 else
121- return inner (inner, tree)
146+ # Otherwise, we need to cache results in `id_map`
147+ # according to `objectid` of the node
148+ return quote
149+ key = objectid (tree)
150+ is_cached = haskey (mapreducer. id_map, key)
151+ if is_cached
152+ return mapreducer. f_on_shared (@inbounds (mapreducer. id_map[key]), true )
153+ else
154+ res = $ base_expr
155+ mapreducer. id_map[key] = res
156+ return mapreducer. f_on_shared (res, false )
157+ end
158+ end
122159 end
123160end
161+
124162function allocate_id_map (tree:: AbstractNode , :: Type{RT} ) where {RT}
125163 d = Dict {UInt,RT} ()
126164 # Preallocate maximum storage (counting with duplicates is fast)
127165 N = length (tree; break_sharing= Val (true ))
128166 sizehint! (d, N)
129167 return d
130168end
131-
132169# TODO : Raise Julia issue for this.
133170# Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here!
134171# I think it's because `setindex!` is declared with `@nospecialize` in IdDict.
0 commit comments