@@ -30,6 +30,47 @@ macro return_on_nonfinite_array(eval_options, array)
3030 )
3131end
3232
33+ """ Buffer management for array allocations during evaluation."""
34+ struct ArrayBuffer{A<: AbstractMatrix ,R<: Base.RefValue{<:Integer} }
35+ array:: A
36+ index:: R
37+ end
38+
39+ reset_index! (buffer:: ArrayBuffer ) = buffer. index[] = 0
40+ reset_index! (:: Nothing ) = nothing
41+
42+ next_index! (buffer:: ArrayBuffer ) = buffer. index[] += 1
43+
44+ function get_array (:: Nothing , template:: AbstractArray , axes... )
45+ return similar (template, axes... )
46+ end
47+
48+ function get_array (buffer:: ArrayBuffer , template:: AbstractArray , axes... )
49+ i = next_index! (buffer)
50+ out = @view (buffer. array[i, :])
51+ return out
52+ end
53+
54+ function get_filled_array (:: Nothing , value, template:: AbstractArray , axes... )
55+ return fill_similar (value, template, axes... )
56+ end
57+ function get_filled_array (buffer:: ArrayBuffer , value, template:: AbstractArray , axes... )
58+ i = next_index! (buffer)
59+ out = @view (buffer. array[i, :])
60+ out .= value
61+ return out
62+ end
63+
64+ function get_feature_array (:: Nothing , X:: AbstractMatrix , feature:: Integer )
65+ return X[feature, :]
66+ end
67+ function get_feature_array (buffer:: ArrayBuffer , X:: AbstractMatrix , feature:: Integer )
68+ i = next_index! (buffer)
69+ out = @view (buffer. array[i, :])
70+ out .= X[feature, :]
71+ return out
72+ end
73+
3374"""
3475 EvalOptions{T,B,E}
3576
@@ -48,17 +89,14 @@ This holds options for expression evaluation, such as evaluation backend.
4889 Setting `Val{false}` will continue the computation as usual and thus result in
4990 `NaN`s only in the elements that actually have `NaN`s.
5091"""
51- struct EvalOptions{T,B,E,A,R}
92+
93+ struct EvalOptions{T,B,E,BUF<: Union{ArrayBuffer,Nothing} }
5294 turbo:: Val{T}
5395 bumper:: Val{B}
5496 early_exit:: Val{E}
55- buffer:: A
56- buffer_ref:: R
97+ buffer:: BUF
5798end
5899
59- @unstable @inline _to_bool_val (x:: Bool ) = x ? Val (true ) : Val (false )
60- @inline _to_bool_val (x:: Val{T} ) where {T} = Val (T:: Bool )
61-
62100@unstable function EvalOptions (;
63101 turbo:: Union{Bool,Val} = Val (false ),
64102 bumper:: Union{Bool,Val} = Val (false ),
69107 v_turbo = _to_bool_val (turbo)
70108 v_bumper = _to_bool_val (bumper)
71109 v_early_exit = _to_bool_val (early_exit)
110+
72111 if v_turbo isa Val{true } || v_bumper isa Val{true }
73112 @assert buffer === nothing && buffer_ref === nothing
74113 end
75- return EvalOptions (v_turbo, v_bumper, v_early_exit, buffer, buffer_ref)
114+
115+ array_buffer = if buffer === nothing
116+ nothing
117+ else
118+ ArrayBuffer (buffer, buffer_ref)
119+ end
120+
121+ return EvalOptions (v_turbo, v_bumper, v_early_exit, array_buffer)
76122end
77123
124+ @unstable @inline _to_bool_val (x:: Bool ) = x ? Val (true ) : Val (false )
125+ @inline _to_bool_val (:: Val{T} ) where {T} = Val (T:: Bool )
126+
78127@unstable function _process_deprecated_kws (eval_options, deprecated_kws)
79128 turbo = get (deprecated_kws, :turbo , nothing )
80129 bumper = get (deprecated_kws, :bumper , nothing )
@@ -165,7 +214,7 @@ function eval_tree_array(
165214 return bumper_eval_tree_array (tree, cX, operators, _eval_options)
166215 end
167216
168- _reset_buffer_ref ! (_eval_options)
217+ reset_index ! (_eval_options. buffer )
169218
170219 result = _eval_tree_array (tree, cX, operators, _eval_options)
171220 return (
@@ -251,9 +300,11 @@ function deg0_eval(
251300 tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , eval_options:: EvalOptions
252301):: ResultOk where {T}
253302 if tree. constant
254- return ResultOk (_fill_similar (tree. val, cX, eval_options, axes (cX, 2 )), true )
303+ return ResultOk (
304+ get_filled_array (eval_options. buffer, tree. val, cX, axes (cX, 2 )), true
305+ )
255306 else
256- return ResultOk (_index_X ( cX, tree. feature, eval_options ), true )
307+ return ResultOk (get_feature_array (eval_options . buffer, cX, tree. feature), true )
257308 end
258309end
259310
@@ -455,12 +506,12 @@ function deg1_l2_ll0_lr0_eval(
455506 @return_on_nonfinite_val (eval_options, x_l, cX)
456507 x = op (x_l):: T
457508 @return_on_nonfinite_val (eval_options, x, cX)
458- return ResultOk (_fill_similar (x, cX, eval_options , axes (cX, 2 )), true )
509+ return ResultOk (get_filled_array (eval_options . buffer, x, cX , axes (cX, 2 )), true )
459510 elseif tree. l. l. constant
460511 val_ll = tree. l. l. val
461512 @return_on_nonfinite_val (eval_options, val_ll, cX)
462513 feature_lr = tree. l. r. feature
463- cumulator = _similar (cX, eval_options , axes (cX, 2 ))
514+ cumulator = get_array (eval_options . buffer, cX , axes (cX, 2 ))
464515 @inbounds @simd for j in axes (cX, 2 )
465516 x_l = op_l (val_ll, cX[feature_lr, j]):: T
466517 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -471,7 +522,7 @@ function deg1_l2_ll0_lr0_eval(
471522 feature_ll = tree. l. l. feature
472523 val_lr = tree. l. r. val
473524 @return_on_nonfinite_val (eval_options, val_lr, cX)
474- cumulator = _similar (cX, eval_options , axes (cX, 2 ))
525+ cumulator = get_array (eval_options . buffer, cX , axes (cX, 2 ))
475526 @inbounds @simd for j in axes (cX, 2 )
476527 x_l = op_l (cX[feature_ll, j], val_lr):: T
477528 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -481,7 +532,7 @@ function deg1_l2_ll0_lr0_eval(
481532 else
482533 feature_ll = tree. l. l. feature
483534 feature_lr = tree. l. r. feature
484- cumulator = _similar (cX, eval_options , axes (cX, 2 ))
535+ cumulator = get_array (eval_options . buffer, cX , axes (cX, 2 ))
485536 @inbounds @simd for j in axes (cX, 2 )
486537 x_l = op_l (cX[feature_ll, j], cX[feature_lr, j]):: T
487538 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -506,10 +557,10 @@ function deg1_l1_ll0_eval(
506557 @return_on_nonfinite_val (eval_options, x_l, cX)
507558 x = op (x_l):: T
508559 @return_on_nonfinite_val (eval_options, x, cX)
509- return ResultOk (_fill_similar (x, cX, eval_options , axes (cX, 2 )), true )
560+ return ResultOk (get_filled_array (eval_options . buffer, x, cX , axes (cX, 2 )), true )
510561 else
511562 feature_ll = tree. l. l. feature
512- cumulator = _similar (cX, eval_options , axes (cX, 2 ))
563+ cumulator = get_array (eval_options . buffer, cX , axes (cX, 2 ))
513564 @inbounds @simd for j in axes (cX, 2 )
514565 x_l = op_l (cX[feature_ll, j]):: T
515566 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -533,9 +584,9 @@ function deg2_l0_r0_eval(
533584 @return_on_nonfinite_val (eval_options, val_r, cX)
534585 x = op (val_l, val_r):: T
535586 @return_on_nonfinite_val (eval_options, x, cX)
536- return ResultOk (_fill_similar (x, cX, eval_options , axes (cX, 2 )), true )
587+ return ResultOk (get_filled_array (eval_options . buffer, x, cX , axes (cX, 2 )), true )
537588 elseif tree. l. constant
538- cumulator = _similar (cX, eval_options , axes (cX, 2 ))
589+ cumulator = get_array (eval_options . buffer, cX , axes (cX, 2 ))
539590 val_l = tree. l. val
540591 @return_on_nonfinite_val (eval_options, val_l, cX)
541592 feature_r = tree. r. feature
@@ -545,7 +596,7 @@ function deg2_l0_r0_eval(
545596 end
546597 return ResultOk (cumulator, true )
547598 elseif tree. r. constant
548- cumulator = _similar (cX, eval_options , axes (cX, 2 ))
599+ cumulator = get_array (eval_options . buffer, cX , axes (cX, 2 ))
549600 feature_l = tree. l. feature
550601 val_r = tree. r. val
551602 @return_on_nonfinite_val (eval_options, val_r, cX)
@@ -555,7 +606,7 @@ function deg2_l0_r0_eval(
555606 end
556607 return ResultOk (cumulator, true )
557608 else
558- cumulator = _similar (cX, eval_options , axes (cX, 2 ))
609+ cumulator = get_array (eval_options . buffer, cX , axes (cX, 2 ))
559610 feature_l = tree. l. feature
560611 feature_r = tree. r. feature
561612 @inbounds @simd for j in axes (cX, 2 )
0 commit comments