Skip to content

Commit 6e41726

Browse files
committed
refactor: better design for buffered evals
1 parent a9d5a75 commit 6e41726

File tree

1 file changed

+71
-20
lines changed

1 file changed

+71
-20
lines changed

src/Evaluate.jl

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,47 @@ macro return_on_nonfinite_array(eval_options, array)
3030
)
3131
end
3232

33+
"""Buffer management for array allocations during evaluation."""
34+
struct ArrayBuffer{A<:AbstractMatrix,R<:Base.RefValue{<:Integer}}
35+
array::A
36+
index::R
37+
end
38+
39+
reset_index!(buffer::ArrayBuffer) = buffer.index[] = 0
40+
reset_index!(::Nothing) = nothing
41+
42+
next_index!(buffer::ArrayBuffer) = buffer.index[] += 1
43+
44+
function get_array(::Nothing, template::AbstractArray, axes...)
45+
return similar(template, axes...)
46+
end
47+
48+
function get_array(buffer::ArrayBuffer, template::AbstractArray, axes...)
49+
i = next_index!(buffer)
50+
out = @view(buffer.array[i, :])
51+
return out
52+
end
53+
54+
function get_filled_array(::Nothing, value, template::AbstractArray, axes...)
55+
return fill_similar(value, template, axes...)
56+
end
57+
function get_filled_array(buffer::ArrayBuffer, value, template::AbstractArray, axes...)
58+
i = next_index!(buffer)
59+
out = @view(buffer.array[i, :])
60+
out .= value
61+
return out
62+
end
63+
64+
function get_feature_array(::Nothing, X::AbstractMatrix, feature::Integer)
65+
return X[feature, :]
66+
end
67+
function get_feature_array(buffer::ArrayBuffer, X::AbstractMatrix, feature::Integer)
68+
i = next_index!(buffer)
69+
out = @view(buffer.array[i, :])
70+
out .= X[feature, :]
71+
return out
72+
end
73+
3374
"""
3475
EvalOptions{T,B,E}
3576
@@ -48,17 +89,14 @@ This holds options for expression evaluation, such as evaluation backend.
4889
Setting `Val{false}` will continue the computation as usual and thus result in
4990
`NaN`s only in the elements that actually have `NaN`s.
5091
"""
51-
struct EvalOptions{T,B,E,A,R}
92+
93+
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}}
5294
turbo::Val{T}
5395
bumper::Val{B}
5496
early_exit::Val{E}
55-
buffer::A
56-
buffer_ref::R
97+
buffer::BUF
5798
end
5899

59-
@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
60-
@inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool)
61-
62100
@unstable function EvalOptions(;
63101
turbo::Union{Bool,Val}=Val(false),
64102
bumper::Union{Bool,Val}=Val(false),
@@ -69,12 +107,23 @@ end
69107
v_turbo = _to_bool_val(turbo)
70108
v_bumper = _to_bool_val(bumper)
71109
v_early_exit = _to_bool_val(early_exit)
110+
72111
if v_turbo isa Val{true} || v_bumper isa Val{true}
73112
@assert buffer === nothing && buffer_ref === nothing
74113
end
75-
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer, buffer_ref)
114+
115+
array_buffer = if buffer === nothing
116+
nothing
117+
else
118+
ArrayBuffer(buffer, buffer_ref)
119+
end
120+
121+
return EvalOptions(v_turbo, v_bumper, v_early_exit, array_buffer)
76122
end
77123

