diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 18d5d741..355e7b98 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -81,7 +81,7 @@ import .StringsModule: get_op_name, get_pretty_op_name OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names! @reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, EvalOptions -import .EvaluateModule: ArrayBuffer +import .EvaluateModule: ArrayBuffer, ResultOk @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 37c0f19a..041194bc 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -94,12 +94,17 @@ This holds options for expression evaluation, such as evaluation backend. - `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation. This should be an instance of `ArrayBuffer` which has an `array` field and an `index` field used to iterate which buffer slot to use. +- `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster + evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that + you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom + evaluation. """ -struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}} +struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing},U} turbo::Val{T} bumper::Val{B} early_exit::Val{E} buffer::BUF + use_fused::Val{U} end @unstable function EvalOptions(; @@ -107,21 +112,25 @@ end bumper::Union{Bool,Val}=Val(false), early_exit::Union{Bool,Val}=Val(true), buffer::Union{ArrayBuffer,Nothing}=nothing, + use_fused::Union{Bool,Val}=Val(true), ) v_turbo = _to_bool_val(turbo) v_bumper = _to_bool_val(bumper) v_early_exit = _to_bool_val(early_exit) + v_use_fused = _to_bool_val(use_fused) if v_bumper isa Val{true} @assert buffer === nothing end - return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer) + return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer, v_use_fused) end @unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false) @inline _to_bool_val(::Val{T}) where {T} = Val(T::Bool) +@inline use_fused(eval_options::EvalOptions) = eval_options.use_fused isa Val{true} + _copy(x) = copy(x) _copy(::Nothing) = nothing function Base.copy(eval_options::EvalOptions) @@ -130,6 +139,7 @@ function Base.copy(eval_options::EvalOptions) bumper=eval_options.bumper, early_exit=eval_options.early_exit, buffer=_copy(eval_options.buffer), + use_fused=eval_options.use_fused, ) end @@ -433,19 +443,20 @@ end end end return quote + fused = use_fused(eval_options) return Base.Cartesian.@nif( $nbin, i -> i == op_idx, # COV_EXCL_LINE i -> let op = operators.binops[i] # COV_EXCL_LINE - if get_child(tree, 1).degree == 0 && get_child(tree, 2).degree == 0 + if fused && get_child(tree, 1).degree == 0 && get_child(tree, 2).degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) - elseif get_child(tree, 2).degree == 0 + elseif fused && get_child(tree, 2).degree == 0 result_l = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) - elseif get_child(tree, 1).degree == 0 + elseif fused && get_child(tree, 1).degree == 0 result_r = _eval_tree_array(get_child(tree, 2), cX, operators, eval_options) !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) @@ -487,11 +498,13 @@ end # This @nif lets us generate an if statement over choice of operator, # which means the compiler will be able to completely avoid type inference on operators. return quote + fused = use_fused(eval_options) Base.Cartesian.@nif( $nuna, i -> i == op_idx, # COV_EXCL_LINE i -> let op = operators.unaops[i] # COV_EXCL_LINE - if get_child(tree, 1).degree == 2 && + if fused && + get_child(tree, 1).degree == 2 && get_child(get_child(tree, 1), 1).degree == 0 && get_child(get_child(tree, 1), 2).degree == 0 # op(op2(x, y)), where x, y, z are constants or variables. @@ -499,7 +512,8 @@ end dispatch_deg1_l2_ll0_lr0_eval( tree, cX, op, l_op_idx, operators.binops, eval_options ) - elseif get_child(tree, 1).degree == 1 && + elseif fused && + get_child(tree, 1).degree == 1 && get_child(get_child(tree, 1), 1).degree == 0 # op(op2(x)), where x is a constant or variable. l_op_idx = get_child(tree, 1).op