Skip to content

Commit 56577b8

Browse files
committed
feat: towards proper buffering of evals
1 parent 4a52c37 commit 56577b8

File tree

1 file changed

+65
-21
lines changed

1 file changed

+65
-21
lines changed

src/Evaluate.jl

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
1515
macro return_on_nonfinite_val(eval_options, val, X)
1616
:(
1717
if $(esc(eval_options)).early_exit isa Val{true} && !is_valid($(esc(val)))
18-
return $(ResultOk)(similar($(esc(X)), axes($(esc(X)), 2)), false)
18+
return $(ResultOk)(_similar($(esc(X)), $(esc(eval_options)), axes($(esc(X)), 2)), false)
1919
end
2020
)
2121
end
@@ -46,10 +46,12 @@ This holds options for expression evaluation, such as evaluation backend.
4646
Setting `Val{false}` will continue the computation as usual and thus result in
4747
`NaN`s only in the elements that actually have `NaN`s.
4848
"""
49-
struct EvalOptions{T,B,E}
49+
struct EvalOptions{T,B,E,A,R}
5050
turbo::Val{T}
5151
bumper::Val{B}
5252
early_exit::Val{E}
53+
buffer::A
54+
buffer_ref::R
5355
end
5456

5557
@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
@@ -59,8 +61,16 @@ end
5961
turbo::Union{Bool,Val}=Val(false),
6062
bumper::Union{Bool,Val}=Val(false),
6163
early_exit::Union{Bool,Val}=Val(true),
64+
buffer::Union{AbstractMatrix,Nothing}=nothing,
65+
buffer_ref::Union{Base.RefValue{<:Integer},Nothing}=nothing,
6266
)
63-
return EvalOptions(_to_bool_val(turbo), _to_bool_val(bumper), _to_bool_val(early_exit))
67+
return EvalOptions(
68+
_to_bool_val(turbo),
69+
_to_bool_val(bumper),
70+
_to_bool_val(early_exit),
71+
buffer,
72+
buffer_ref,
73+
)
6474
end
6575

6676
@unstable function _process_deprecated_kws(eval_options, deprecated_kws)
@@ -193,12 +203,12 @@ function _eval_tree_array(
193203
# First, we see if there are only constants in the tree - meaning
194204
# we can just return the constant result.
195205
if tree.degree == 0
196-
return deg0_eval(tree, cX)
206+
return deg0_eval(tree, cX, eval_options)
197207
elseif is_constant(tree)
198208
# Speed hack for constant trees.
199209
const_result = dispatch_constant_tree(tree, operators)::ResultOk{Vector{T}}
200-
!const_result.ok && return ResultOk(similar(cX, axes(cX, 2)), false)
201-
return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true)
210+
!const_result.ok && return ResultOk(_similar(cX, eval_options, axes(cX, 2)), false)
211+
return ResultOk(_fill_similar(const_result.x[], cX, eval_options, axes(cX, 2)), true)
202212
elseif tree.degree == 1
203213
op_idx = tree.op
204214
return dispatch_deg1_eval(tree, cX, op_idx, operators, eval_options)
@@ -234,12 +244,46 @@ function deg1_eval(
234244
end
235245

236246
function deg0_eval(
237-
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}
247+
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, eval_options::EvalOptions
238248
)::ResultOk where {T}
239249
if tree.constant
240-
return ResultOk(fill_similar(tree.val, cX, axes(cX, 2)), true)
250+
return ResultOk(_fill_similar(tree.val, cX, eval_options, axes(cX, 2)), true)
251+
else
252+
return ResultOk(_index_X(cX, tree.feature, eval_options), true)
253+
end
254+
end
255+
256+
function _fill_similar(value, array, eval_options::EvalOptions, args...)
257+
if eval_options.buffer === nothing
258+
return fill_similar(value, array, args...)
259+
else
260+
# TODO HACK: Treat `axes` here explicitly!
261+
i = eval_options.buffer_ref[]
262+
out = @inbounds(@view(eval_options.buffer[i, :]))
263+
out .= value
264+
eval_options.buffer_ref[] = i + 1
265+
return out
266+
end
267+
end
268+
function _similar(X, eval_options::EvalOptions, args...)
269+
if eval_options.buffer === nothing
270+
return similar(X, args...)
271+
else
272+
i = eval_options.buffer_ref[]
273+
out = @inbounds(@view(eval_options.buffer[i, args...]))
274+
eval_options.buffer_ref[] = i + 1
275+
return out
276+
end
277+
end
278+
function _index_X(X, feature, eval_options::EvalOptions)
279+
if eval_options.buffer === nothing
280+
return X[feature, :]
241281
else
242-
return ResultOk(cX[tree.feature, :], true)
282+
i = eval_options.buffer_ref[]
283+
out = @inbounds(@view(eval_options.buffer[i, :]))
284+
eval_options.buffer_ref[] = i + 1
285+
out .= X[feature, :]
286+
return out
243287
end
244288
end
245289

@@ -401,12 +445,12 @@ function deg1_l2_ll0_lr0_eval(
401445
@return_on_nonfinite_val(eval_options, x_l, cX)
402446
x = op(x_l)::T
403447
@return_on_nonfinite_val(eval_options, x, cX)
404-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
448+
return ResultOk(_fill_similar(x, cX, eval_options, axes(cX, 2)), true)
405449
elseif tree.l.l.constant
406450
val_ll = tree.l.l.val
407451
@return_on_nonfinite_val(eval_options, val_ll, cX)
408452
feature_lr = tree.l.r.feature
409-
cumulator = similar(cX, axes(cX, 2))
453+
cumulator = _similar(cX, eval_options, axes(cX, 2))
410454
@inbounds @simd for j in axes(cX, 2)
411455
x_l = op_l(val_ll, cX[feature_lr, j])::T
412456
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -417,7 +461,7 @@ function deg1_l2_ll0_lr0_eval(
417461
feature_ll = tree.l.l.feature
418462
val_lr = tree.l.r.val
419463
@return_on_nonfinite_val(eval_options, val_lr, cX)
420-
cumulator = similar(cX, axes(cX, 2))
464+
cumulator = _similar(cX, eval_options, axes(cX, 2))
421465
@inbounds @simd for j in axes(cX, 2)
422466
x_l = op_l(cX[feature_ll, j], val_lr)::T
423467
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -427,7 +471,7 @@ function deg1_l2_ll0_lr0_eval(
427471
else
428472
feature_ll = tree.l.l.feature
429473
feature_lr = tree.l.r.feature
430-
cumulator = similar(cX, axes(cX, 2))
474+
cumulator = _similar(cX, eval_options, axes(cX, 2))
431475
@inbounds @simd for j in axes(cX, 2)
432476
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
433477
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -452,10 +496,10 @@ function deg1_l1_ll0_eval(
452496
@return_on_nonfinite_val(eval_options, x_l, cX)
453497
x = op(x_l)::T
454498
@return_on_nonfinite_val(eval_options, x, cX)
455-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
499+
return ResultOk(_fill_similar(x, cX, eval_options, axes(cX, 2)), true)
456500
else
457501
feature_ll = tree.l.l.feature
458-
cumulator = similar(cX, axes(cX, 2))
502+
cumulator = _similar(cX, eval_options, axes(cX, 2))
459503
@inbounds @simd for j in axes(cX, 2)
460504
x_l = op_l(cX[feature_ll, j])::T
461505
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
@@ -479,9 +523,9 @@ function deg2_l0_r0_eval(
479523
@return_on_nonfinite_val(eval_options, val_r, cX)
480524
x = op(val_l, val_r)::T
481525
@return_on_nonfinite_val(eval_options, x, cX)
482-
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
526+
return ResultOk(_fill_similar(x, cX, eval_options, axes(cX, 2)), true)
483527
elseif tree.l.constant
484-
cumulator = similar(cX, axes(cX, 2))
528+
cumulator = _similar(cX, eval_options, axes(cX, 2))
485529
val_l = tree.l.val
486530
@return_on_nonfinite_val(eval_options, val_l, cX)
487531
feature_r = tree.r.feature
@@ -491,7 +535,7 @@ function deg2_l0_r0_eval(
491535
end
492536
return ResultOk(cumulator, true)
493537
elseif tree.r.constant
494-
cumulator = similar(cX, axes(cX, 2))
538+
cumulator = _similar(cX, eval_options, axes(cX, 2))
495539
feature_l = tree.l.feature
496540
val_r = tree.r.val
497541
@return_on_nonfinite_val(eval_options, val_r, cX)
@@ -501,7 +545,7 @@ function deg2_l0_r0_eval(
501545
end
502546
return ResultOk(cumulator, true)
503547
else
504-
cumulator = similar(cX, axes(cX, 2))
548+
cumulator = _similar(cX, eval_options, axes(cX, 2))
505549
feature_l = tree.l.feature
506550
feature_r = tree.r.feature
507551
@inbounds @simd for j in axes(cX, 2)
@@ -664,9 +708,9 @@ end
664708
quote
665709
if tree.degree == 0
666710
if tree.constant
667-
ResultOk(fill_similar(one(T), cX, axes(cX, 2)) .* tree.val, true)
711+
ResultOk(_fill_similar(one(T), cX, eval_options, axes(cX, 2)) .* tree.val, true)
668712
else
669-
ResultOk(cX[tree.feature, :], true)
713+
ResultOk(_index_X(cX, tree.feature, eval_options), true)
670714
end
671715
elseif tree.degree == 1
672716
op_idx = tree.op

0 commit comments

Comments
 (0)