Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ import .StringsModule: get_op_name, get_pretty_op_name
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateModule:
eval_tree_array, differentiable_eval_tree_array, EvalOptions
import .EvaluateModule: ArrayBuffer
import .EvaluateModule: ArrayBuffer, ResultOk
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
28 changes: 21 additions & 7 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,34 +94,43 @@
- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
This should be an instance of `ArrayBuffer` which has an `array` field and an
`index` field used to iterate which buffer slot to use.
- `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster
evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that
you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom
evaluation.
"""
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}}
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing},U}
turbo::Val{T}
bumper::Val{B}
early_exit::Val{E}
buffer::BUF
use_fused::Val{U}
end

@unstable function EvalOptions(;
turbo::Union{Bool,Val}=Val(false),
bumper::Union{Bool,Val}=Val(false),
early_exit::Union{Bool,Val}=Val(true),
buffer::Union{ArrayBuffer,Nothing}=nothing,
use_fused::Union{Bool,Val}=Val(true),
)
v_turbo = _to_bool_val(turbo)
v_bumper = _to_bool_val(bumper)
v_early_exit = _to_bool_val(early_exit)
v_use_fused = _to_bool_val(use_fused)

Check warning on line 120 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L120

Added line #L120 was not covered by tests

if v_bumper isa Val{true}
@assert buffer === nothing
end

return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer)
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer, v_use_fused)

Check warning on line 126 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L126

Added line #L126 was not covered by tests
end

@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
@inline _to_bool_val(::Val{T}) where {T} = Val(T::Bool)

@inline use_fused(eval_options::EvalOptions) = eval_options.use_fused isa Val{true}

Check warning on line 132 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L132

Added line #L132 was not covered by tests

_copy(x) = copy(x)
_copy(::Nothing) = nothing
function Base.copy(eval_options::EvalOptions)
Expand All @@ -130,6 +139,7 @@
bumper=eval_options.bumper,
early_exit=eval_options.early_exit,
buffer=_copy(eval_options.buffer),
use_fused=eval_options.use_fused,
)
end

Expand Down Expand Up @@ -433,19 +443,20 @@
end
end
return quote
fused = use_fused(eval_options)

Check warning on line 446 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L446

Added line #L446 was not covered by tests
return Base.Cartesian.@nif(
$nbin,
i -> i == op_idx, # COV_EXCL_LINE
i -> let op = operators.binops[i] # COV_EXCL_LINE
if get_child(tree, 1).degree == 0 && get_child(tree, 2).degree == 0
if fused && get_child(tree, 1).degree == 0 && get_child(tree, 2).degree == 0

Check warning on line 451 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L451

Added line #L451 was not covered by tests
deg2_l0_r0_eval(tree, cX, op, eval_options)
elseif get_child(tree, 2).degree == 0
elseif fused && get_child(tree, 2).degree == 0

Check warning on line 453 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L453

Added line #L453 was not covered by tests
result_l = _eval_tree_array(get_child(tree, 1), cX, operators, eval_options)
!result_l.ok && return result_l
@return_on_nonfinite_array(eval_options, result_l.x)
# op(x, y), where y is a constant or variable but x is not.
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
elseif get_child(tree, 1).degree == 0
elseif fused && get_child(tree, 1).degree == 0

Check warning on line 459 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L459

Added line #L459 was not covered by tests
result_r = _eval_tree_array(get_child(tree, 2), cX, operators, eval_options)
!result_r.ok && return result_r
@return_on_nonfinite_array(eval_options, result_r.x)
Expand Down Expand Up @@ -487,19 +498,22 @@
# This @nif lets us generate an if statement over choice of operator,
# which means the compiler will be able to completely avoid type inference on operators.
return quote
fused = use_fused(eval_options)

Check warning on line 501 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L501

Added line #L501 was not covered by tests
Base.Cartesian.@nif(
$nuna,
i -> i == op_idx, # COV_EXCL_LINE
i -> let op = operators.unaops[i] # COV_EXCL_LINE
if get_child(tree, 1).degree == 2 &&
if fused &&

Check warning on line 506 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L506

Added line #L506 was not covered by tests
get_child(tree, 1).degree == 2 &&
get_child(get_child(tree, 1), 1).degree == 0 &&
get_child(get_child(tree, 1), 2).degree == 0
# op(op2(x, y)), where x, y, z are constants or variables.
l_op_idx = get_child(tree, 1).op
dispatch_deg1_l2_ll0_lr0_eval(
tree, cX, op, l_op_idx, operators.binops, eval_options
)
elseif get_child(tree, 1).degree == 1 &&
elseif fused &&

Check warning on line 515 in src/Evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/Evaluate.jl#L515

Added line #L515 was not covered by tests
get_child(tree, 1).degree == 1 &&
get_child(get_child(tree, 1), 1).degree == 0
# op(op2(x)), where x is a constant or variable.
l_op_idx = get_child(tree, 1).op
Expand Down
Loading