124+
@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
125+
@inline _to_bool_val(::Val{T}) where {T} = Val(T::Bool)
126+
78127
@unstable function _process_deprecated_kws(eval_options, deprecated_kws)
79128
turbo = get(deprecated_kws, :turbo, nothing)
80129
bumper = get(deprecated_kws, :bumper, nothing)
@@ -165,7 +214,7 @@ function eval_tree_array(
165214
return bumper_eval_tree_array(tree, cX, operators, _eval_options)
166215
end
167216

168-
_reset_buffer_ref!(_eval_options)
217+
reset_index!(_eval_options.buffer)
169218

170219
result = _eval_tree_array(tree, cX, operators, _eval_options)
171220
return (
@@ -251,9 +300,11 @@ function deg0_eval(
251300
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, eval_options::EvalOptions
252301
)::ResultOk where {T}
253302
if tree.constant
254-
return ResultOk(_fill_similar(tree.val, cX, eval_options, axes(cX, 2)), true)
303+
return ResultOk(
304+
get_filled_array(eval_options.buffer, tree.val, cX, axes(cX, 2)), true
305+
)
255306
else
256-
return ResultOk(_index_X(cX, tree.feature, eval_options), true)
307+
return ResultOk(get_feature_array(eval_options.buffer, cX, tree.feature), true)
257308
end
258309
end
259310

@@ -455,12 +506,12 @@ function deg1_l2_ll0_lr0_eval(
455506
@return_on_nonfinite_val(eval_options, x_l, cX)
456507
x = op(x_l)::T
457508
@return_on_nonfinite_val(eval_options, x, cX)
458-
return ResultOk(_fill_similar(x, cX, eval_options, axes(cX, 2)), true)
509+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
459510
elseif tree.l.l.constant
460511
val_ll = tree.l.l.val
461512
@return_on_nonfinite_val(eval_options, val_ll, cX)
462513
feature_lr = tree.l.r.feature
463-
cumulator = _similar(cX, eval_options, axes(cX, 2))
514+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
464515
@inbounds @simd for j in axes(cX, 2)
465516
x_l = op_l(val_ll, cX[feature_lr, j])::T
466517
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -471,7 +522,7 @@ function deg1_l2_ll0_lr0_eval(
471522
feature_ll = tree.l.l.feature
472523
val_lr = tree.l.r.val
473524
@return_on_nonfinite_val(eval_options, val_lr, cX)
474-
cumulator = _similar(cX, eval_options, axes(cX, 2))
525+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
475526
@inbounds @simd for j in axes(cX, 2)
476527
x_l = op_l(cX[feature_ll, j], val_lr)::T
477528
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -481,7 +532,7 @@ function deg1_l2_ll0_lr0_eval(
481532
else
482533
feature_ll = tree.l.l.feature
483534
feature_lr = tree.l.r.feature
484-
cumulator = _similar(cX, eval_options, axes(cX, 2))
535+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
485536
@inbounds @simd for j in axes(cX, 2)
486537
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
487538
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -506,10 +557,10 @@ function deg1_l1_ll0_eval(
506557
@return_on_nonfinite_val(eval_options, x_l, cX)
507558
x = op(x_l)::T
508559
@return_on_nonfinite_val(eval_options, x, cX)
509-
return ResultOk(_fill_similar(x, cX, eval_options, axes(cX, 2)), true)
560+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
510561
else
511562
feature_ll = tree.l.l.feature
512-
cumulator = _similar(cX, eval_options, axes(cX, 2))
563+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
513564
@inbounds @simd for j in axes(cX, 2)
514565
x_l = op_l(cX[feature_ll, j])::T
515566
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -533,9 +584,9 @@ function deg2_l0_r0_eval(
533584
@return_on_nonfinite_val(eval_options, val_r, cX)
534585
x = op(val_l, val_r)::T
535586
@return_on_nonfinite_val(eval_options, x, cX)
536-
return ResultOk(_fill_similar(x, cX, eval_options, axes(cX, 2)), true)
587+
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
537588
elseif tree.l.constant
538-
cumulator = _similar(cX, eval_options, axes(cX, 2))
589+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
539590
val_l = tree.l.val
540591
@return_on_nonfinite_val(eval_options, val_l, cX)
541592
feature_r = tree.r.feature
@@ -545,7 +596,7 @@ function deg2_l0_r0_eval(
545596
end
546597
return ResultOk(cumulator, true)
547598
elseif tree.r.constant
548-
cumulator = _similar(cX, eval_options, axes(cX, 2))
599+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
549600
feature_l = tree.l.feature
550601
val_r = tree.r.val
551602
@return_on_nonfinite_val(eval_options, val_r, cX)
@@ -555,7 +606,7 @@ function deg2_l0_r0_eval(
555606
end
556607
return ResultOk(cumulator, true)
557608
else
558-
cumulator = _similar(cX, eval_options, axes(cX, 2))
609+
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
559610
feature_l = tree.l.feature
560611
feature_r = tree.r.feature
561612
@inbounds @simd for j in axes(cX, 2)

0 commit comments

Comments
 (0)