Skip to content

Commit e44c0f9

Browse files
committed
refactor: fix type instability in buffered eval
1 parent 56577b8 commit e44c0f9

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

src/Evaluate.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
1515
macro return_on_nonfinite_val(eval_options, val, X)
1616
:(
1717
if $(esc(eval_options)).early_exit isa Val{true} && !is_valid($(esc(val)))
18-
return $(ResultOk)(_similar($(esc(X)), $(esc(eval_options)), axes($(esc(X)), 2)), false)
18+
return $(ResultOk)(
19+
_similar($(esc(X)), $(esc(eval_options)), axes($(esc(X)), 2)), false
20+
)
1921
end
2022
)
2123
end
@@ -64,13 +66,13 @@ end
6466
buffer::Union{AbstractMatrix,Nothing}=nothing,
6567
buffer_ref::Union{Base.RefValue{<:Integer},Nothing}=nothing,
6668
)
67-
return EvalOptions(
68-
_to_bool_val(turbo),
69-
_to_bool_val(bumper),
70-
_to_bool_val(early_exit),
71-
buffer,
72-
buffer_ref,
73-
)
69+
v_turbo = _to_bool_val(turbo)
70+
v_bumper = _to_bool_val(bumper)
71+
v_early_exit = _to_bool_val(early_exit)
72+
if v_turbo isa Val{true} || v_bumper isa Val{true}
73+
@assert buffer === nothing && buffer_ref === nothing
74+
end
75+
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer, buffer_ref)
7476
end
7577

7678
@unstable function _process_deprecated_kws(eval_options, deprecated_kws)
@@ -163,6 +165,8 @@ function eval_tree_array(
163165
return bumper_eval_tree_array(tree, cX, operators, _eval_options)
164166
end
165167

168+
_reset_buffer_ref!(_eval_options)
169+
166170
result = _eval_tree_array(tree, cX, operators, _eval_options)
167171
return (
168172
result.x,
@@ -208,7 +212,9 @@ function _eval_tree_array(
208212
# Speed hack for constant trees.
209213
const_result = dispatch_constant_tree(tree, operators)::ResultOk{Vector{T}}
210214
!const_result.ok && return ResultOk(_similar(cX, eval_options, axes(cX, 2)), false)
211-
return ResultOk(_fill_similar(const_result.x[], cX, eval_options, axes(cX, 2)), true)
215+
return ResultOk(
216+
_fill_similar(const_result.x[], cX, eval_options, axes(cX, 2)), true
217+
)
212218
elseif tree.degree == 1
213219
op_idx = tree.op
214220
return dispatch_deg1_eval(tree, cX, op_idx, operators, eval_options)
@@ -253,13 +259,19 @@ function deg0_eval(
253259
end
254260
end
255261

262+
function _reset_buffer_ref!(eval_options::EvalOptions)
263+
if eval_options.buffer_ref !== nothing
264+
eval_options.buffer_ref[] = 1
265+
end
266+
return nothing
267+
end
256268
function _fill_similar(value, array, eval_options::EvalOptions, args...)
257269
if eval_options.buffer === nothing
258270
return fill_similar(value, array, args...)
259271
else
260272
# TODO HACK: Treat `axes` here explicitly!
261273
i = eval_options.buffer_ref[]
262-
out = @inbounds(@view(eval_options.buffer[i, :]))
274+
out = @view(eval_options.buffer[i, :])
263275
out .= value
264276
eval_options.buffer_ref[] = i + 1
265277
return out
@@ -270,7 +282,7 @@ function _similar(X, eval_options::EvalOptions, args...)
270282
return similar(X, args...)
271283
else
272284
i = eval_options.buffer_ref[]
273-
out = @inbounds(@view(eval_options.buffer[i, args...]))
285+
out = @view(eval_options.buffer[i, :])
274286
eval_options.buffer_ref[] = i + 1
275287
return out
276288
end
@@ -280,7 +292,7 @@ function _index_X(X, feature, eval_options::EvalOptions)
280292
return X[feature, :]
281293
else
282294
i = eval_options.buffer_ref[]
283-
out = @inbounds(@view(eval_options.buffer[i, :]))
295+
out = @view(eval_options.buffer[i, :])
284296
eval_options.buffer_ref[] = i + 1
285297
out .= X[feature, :]
286298
return out
@@ -708,9 +720,9 @@ end
708720
quote
709721
if tree.degree == 0
710722
if tree.constant
711-
ResultOk(_fill_similar(one(T), cX, eval_options, axes(cX, 2)) .* tree.val, true)
723+
ResultOk(fill_similar(one(T), cX, axes(cX, 2)) .* tree.val, true)
712724
else
713-
ResultOk(_index_X(cX, tree.feature, eval_options), true)
725+
ResultOk(cX[tree.feature, :], true)
714726
end
715727
elseif tree.degree == 1
716728
op_idx = tree.op

0 commit comments

Comments
 (0)