11using DynamicExpressions, BenchmarkTools, Random
22
33# Trigger extensions:
4- using LoopVectorization
5- using Bumper
6- using StrideArrays
7- using Zygote
4+ using LoopVectorization, Bumper, StrideArrays, Zygote
85
96if PACKAGE_VERSION < v " 0.14.0"
107 @eval using DynamicExpressions: Node as GraphNode
1815 @eval using DynamicExpressions. NodeUtilsModule: is_constant
1916end
2017
18+ if PACKAGE_VERSION < v " 0.18.6"
19+ @eval using DynamicExpressions:
20+ index_constants as index_constant_nodes,
21+ count_constants as count_constant_nodes,
22+ get_constants as get_scalar_constants,
23+ set_constants! as set_scalar_constants!
24+ end
25+
2126include (" ../test/tree_gen_utils.jl" )
2227
2328const SUITE = BenchmarkGroup ()
@@ -113,15 +118,16 @@ end
113118 PACKAGE_VERSION < v " 0.14.0" && return :(copy_node (t; preserve_sharing= preserve_sharing))
114119 return :(copy_node (t)) # Assume type used to infer sharing
115120end
116- @generated function get_set_constants! (tree:: N ) where {T,N<: AbstractExpressionNode{T} }
117- if ! (@isdefined set_constants!)
118- return :(set_constants (tree, get_constants (tree)))
119- elseif hasmethod (set_constants!, Tuple{N, Vector{T}})
120- return :(set_constants! (tree, get_constants (tree)))
121+ @generated function get_set_constants! (tree:: N ) where {N}
122+ T = eltype (N)
123+ if ! (@isdefined set_scalar_constants!)
124+ return :(set_scalar_constants (tree, get_scalar_constants (tree)))
125+ elseif hasmethod (set_scalar_constants!, Tuple{N, Vector{T}})
126+ return :(set_scalar_constants! (tree, get_scalar_constants (tree)))
121127 else
122128 return quote
123- let (x, refs) = get_constants (tree)
124- set_constants ! (tree, x, refs)
129+ let (x, refs) = get_scalar_constants (tree)
130+ set_scalar_constants ! (tree, x, refs)
125131 end
126132 end
127133 end
@@ -141,12 +147,12 @@ function benchmark_utilities()
141147 :combine_operators ,
142148 :count_nodes ,
143149 :count_depth ,
144- :count_constants ,
150+ :count_constant_nodes ,
145151 :has_constants ,
146152 :has_operators ,
147153 :is_constant ,
148154 :get_set_constants! ,
149- :index_constants ,
155+ :index_constant_nodes ,
150156 :string_tree ,
151157 :hash ,
152158 )
@@ -157,9 +163,9 @@ function benchmark_utilities()
157163 [
158164 :simplify_tree ,
159165 :count_nodes ,
160- :count_constants ,
166+ :count_constant_nodes ,
161167 :get_set_constants! ,
162- :index_constants ,
168+ :index_constant_nodes ,
163169 :string_tree ,
164170 ],
165171 )
@@ -207,7 +213,8 @@ function benchmark_utilities()
207213 setup= (
208214 ntrees= 100 ;
209215 n= 20 ;
210- trees= [$ preprocess (gen_random_tree_fixed_size (n, $ operators, 5 , Float32)) for _ in 1 : ntrees]
216+ rng= Random. MersenneTwister (0 );
217+ trees= [$ preprocess (gen_random_tree_fixed_size (n, $ operators, 5 , Float32, Node, rng)) for _ in 1 : ntrees]
211218 )
212219 )
213220 # ! format: on
@@ -216,6 +223,37 @@ function benchmark_utilities()
216223 end
217224 end
218225
226+ # Additional methods
227+ @static if PACKAGE_VERSION >= v " 0.18.0"
228+ suite[" get_set_constants_parametric" ] = @benchmarkable (
229+ [get_set_constants! (ex) for ex in exs],
230+ seconds = 10.0 ,
231+ setup = (
232+ operators = $ operators;
233+ ntrees = 100 ;
234+ n = 20 ;
235+ n_features = 5 ;
236+ n_params = 3 ;
237+ n_param_classes = 10 ;
238+ rng = Random. MersenneTwister (0 );
239+ exs = [
240+ let tree = gen_random_tree_fixed_size (
241+ n, operators, n_features, Float32, ParametricNode, rng
242+ )
243+ ex = ParametricExpression (
244+ tree;
245+ operators,
246+ variable_names= map (i -> " x$i " , 1 : n_features),
247+ parameters= randn (rng, Float32, n_params, n_param_classes),
248+ parameter_names= map (i -> " p$i " , 1 : n_params),
249+ )
250+ ex
251+ end for _ in 1 : ntrees
252+ ]
253+ )
254+ )
255+ end
256+
219257 return suite
220258end
221259
0 commit comments