Skip to content

Commit fffaaee

Browse files
committed
feat: compat with turbo mode and buffered evals
1 parent 64956de commit fffaaee

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

ext/DynamicExpressionsLoopVectorizationExt.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ module DynamicExpressionsLoopVectorizationExt
22

33
using LoopVectorization: @turbo
44
using DynamicExpressions: AbstractExpressionNode
5-
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
6-
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
5+
using DynamicExpressions.UtilsModule: ResultOk
6+
using DynamicExpressions.EvaluateModule:
7+
@return_on_nonfinite_val, EvalOptions, get_array, get_feature_array, get_filled_array
78
import DynamicExpressions.EvaluateModule:
89
deg1_eval,
910
deg2_eval,
@@ -56,12 +57,12 @@ function deg1_l2_ll0_lr0_eval(
5657
@return_on_nonfinite_val(eval_options, x_l, cX)
5758
x = op(x_l)::T
5859
@return_on_nonfinite_val(eval_options, x, cX)
59-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
60+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)))
6061
elseif tree.l.l.constant
6162
val_ll = tree.l.l.val
6263
@return_on_nonfinite_val(eval_options, val_ll, cX)
6364
feature_lr = tree.l.r.feature
64-
cumulator = similar(cX, axes(cX, 2))
65+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
6566
@turbo for j in axes(cX, 2)
6667
x_l = op_l(val_ll, cX[feature_lr, j])
6768
x = op(x_l)
@@ -72,7 +73,7 @@ function deg1_l2_ll0_lr0_eval(
7273
feature_ll = tree.l.l.feature
7374
val_lr = tree.l.r.val
7475
@return_on_nonfinite_val(eval_options, val_lr, cX)
75-
cumulator = similar(cX, axes(cX, 2))
76+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
7677
@turbo for j in axes(cX, 2)
7778
x_l = op_l(cX[feature_ll, j], val_lr)
7879
x = op(x_l)
@@ -82,7 +83,7 @@ function deg1_l2_ll0_lr0_eval(
8283
else
8384
feature_ll = tree.l.l.feature
8485
feature_lr = tree.l.r.feature
85-
cumulator = similar(cX, axes(cX, 2))
86+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
8687
@turbo for j in axes(cX, 2)
8788
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])
8889
x = op(x_l)
@@ -106,10 +107,10 @@ function deg1_l1_ll0_eval(
106107
@return_on_nonfinite_val(eval_options, x_l, cX)
107108
x = op(x_l)::T
108109
@return_on_nonfinite_val(eval_options, x, cX)
109-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
110+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)))
110111
else
111112
feature_ll = tree.l.l.feature
112-
cumulator = similar(cX, axes(cX, 2))
113+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
113114
@turbo for j in axes(cX, 2)
114115
x_l = op_l(cX[feature_ll, j])
115116
x = op(x_l)
@@ -132,9 +133,9 @@ function deg2_l0_r0_eval(
132133
@return_on_nonfinite_val(eval_options, val_r, cX)
133134
x = op(val_l, val_r)::T
134135
@return_on_nonfinite_val(eval_options, x, cX)
135-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
136+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)))
136137
elseif tree.l.constant
137-
cumulator = similar(cX, axes(cX, 2))
138+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
138139
val_l = tree.l.val
139140
@return_on_nonfinite_val(eval_options, val_l, cX)
140141
feature_r = tree.r.feature
@@ -144,7 +145,7 @@ function deg2_l0_r0_eval(
144145
end
145146
return ResultOk(cumulator, true)
146147
elseif tree.r.constant
147-
cumulator = similar(cX, axes(cX, 2))
148+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
148149
feature_l = tree.l.feature
149150
val_r = tree.r.val
150151
@return_on_nonfinite_val(eval_options, val_r, cX)
@@ -154,7 +155,7 @@ function deg2_l0_r0_eval(
154155
end
155156
return ResultOk(cumulator, true)
156157
else
157-
cumulator = similar(cX, axes(cX, 2))
158+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
158159
feature_l = tree.l.feature
159160
feature_r = tree.r.feature
160161
@turbo for j in axes(cX, 2)

src/Evaluate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ end
106106
v_bumper = _to_bool_val(bumper)
107107
v_early_exit = _to_bool_val(early_exit)
108108

109-
if v_turbo isa Val{true} || v_bumper isa Val{true}
109+
if v_bumper isa Val{true}
110110
@assert buffer === nothing && buffer_ref === nothing
111111
end
112112

src/EvaluateDerivative.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import ..OperatorEnumModule: OperatorEnum
55
import ..UtilsModule: fill_similar, ResultOk2
66
import ..ValueInterfaceModule: is_valid_array
77
import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, NodeIndex
8-
import ..EvaluateModule: deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN
8+
import ..EvaluateModule:
9+
deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN, EvalOptions
910
import ..ExtensionInterfaceModule: _zygote_gradient
1011

1112
"""
@@ -120,7 +121,7 @@ end
120121
function diff_deg0_eval(
121122
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, direction::Integer
122123
) where {T<:Number}
123-
const_part = deg0_eval(tree, cX).x
124+
const_part = deg0_eval(tree, cX, EvalOptions()).x
124125
derivative_part = if ((!tree.constant) && tree.feature == direction)
125126
fill_similar(one(T), cX, axes(cX, 2))
126127
else
@@ -335,7 +336,7 @@ function grad_deg0_eval(
335336
cX::AbstractMatrix{T},
336337
::Val{mode},
337338
)::ResultOk2 where {T<:Number,mode}
338-
const_part = deg0_eval(tree, cX).x
339+
const_part = deg0_eval(tree, cX, EvalOptions()).x
339340

340341
zero_mat = if isa(cX, Array)
341342
fill_similar(zero(T), cX, n_gradients, axes(cX, 2))

test/test_buffered_evaluation.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ end
119119
@testitem "Random tree buffer evaluation" begin
120120
using DynamicExpressions
121121
using Random
122+
using LoopVectorization
122123
include("tree_gen_utils.jl")
123124

124125
# Test setup
@@ -127,18 +128,19 @@ end
127128
binary_operators=[+, -, *, /], unary_operators=[sin, cos, exp]
128129
)
129130

130-
for i in 1:100
131+
for turbo in (false, true), i in 1:100
131132
# Generate a random tree with varying size (1-10 nodes)
132133
n_nodes = rand(1:10)
133134
tree = gen_random_tree_fixed_size(n_nodes, operators, size(X, 1), Float64, Node)
134135

135136
# Regular evaluation
136-
result1, ok1 = eval_tree_array(tree, X, operators)
137+
eval_options_no_buffer = EvalOptions(; turbo)
138+
result1, ok1 = eval_tree_array(tree, X, operators; eval_options=eval_options_no_buffer)
137139

138140
# Buffer evaluation
139141
buffer = Array{Float64}(undef, 2n_nodes, size(X, 2))
140142
buffer_ref = Ref(rand(1:10)) # Random starting index (will be reset)
141-
eval_options = EvalOptions(; buffer=buffer, buffer_ref=buffer_ref)
143+
eval_options = EvalOptions(; turbo, buffer, buffer_ref)
142144
result2, ok2 = eval_tree_array(tree, X, operators; eval_options)
143145

144146
# Results should be identical

0 commit comments

Comments
 (0)