Skip to content

Commit d9d05e2

Browse files
committed
feat: require user to pass ArrayBuffer object explicitly
1 parent e6067d1 commit d9d05e2

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ import .StringsModule: get_op_name
7373
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
7474
@reexport import .EvaluateModule:
7575
eval_tree_array, differentiable_eval_tree_array, EvalOptions
76+
import .EvaluateModule: ArrayBuffer
7677
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
7778
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
7879
@reexport import .SimplifyModule: combine_operators, simplify_tree!

src/Evaluate.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function get_feature_array(buffer::ArrayBuffer, X::AbstractMatrix, feature::Inte
7070
end
7171

7272
"""
73-
EvalOptions{T,B,E}
73+
EvalOptions
7474
7575
This holds options for expression evaluation, such as evaluation backend.
7676
@@ -86,8 +86,10 @@ This holds options for expression evaluation, such as evaluation backend.
8686
the entire buffer. This early exit is performed to avoid wasting compute cycles.
8787
Setting `Val{false}` will continue the computation as usual and thus result in
8888
`NaN`s only in the elements that actually have `NaN`s.
89+
- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
90+
This should be an instance of `ArrayBuffer` which has an `array` field and an
91+
`index` field used to iterate which buffer slot to use.
8992
"""
90-
9193
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}}
9294
turbo::Val{T}
9395
bumper::Val{B}
@@ -99,24 +101,17 @@ end
99101
turbo::Union{Bool,Val}=Val(false),
100102
bumper::Union{Bool,Val}=Val(false),
101103
early_exit::Union{Bool,Val}=Val(true),
102-
buffer::Union{AbstractMatrix,Nothing}=nothing,
103-
buffer_ref::Union{Base.RefValue{<:Integer},Nothing}=nothing,
104+
buffer::Union{ArrayBuffer,Nothing}=nothing,
104105
)
105106
v_turbo = _to_bool_val(turbo)
106107
v_bumper = _to_bool_val(bumper)
107108
v_early_exit = _to_bool_val(early_exit)
108109

109110
if v_bumper isa Val{true}
110-
@assert buffer === nothing && buffer_ref === nothing
111-
end
112-
113-
array_buffer = if buffer === nothing
114-
nothing
115-
else
116-
ArrayBuffer(buffer, buffer_ref)
111+
@assert buffer === nothing
117112
end
118113

119-
return EvalOptions(v_turbo, v_bumper, v_early_exit, array_buffer)
114+
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer)
120115
end
121116

122117
@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)

test/test_buffered_evaluation.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@testitem "Buffer creation and validation" begin
22
using DynamicExpressions
3+
using DynamicExpressions: ArrayBuffer
34

45
# Test data setup
56
X = rand(2, 10) # 2 features, 10 samples
@@ -11,13 +12,13 @@
1112
# Basic buffer creation - buffer shape should match (num_leafs, num_samples)
1213
buffer = zeros(5, size(X, 2)) # 5 leaves should be enough for our test tree
1314
buffer_ref = Ref(0)
14-
eval_options = EvalOptions(; buffer=buffer, buffer_ref=buffer_ref)
15+
eval_options = EvalOptions(; buffer=ArrayBuffer(buffer, buffer_ref))
1516
@test eval_options.buffer.array === buffer
1617
@test eval_options.buffer.index === buffer_ref
1718

1819
# Test buffer is not allowed with bumper
1920
@test_throws AssertionError EvalOptions(;
20-
bumper=true, buffer=buffer, buffer_ref=buffer_ref
21+
bumper=true, buffer=ArrayBuffer(buffer, buffer_ref)
2122
)
2223

2324
# Basic evaluation should work
@@ -28,6 +29,7 @@ end
2829

2930
@testitem "Buffer correctness" begin
3031
using DynamicExpressions
32+
using DynamicExpressions: ArrayBuffer
3133

3234
X = rand(2, 10)
3335
operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin])
@@ -45,7 +47,7 @@ end
4547
# Evaluation with buffer
4648
buffer = zeros(5, size(X, 2))
4749
buffer_ref = Ref(0)
48-
eval_options = EvalOptions(; buffer=buffer, buffer_ref=buffer_ref)
50+
eval_options = EvalOptions(; buffer=ArrayBuffer(buffer, buffer_ref))
4951
result2, ok2 = eval_tree_array(tree, X, operators; eval_options)
5052

5153
# Results should be identical
@@ -56,7 +58,7 @@ end
5658

5759
@testitem "Buffer index management" begin
5860
using DynamicExpressions
59-
61+
using DynamicExpressions: ArrayBuffer
6062
X = rand(2, 10)
6163
operators = OperatorEnum(; binary_operators=[+, *], unary_operators=[sin])
6264

@@ -70,7 +72,7 @@ end
7072
# This tree needs more buffer space due to intermediate computations
7173
buffer = zeros(10, size(X, 2))
7274
buffer_ref = Ref(0)
73-
eval_options = EvalOptions(; buffer=buffer, buffer_ref=buffer_ref)
75+
eval_options = EvalOptions(; buffer=ArrayBuffer(buffer, buffer_ref))
7476

7577
# Index should start at 1
7678
@test buffer_ref[] == 0
@@ -93,6 +95,7 @@ end
9395

9496
@testitem "Buffer error handling" begin
9597
using DynamicExpressions
98+
using DynamicExpressions: ArrayBuffer
9699

97100
X = rand(2, 10)
98101
operators = OperatorEnum(; binary_operators=[+, /, *], unary_operators=[sin])
@@ -106,7 +109,7 @@ end
106109

107110
buffer = zeros(5, size(X, 2))
108111
buffer_ref = Ref(0)
109-
eval_options = EvalOptions(; buffer=buffer, buffer_ref=buffer_ref)
112+
eval_options = EvalOptions(; buffer=ArrayBuffer(buffer, buffer_ref))
110113

111114
# Test with early_exit=true
112115
result1, ok1 = eval_tree_array(tree, X, operators; eval_options)
@@ -115,6 +118,7 @@ end
115118

116119
@testitem "Random tree buffer evaluation" begin
117120
using DynamicExpressions
121+
using DynamicExpressions: ArrayBuffer
118122
using Random
119123
using LoopVectorization
120124
include("tree_gen_utils.jl")
@@ -139,7 +143,7 @@ end
139143
# Buffer evaluation
140144
buffer = Array{Float64}(undef, 2n_nodes, size(X, 2))
141145
buffer_ref = Ref(rand(1:10)) # Random starting index (will be reset)
142-
eval_options = EvalOptions(; turbo, buffer, buffer_ref)
146+
eval_options = EvalOptions(; turbo, buffer=ArrayBuffer(buffer, buffer_ref))
143147
result2, ok2 = eval_tree_array(tree, X, operators; eval_options)
144148

145149
# Results should be identical

0 commit comments

Comments
 (0)