Skip to content

Commit 15d9efb

Browse files
authored
Merge pull request #112 from SymbolicML/buffered-evals
Zero-allocation tree evaluation with buffer
2 parents 4a52c37 + 0b5e4a8 commit 15d9efb

File tree

6 files changed

+267
-54
lines changed

6 files changed

+267
-54
lines changed

ext/DynamicExpressionsLoopVectorizationExt.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ module DynamicExpressionsLoopVectorizationExt
22

33
using LoopVectorization: @turbo
44
using DynamicExpressions: AbstractExpressionNode
5-
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
6-
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
5+
using DynamicExpressions.UtilsModule: ResultOk
6+
using DynamicExpressions.EvaluateModule:
7+
@return_on_nonfinite_val, EvalOptions, get_array, get_feature_array, get_filled_array
78
import DynamicExpressions.EvaluateModule:
89
deg1_eval,
910
deg2_eval,
@@ -56,12 +57,12 @@ function deg1_l2_ll0_lr0_eval(
5657
@return_on_nonfinite_val(eval_options, x_l, cX)
5758
x = op(x_l)::T
5859
@return_on_nonfinite_val(eval_options, x, cX)
59-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
60+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
6061
elseif tree.l.l.constant
6162
val_ll = tree.l.l.val
6263
@return_on_nonfinite_val(eval_options, val_ll, cX)
6364
feature_lr = tree.l.r.feature
64-
cumulator = similar(cX, axes(cX, 2))
65+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
6566
@turbo for j in axes(cX, 2)
6667
x_l = op_l(val_ll, cX[feature_lr, j])
6768
x = op(x_l)
@@ -72,7 +73,7 @@ function deg1_l2_ll0_lr0_eval(
7273
feature_ll = tree.l.l.feature
7374
val_lr = tree.l.r.val
7475
@return_on_nonfinite_val(eval_options, val_lr, cX)
75-
cumulator = similar(cX, axes(cX, 2))
76+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
7677
@turbo for j in axes(cX, 2)
7778
x_l = op_l(cX[feature_ll, j], val_lr)
7879
x = op(x_l)
@@ -82,7 +83,7 @@ function deg1_l2_ll0_lr0_eval(
8283
else
8384
feature_ll = tree.l.l.feature
8485
feature_lr = tree.l.r.feature
85-
cumulator = similar(cX, axes(cX, 2))
86+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
8687
@turbo for j in axes(cX, 2)
8788
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])
8889
x = op(x_l)
@@ -106,10 +107,10 @@ function deg1_l1_ll0_eval(
106107
@return_on_nonfinite_val(eval_options, x_l, cX)
107108
x = op(x_l)::T
108109
@return_on_nonfinite_val(eval_options, x, cX)
109-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
110+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
110111
else
111112
feature_ll = tree.l.l.feature
112-
cumulator = similar(cX, axes(cX, 2))
113+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
113114
@turbo for j in axes(cX, 2)
114115
x_l = op_l(cX[feature_ll, j])
115116
x = op(x_l)
@@ -132,9 +133,9 @@ function deg2_l0_r0_eval(
132133
@return_on_nonfinite_val(eval_options, val_r, cX)
133134
x = op(val_l, val_r)::T
134135
@return_on_nonfinite_val(eval_options, x, cX)
135-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
136+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
136137
elseif tree.l.constant
137-
cumulator = similar(cX, axes(cX, 2))
138+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
138139
val_l = tree.l.val
139140
@return_on_nonfinite_val(eval_options, val_l, cX)
140141
feature_r = tree.r.feature
@@ -144,7 +145,7 @@ function deg2_l0_r0_eval(
144145
end
145146
return ResultOk(cumulator, true)
146147
elseif tree.r.constant
147-
cumulator = similar(cX, axes(cX, 2))
148+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
148149
feature_l = tree.l.feature
149150
val_r = tree.r.val
150151
@return_on_nonfinite_val(eval_options, val_r, cX)
@@ -154,7 +155,7 @@ function deg2_l0_r0_eval(
154155
end
155156
return ResultOk(cumulator, true)
156157
else
157-
cumulator = similar(cX, axes(cX, 2))
158+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
158159
feature_l = tree.l.feature
159160
feature_r = tree.r.feature
160161
@turbo for j in axes(cX, 2)

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ import .StringsModule: get_op_name
7373
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
7474
@reexport import .EvaluateModule:
7575
eval_tree_array, differentiable_eval_tree_array, EvalOptions
76+
import .EvaluateModule: ArrayBuffer
7677
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
7778
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
7879
@reexport import .SimplifyModule: combine_operators, simplify_tree!

0 commit comments

Comments
 (0)