Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 99 additions & 37 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
macro return_on_nonfinite_val(eval_options, val, X)
:(
if $(esc(eval_options)).early_exit isa Val{true} && !is_valid($(esc(val)))
return $(ResultOk)(similar($(esc(X)), axes($(esc(X)), 2)), false)
return $(ResultOk)(
get_array($(esc(eval_options)).buffer, $(esc(X)), axes($(esc(X)), 2)), false
)
end
)
end
Expand All @@ -28,6 +30,45 @@ macro return_on_nonfinite_array(eval_options, array)
)
end

"""Buffer management for array allocations during evaluation."""
struct ArrayBuffer{A<:AbstractMatrix,R<:Base.RefValue{<:Integer}}
array::A
index::R
end

reset_index!(buffer::ArrayBuffer) = buffer.index[] = 0
reset_index!(::Nothing) = nothing

next_index!(buffer::ArrayBuffer) = buffer.index[] += 1

function get_array(::Nothing, template::AbstractArray, axes...)
return similar(template, axes...)
end

function get_array(buffer::ArrayBuffer, template::AbstractArray, axes...)
i = next_index!(buffer)
out = @view(buffer.array[i, :])
return out
end

function get_filled_array(::Nothing, value, template::AbstractArray, axes...)
return fill_similar(value, template, axes...)
end
function get_filled_array(buffer::ArrayBuffer, value, template::AbstractArray, axes...)
i = next_index!(buffer)
@inbounds buffer.array[i, :] .= value
return @view(buffer.array[i, :])
end

function get_feature_array(::Nothing, X::AbstractMatrix, feature::Integer)
return @inbounds(X[feature, :])
end
function get_feature_array(buffer::ArrayBuffer, X::AbstractMatrix, feature::Integer)
i = next_index!(buffer)
@inbounds buffer.array[i, :] .= X[feature, :]
return @view(buffer.array[i, :])
end

"""
EvalOptions{T,B,E}

Expand All @@ -46,23 +87,41 @@ This holds options for expression evaluation, such as evaluation backend.
Setting `Val{false}` will continue the computation as usual and thus result in
`NaN`s only in the elements that actually have `NaN`s.
"""
struct EvalOptions{T,B,E}

struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}}
turbo::Val{T}
bumper::Val{B}
early_exit::Val{E}
buffer::BUF
end

@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
@inline _to_bool_val(x::Val{T}) where {T} = Val(T::Bool)

@unstable function EvalOptions(;
turbo::Union{Bool,Val}=Val(false),
bumper::Union{Bool,Val}=Val(false),
early_exit::Union{Bool,Val}=Val(true),
buffer::Union{AbstractMatrix,Nothing}=nothing,
buffer_ref::Union{Base.RefValue{<:Integer},Nothing}=nothing,
)
return EvalOptions(_to_bool_val(turbo), _to_bool_val(bumper), _to_bool_val(early_exit))
v_turbo = _to_bool_val(turbo)
v_bumper = _to_bool_val(bumper)
v_early_exit = _to_bool_val(early_exit)

if v_turbo isa Val{true} || v_bumper isa Val{true}
@assert buffer === nothing && buffer_ref === nothing
end

array_buffer = if buffer === nothing
nothing
else
ArrayBuffer(buffer, buffer_ref)
end

return EvalOptions(v_turbo, v_bumper, v_early_exit, array_buffer)
end

@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
@inline _to_bool_val(::Val{T}) where {T} = Val(T::Bool)

