Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ module DynamicExpressionsLoopVectorizationExt

using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateModule: @return_on_nonfinite_val, EvalOptions
using DynamicExpressions.UtilsModule: ResultOk
using DynamicExpressions.EvaluateModule:
@return_on_nonfinite_val, EvalOptions, get_array, get_feature_array, get_filled_array
import DynamicExpressions.EvaluateModule:
deg1_eval,
deg2_eval,
Expand Down Expand Up @@ -56,12 +57,12 @@ function deg1_l2_ll0_lr0_eval(
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
elseif tree.l.l.constant
val_ll = tree.l.l.val
@return_on_nonfinite_val(eval_options, val_ll, cX)
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(val_ll, cX[feature_lr, j])
x = op(x_l)
Expand All @@ -72,7 +73,7 @@ function deg1_l2_ll0_lr0_eval(
feature_ll = tree.l.l.feature
val_lr = tree.l.r.val
@return_on_nonfinite_val(eval_options, val_lr, cX)
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], val_lr)
x = op(x_l)
Expand All @@ -82,7 +83,7 @@ function deg1_l2_ll0_lr0_eval(
else
feature_ll = tree.l.l.feature
feature_lr = tree.l.r.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j], cX[feature_lr, j])
x = op(x_l)
Expand All @@ -106,10 +107,10 @@ function deg1_l1_ll0_eval(
@return_on_nonfinite_val(eval_options, x_l, cX)
x = op(x_l)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
else
feature_ll = tree.l.l.feature
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
@turbo for j in axes(cX, 2)
x_l = op_l(cX[feature_ll, j])
x = op(x_l)
Expand All @@ -132,9 +133,9 @@ function deg2_l0_r0_eval(
@return_on_nonfinite_val(eval_options, val_r, cX)
x = op(val_l, val_r)::T
@return_on_nonfinite_val(eval_options, x, cX)
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
return ResultOk(get_filled_array(eval_options.buffer, x, cX, axes(cX, 2)), true)
elseif tree.l.constant
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
val_l = tree.l.val
@return_on_nonfinite_val(eval_options, val_l, cX)
feature_r = tree.r.feature
Expand All @@ -144,7 +145,7 @@ function deg2_l0_r0_eval(
end
return ResultOk(cumulator, true)
elseif tree.r.constant
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
feature_l = tree.l.feature
val_r = tree.r.val
@return_on_nonfinite_val(eval_options, val_r, cX)
Expand All @@ -154,7 +155,7 @@ function deg2_l0_r0_eval(
end
return ResultOk(cumulator, true)
else
cumulator = similar(cX, axes(cX, 2))
cumulator = get_array(eval_options.buffer, cX, axes(cX, 2))
feature_l = tree.l.feature
feature_r = tree.r.feature
@turbo for j in axes(cX, 2)
Expand Down
1 change: 1 addition & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import .StringsModule: get_op_name
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateModule:
eval_tree_array, differentiable_eval_tree_array, EvalOptions
import .EvaluateModule: ArrayBuffer
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
Loading
Loading