@@ -5,8 +5,7 @@ using DynamicExpressions:
55 OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions
66using DynamicExpressions. UtilsModule: ResultOk, counttuple
77
8- import DynamicExpressions. ExtensionInterfaceModule:
9- bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
8+ import DynamicExpressions. ExtensionInterfaceModule: bumper_eval_tree_array, bumper_kern!
109
1110function bumper_eval_tree_array (
1211 tree:: AbstractExpressionNode{T} ,
@@ -37,8 +36,7 @@ function bumper_eval_tree_array(
3736 branch_node -> branch_node,
3837 # In the evaluation kernel, we combine the branch nodes
3938 # with the arrays created by the leaf nodes:
40- ((args:: Vararg{Any,M} ) where {M}) ->
41- dispatch_kerns! (operators, args... , eval_options),
39+ KernelDispatcher (operators, eval_options),
4240 tree;
4341 break_sharing= Val (true ),
4442 )
@@ -49,63 +47,44 @@ function bumper_eval_tree_array(
4947 return (result, all_ok[])
5048end
5149
52- function dispatch_kerns! (
53- operators, branch_node, cumulator, eval_options:: EvalOptions{<:Any,true,early_exit}
54- ) where {early_exit}
55- cumulator. ok || return cumulator
56-
57- out = dispatch_kern1! (operators. unaops, branch_node. op, cumulator. x, eval_options)
58- return ResultOk (out, early_exit ? is_valid_array (out) : true )
59- end
60- function dispatch_kerns! (
61- operators,
62- branch_node,
63- cumulator1,
64- cumulator2,
65- eval_options:: EvalOptions{<:Any,true,early_exit} ,
66- ) where {early_exit}
67- cumulator1. ok || return cumulator1
68- cumulator2. ok || return cumulator2
69-
70- out = dispatch_kern2! (
71- operators. binops, branch_node. op, cumulator1. x, cumulator2. x, eval_options
72- )
73- return ResultOk (out, early_exit ? is_valid_array (out) : true )
50+ struct KernelDispatcher{O<: OperatorEnum ,E<: EvalOptions{<:Any,true,<:Any} } <: Function
51+ operators:: O
52+ eval_options:: E
7453end
7554
76- @generated function dispatch_kern1! (unaops, op_idx, cumulator, eval_options:: EvalOptions )
77- nuna = counttuple (unaops)
55+ @generated function (kd:: KernelDispatcher{<:Any,<:EvalOptions{<:Any,true,early_exit}} )(
56+ branch_node, inputs:: Vararg{Any,degree}
57+ ) where {degree,early_exit}
7858 quote
79- Base. @nif (
80- $ nuna,
81- i -> i == op_idx,
82- i -> let op = unaops[i]
83- return bumper_kern1! (op, cumulator, eval_options)
84- end ,
85- )
59+ Base. Cartesian. @nexprs ($ degree, i -> inputs[i]. ok || return inputs[i])
60+ cumulators = Base. Cartesian. @ntuple ($ degree, i -> inputs[i]. x)
61+ out = dispatch_kerns! (kd. operators, branch_node, cumulators, kd. eval_options)
62+ return ResultOk (out, early_exit ? is_valid_array (out) : true )
8663 end
8764end
88- @generated function dispatch_kern2! (
89- binops, op_idx, cumulator1, cumulator2, eval_options:: EvalOptions
90- )
91- nbin = counttuple (binops)
65+ @generated function dispatch_kerns! (
66+ operators:: OperatorEnum{OPS} ,
67+ branch_node,
68+ cumulators:: Tuple{Vararg{Any,degree}} ,
69+ eval_options:: EvalOptions ,
70+ ) where {OPS,degree}
71+ nops = length (OPS. types[degree]. types)
9272 quote
93- Base. @nif (
94- $ nbin,
73+ op_idx = branch_node. op
74+ Base. Cartesian. @nif (
75+ $ nops,
9576 i -> i == op_idx,
96- i -> let op = binops[i]
97- return bumper_kern2! (op, cumulator1, cumulator2, eval_options)
98- end ,
77+ i -> bumper_kern! (operators[$ degree][i], cumulators, eval_options)
9978 )
10079 end
10180end
102- function bumper_kern1! (op :: F , cumulator, :: EvalOptions{false,true} ) where {F}
103- @. cumulator = op (cumulator)
104- return cumulator
105- end
106- function bumper_kern2! (op :: F , cumulator1, cumulator2, :: EvalOptions{false,true} ) where {F}
107- @. cumulator1 = op (cumulator1, cumulator2 )
108- return cumulator1
81+
82+ function bumper_kern! (
83+ op :: F , cumulators :: Tuple{Vararg{Any,degree}} , :: EvalOptions{false,true,early_exit}
84+ ) where {F,degree,early_exit}
85+ cumulator_1 = first (cumulators)
86+ @. cumulator_1 = op (cumulators ... )
87+ return cumulator_1
10988end
11089
11190end
0 commit comments