@unstable function _process_deprecated_kws(eval_options, deprecated_kws)
turbo = get(deprecated_kws, :turbo, nothing)
bumper = get(deprecated_kws, :bumper, nothing)
Expand Down Expand Up @@ -153,6 +212,8 @@ function eval_tree_array(
return bumper_eval_tree_array(tree, cX, operators, _eval_options)
end

reset_index!(_eval_options.buffer)

result = _eval_tree_array(tree, cX, operators, _eval_options)
return (
result.x,
Expand Down Expand Up @@ -193,12 +254,15 @@ function _eval_tree_array(
# First, we see if there are only constants in the tree - meaning
# we can just return the constant result.
if tree.degree == 0
return deg0_eval(tree, cX)
return deg0_eval(tree, cX, eval_options)
elseif is_constant(tree)
# Speed hack for constant trees.
const_result = dispatch_constant_tree(tree, operators)::ResultOk{Vector{T}}
!const_result.ok && return ResultOk(similar(cX, axes(cX, 2)), false)
return ResultOk(fill_similar(const_result.x[], cX, axes(cX, 2)), true)
const_result = dispatch_constant_tree(tree, operators)::ResultOk{T}
!const_result.ok &&
return ResultOk(get_array(eval_options.buffer, cX, axes(cX, 2)), false)
return ResultOk(
get_filled_array(eval_options.buffer, const_result.x, cX, axes(cX, 2)), true
)
elseif tree.degree == 1
op_idx = tree.op
return dispatch_deg1_eval(tree, cX, op_idx, operators, eval_options)
Expand Down Expand Up @@ -234,12 +298,14 @@ function deg1_eval(
end

function deg0_eval(
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}
tree::AbstractExpressionNode{T}, cX::AbstractMatrix{T}, eval_options::EvalOptions
)::ResultOk where {T}
if tree.constant
return ResultOk(fill_similar(tree.val, cX, axes(cX, 2)), true)
return ResultOk(
get_filled_array(eval_options.buffer, tree.val, cX, axes(cX, 2)), true
)
else
return ResultOk(cX[tree.feature, :], true)
return ResultOk(get_feature_array(eval_options.buffer, cX, tree.feature), true)
end
end

Expand Down Expand Up @@ -401,12 +467,12 @@ function deg1_l2_ll0_lr0_eval(
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
elseif tree.l.l.constant
val_ll = tree.l.l.val
@return_on_nonfinite_val(eval_options, val_ll, cX)
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(val_ll, cX[feature_lr, j])::T
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
Expand All @@ -417,7 +483,7 @@ function deg1_l2_ll0_lr0_eval(
feature_ll = tree.l.l.feature
val_lr = tree.l.r.val
@return_on_nonfinite_val(eval_options, val_lr, cX)
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], val_lr)::T
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
Expand All @@ -427,7 +493,7 @@ function deg1_l2_ll0_lr0_eval(
else
feature_ll = tree.l.l.feature
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])::T
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
Expand All @@ -452,10 +518,10 @@ function deg1_l1_ll0_eval(
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
else
feature_ll = tree.l.l.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@inbounds @simd for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j])::T
x = is_valid(x_l) ? op(x_l)::T : T(Inf)
Expand All @@ -479,9 +545,9 @@ function deg2_l0_r0_eval(
@return_on_nonfinite_val(eval_options, val_r, cX)
x = op(val_l, val_r)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
elseif tree.l.constant
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
val_l = tree.l.val
@return_on_nonfinite_val(eval_options, val_l, cX)
feature_r = tree.r.feature
Expand All @@ -491,7 +557,7 @@ function deg2_l0_r0_eval(
end
return ResultOk(cumulator, true)
elseif tree.r.constant
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
feature_l = tree.l.feature
val_r = tree.r.val
@return_on_nonfinite_val(eval_options, val_r, cX)
Expand All @@ -501,7 +567,7 @@ function deg2_l0_r0_eval(
end
return ResultOk(cumulator, true)
else
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
feature_l = tree.l.feature
feature_r = tree.r.feature
@inbounds @simd for j in axes(cX, 2)
Expand Down Expand Up @@ -578,37 +644,33 @@ over an entire array when the values are all the same.
nbin = get_nbin(operators)
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
quote
deg1_eval_constant(tree, operators.unaops[op_idx], operators)::ResultOk{Vector{T}}
deg1_eval_constant(tree, operators.unaops[op_idx], operators)::ResultOk{T}
end
else
quote
Base.Cartesian.@nif(
$nuna,
i -> i == op_idx,
i -> deg1_eval_constant(
tree, operators.unaops[i], operators
)::ResultOk{Vector{T}}
i -> deg1_eval_constant(tree, operators.unaops[i], operators)::ResultOk{T}
)
end
end
deg2_branch = if nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
quote
deg2_eval_constant(tree, operators.binops[op_idx], operators)::ResultOk{Vector{T}}
deg2_eval_constant(tree, operators.binops[op_idx], operators)::ResultOk{T}
end
else
quote
Base.Cartesian.@nif(
$nbin,
i -> i == op_idx,
i -> deg2_eval_constant(
tree, operators.binops[i], operators
)::ResultOk{Vector{T}}
i -> deg2_eval_constant(tree, operators.binops[i], operators)::ResultOk{T}
)
end
end
return quote
if tree.degree == 0
return deg0_eval_constant(tree)::ResultOk{Vector{T}}
return deg0_eval_constant(tree)::ResultOk{T}
elseif tree.degree == 1
op_idx = tree.op
return $deg1_branch
Expand All @@ -621,16 +683,16 @@ end

@inline function deg0_eval_constant(tree::AbstractExpressionNode{T}) where {T}
output = tree.val
return ResultOk([output], is_valid(output))::ResultOk{Vector{T}}
return ResultOk(output, is_valid(output))::ResultOk{T}
end

function deg1_eval_constant(
tree::AbstractExpressionNode{T}, op::F, operators::OperatorEnum
) where {T,F}
result = dispatch_constant_tree(tree.l, operators)
!result.ok && return result
output = op(result.x[])::T
return ResultOk([output], is_valid(output))::ResultOk{Vector{T}}
output = op(result.x)::T
return ResultOk(output, is_valid(output))::ResultOk{T}
end

function deg2_eval_constant(
Expand All @@ -640,8 +702,8 @@ function deg2_eval_constant(
!cumulator.ok && return cumulator
result_r = dispatch_constant_tree(tree.r, operators)
!result_r.ok && return result_r
output = op(cumulator.x[], result_r.x[])::T
return ResultOk([output], is_valid(output))::ResultOk{Vector{T}}
output = op(cumulator.x, result_r.x)::T
return ResultOk(output, is_valid(output))::ResultOk{T}
end

"""
Expand Down
Loading
Loading