@@ -24,8 +24,7 @@ import Base:
2424 sum
2525
2626using DispatchDoctor: @unstable
27- using Compat: @inline , Returns
28- using .. UtilsModule: @memoize_on , @with_memoize , Undefined
27+ using .. UtilsModule: Undefined
2928
3029"""
3130 tree_mapreduce(
@@ -94,41 +93,66 @@ function tree_mapreduce(
9493 f_on_shared:: H = (result, is_shared) -> result,
9594 break_sharing:: Val{BS} = Val (false ),
9695) where {F1<: Function ,F2<: Function ,G<: Function ,H<: Function ,RT,BS}
97-
98- # Trick taken from here:
99- # https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
100- # to speed up recursive closure
101- @memoize_on t f_on_shared function inner (inner, t)
102- if t. degree == 0
103- return @inline (f_leaf (t))
104- elseif t. degree == 1
105- return @inline (op (@inline (f_branch (t)), inner (inner, t. l)))
106- else
107- return @inline (op (@inline (f_branch (t)), inner (inner, t. l), inner (inner, t. r)))
108- end
109- end
110-
11196 sharing = preserve_sharing (typeof (tree)) && ! BS
11297
11398 RT == Undefined &&
11499 sharing &&
115100 throw (ArgumentError (" Need to specify `result_type` if nodes are shared.." ))
116101
117102 if sharing && RT != Undefined
118- d = allocate_id_map (tree, RT)
119- return @with_memoize inner (inner, tree) d
103+ id_map = allocate_id_map (tree, RT)
104+ reducer = TreeMapreducer (Val (2 ), id_map, f_leaf, f_branch, op, f_on_shared)
105+ return call_mapreducer (reducer, tree)
106+ else
107+ reducer = TreeMapreducer (Val (2 ), nothing , f_leaf, f_branch, op, f_on_shared)
108+ return call_mapreducer (reducer, tree)
109+ end
110+ end
111+
112+ struct TreeMapreducer{
113+ D,ID<: Union{Nothing,Dict} ,F1<: Function ,F2<: Function ,G<: Function ,H<: Function
114+ }
115+ max_degree:: Val{D}
116+ id_map:: ID
117+ f_leaf:: F1
118+ f_branch:: F2
119+ op:: G
120+ f_on_shared:: H
121+ end
122+
123+ function call_mapreducer (mapreducer:: TreeMapreducer{2,ID} , tree:: AbstractNode ) where {ID}
124+ key = ID <: Dict ? objectid (tree) : nothing
125+ if ID <: Dict && haskey (mapreducer. id_map, key)
126+ result = @inbounds (mapreducer. id_map[key])
127+ return mapreducer. f_on_shared (result, true )
120128 else
121- return inner (inner, tree)
129+ result = if tree. degree == 0
130+ mapreducer. f_leaf (tree)
131+ elseif tree. degree == 1
132+ mapreducer. op (mapreducer. f_branch (tree), call_mapreducer (mapreducer, tree. l))
133+ else
134+ mapreducer. op (
135+ mapreducer. f_branch (tree),
136+ call_mapreducer (mapreducer, tree. l),
137+ call_mapreducer (mapreducer, tree. r),
138+ )
139+ end
140+ if ID <: Dict
141+ mapreducer. id_map[key] = result
142+ return mapreducer. f_on_shared (result, false )
143+ else
144+ return result
145+ end
122146 end
123147end
148+
124149function allocate_id_map (tree:: AbstractNode , :: Type{RT} ) where {RT}
125150 d = Dict {UInt,RT} ()
126151 # Preallocate maximum storage (counting with duplicates is fast)
127152 N = length (tree; break_sharing= Val (true ))
128153 sizehint! (d, N)
129154 return d
130155end
131-
132156# TODO : Raise Julia issue for this.
133157# Surprisingly Dict{UInt,RT} is faster than IdDict{Node{T},RT} here!
134158# I think it's because `setindex!` is declared with `@nospecialize` in IdDict.
0 commit comments