@@ -15,7 +15,7 @@ const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
1515macro return_on_nonfinite_val (eval_options, val, X)
1616 :(
1717 if $ (esc (eval_options)). early_exit isa Val{true } && ! is_valid ($ (esc (val)))
18- return $ (ResultOk)(similar ($ (esc (X)), axes ($ (esc (X)), 2 )), false )
18+ return $ (ResultOk)(_similar ($ (esc (X)), $ ( esc (eval_options )), axes ($ (esc (X)), 2 )), false )
1919 end
2020 )
2121end
@@ -46,10 +46,12 @@ This holds options for expression evaluation, such as evaluation backend.
4646 Setting `Val{false}` will continue the computation as usual and thus result in
4747 `NaN`s only in the elements that actually have `NaN`s.
4848"""
49- struct EvalOptions{T,B,E}
49+ struct EvalOptions{T,B,E,A,R }
5050 turbo:: Val{T}
5151 bumper:: Val{B}
5252 early_exit:: Val{E}
53+ buffer:: A
54+ buffer_ref:: R
5355end
5456
5557@unstable @inline _to_bool_val (x:: Bool ) = x ? Val (true ) : Val (false )
5961 turbo:: Union{Bool,Val} = Val (false ),
6062 bumper:: Union{Bool,Val} = Val (false ),
6163 early_exit:: Union{Bool,Val} = Val (true ),
64+ buffer:: Union{AbstractMatrix,Nothing} = nothing ,
65+ buffer_ref:: Union{Base.RefValue{<:Integer},Nothing} = nothing ,
6266)
63- return EvalOptions (_to_bool_val (turbo), _to_bool_val (bumper), _to_bool_val (early_exit))
67+ return EvalOptions (
68+ _to_bool_val (turbo),
69+ _to_bool_val (bumper),
70+ _to_bool_val (early_exit),
71+ buffer,
72+ buffer_ref,
73+ )
6474end
6575
6676@unstable function _process_deprecated_kws (eval_options, deprecated_kws)
@@ -193,12 +203,12 @@ function _eval_tree_array(
193203 # First, we see if there are only constants in the tree - meaning
194204 # we can just return the constant result.
195205 if tree. degree == 0
196- return deg0_eval (tree, cX)
206+ return deg0_eval (tree, cX, eval_options )
197207 elseif is_constant (tree)
198208 # Speed hack for constant trees.
199209 const_result = dispatch_constant_tree (tree, operators):: ResultOk{Vector{T}}
200- ! const_result. ok && return ResultOk (similar (cX, axes (cX, 2 )), false )
201- return ResultOk (fill_similar (const_result. x[], cX, axes (cX, 2 )), true )
210+ ! const_result. ok && return ResultOk (_similar (cX, eval_options , axes (cX, 2 )), false )
211+ return ResultOk (_fill_similar (const_result. x[], cX, eval_options , axes (cX, 2 )), true )
202212 elseif tree. degree == 1
203213 op_idx = tree. op
204214 return dispatch_deg1_eval (tree, cX, op_idx, operators, eval_options)
@@ -234,12 +244,46 @@ function deg1_eval(
234244end
235245
236246function deg0_eval (
237- tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T}
247+ tree:: AbstractExpressionNode{T} , cX:: AbstractMatrix{T} , eval_options :: EvalOptions
238248):: ResultOk where {T}
239249 if tree. constant
240- return ResultOk (fill_similar (tree. val, cX, axes (cX, 2 )), true )
250+ return ResultOk (_fill_similar (tree. val, cX, eval_options, axes (cX, 2 )), true )
251+ else
252+ return ResultOk (_index_X (cX, tree. feature, eval_options), true )
253+ end
254+ end
255+
256+ function _fill_similar (value, array, eval_options:: EvalOptions , args... )
257+ if eval_options. buffer === nothing
258+ return fill_similar (value, array, args... )
259+ else
260+ # TODO HACK: Treat `axes` here explicitly!
261+ i = eval_options. buffer_ref[]
262+ out = @inbounds (@view (eval_options. buffer[i, :]))
263+ out .= value
264+ eval_options. buffer_ref[] = i + 1
265+ return out
266+ end
267+ end
268+ function _similar (X, eval_options:: EvalOptions , args... )
269+ if eval_options. buffer === nothing
270+ return similar (X, args... )
271+ else
272+ i = eval_options. buffer_ref[]
273+ out = @inbounds (@view (eval_options. buffer[i, args... ]))
274+ eval_options. buffer_ref[] = i + 1
275+ return out
276+ end
277+ end
278+ function _index_X (X, feature, eval_options:: EvalOptions )
279+ if eval_options. buffer === nothing
280+ return X[feature, :]
241281 else
242- return ResultOk (cX[tree. feature, :], true )
282+ i = eval_options. buffer_ref[]
283+ out = @inbounds (@view (eval_options. buffer[i, :]))
284+ eval_options. buffer_ref[] = i + 1
285+ out .= X[feature, :]
286+ return out
243287 end
244288end
245289
@@ -401,12 +445,12 @@ function deg1_l2_ll0_lr0_eval(
401445 @return_on_nonfinite_val (eval_options, x_l, cX)
402446 x = op (x_l):: T
403447 @return_on_nonfinite_val (eval_options, x, cX)
404- return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
448+ return ResultOk (_fill_similar (x, cX, eval_options , axes (cX, 2 )), true )
405449 elseif tree. l. l. constant
406450 val_ll = tree. l. l. val
407451 @return_on_nonfinite_val (eval_options, val_ll, cX)
408452 feature_lr = tree. l. r. feature
409- cumulator = similar (cX, axes (cX, 2 ))
453+ cumulator = _similar (cX, eval_options , axes (cX, 2 ))
410454 @inbounds @simd for j in axes (cX, 2 )
411455 x_l = op_l (val_ll, cX[feature_lr, j]):: T
412456 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -417,7 +461,7 @@ function deg1_l2_ll0_lr0_eval(
417461 feature_ll = tree. l. l. feature
418462 val_lr = tree. l. r. val
419463 @return_on_nonfinite_val (eval_options, val_lr, cX)
420- cumulator = similar (cX, axes (cX, 2 ))
464+ cumulator = _similar (cX, eval_options , axes (cX, 2 ))
421465 @inbounds @simd for j in axes (cX, 2 )
422466 x_l = op_l (cX[feature_ll, j], val_lr):: T
423467 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -427,7 +471,7 @@ function deg1_l2_ll0_lr0_eval(
427471 else
428472 feature_ll = tree. l. l. feature
429473 feature_lr = tree. l. r. feature
430- cumulator = similar (cX, axes (cX, 2 ))
474+ cumulator = _similar (cX, eval_options , axes (cX, 2 ))
431475 @inbounds @simd for j in axes (cX, 2 )
432476 x_l = op_l (cX[feature_ll, j], cX[feature_lr, j]):: T
433477 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -452,10 +496,10 @@ function deg1_l1_ll0_eval(
452496 @return_on_nonfinite_val (eval_options, x_l, cX)
453497 x = op (x_l):: T
454498 @return_on_nonfinite_val (eval_options, x, cX)
455- return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
499+ return ResultOk (_fill_similar (x, cX, eval_options , axes (cX, 2 )), true )
456500 else
457501 feature_ll = tree. l. l. feature
458- cumulator = similar (cX, axes (cX, 2 ))
502+ cumulator = _similar (cX, eval_options , axes (cX, 2 ))
459503 @inbounds @simd for j in axes (cX, 2 )
460504 x_l = op_l (cX[feature_ll, j]):: T
461505 x = is_valid (x_l) ? op (x_l):: T : T (Inf )
@@ -479,9 +523,9 @@ function deg2_l0_r0_eval(
479523 @return_on_nonfinite_val (eval_options, val_r, cX)
480524 x = op (val_l, val_r):: T
481525 @return_on_nonfinite_val (eval_options, x, cX)
482- return ResultOk (fill_similar (x, cX, axes (cX, 2 )), true )
526+ return ResultOk (_fill_similar (x, cX, eval_options , axes (cX, 2 )), true )
483527 elseif tree. l. constant
484- cumulator = similar (cX, axes (cX, 2 ))
528+ cumulator = _similar (cX, eval_options , axes (cX, 2 ))
485529 val_l = tree. l. val
486530 @return_on_nonfinite_val (eval_options, val_l, cX)
487531 feature_r = tree. r. feature
@@ -491,7 +535,7 @@ function deg2_l0_r0_eval(
491535 end
492536 return ResultOk (cumulator, true )
493537 elseif tree. r. constant
494- cumulator = similar (cX, axes (cX, 2 ))
538+ cumulator = _similar (cX, eval_options , axes (cX, 2 ))
495539 feature_l = tree. l. feature
496540 val_r = tree. r. val
497541 @return_on_nonfinite_val (eval_options, val_r, cX)
@@ -501,7 +545,7 @@ function deg2_l0_r0_eval(
501545 end
502546 return ResultOk (cumulator, true )
503547 else
504- cumulator = similar (cX, axes (cX, 2 ))
548+ cumulator = _similar (cX, eval_options , axes (cX, 2 ))
505549 feature_l = tree. l. feature
506550 feature_r = tree. r. feature
507551 @inbounds @simd for j in axes (cX, 2 )
664708 quote
665709 if tree. degree == 0
666710 if tree. constant
667- ResultOk (fill_similar (one (T), cX, axes (cX, 2 )) .* tree. val, true )
711+ ResultOk (_fill_similar (one (T), cX, eval_options , axes (cX, 2 )) .* tree. val, true )
668712 else
669- ResultOk (cX[ tree. feature, :] , true )
713+ ResultOk (_index_X (cX, tree. feature, eval_options) , true )
670714 end
671715 elseif tree. degree == 1
672716 op_idx = tree. op
0 commit comments