Skip to content

Commit 4f31a85

Browse files
committed
feat: n-arity compat with bumper
1 parent 7b51c06 commit 4f31a85

File tree

3 files changed

+39
-66
lines changed

3 files changed

+39
-66
lines changed

ext/DynamicExpressionsBumperExt.jl

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ using DynamicExpressions:
55
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array, EvalOptions
66
using DynamicExpressions.UtilsModule: ResultOk, counttuple
77

8-
import DynamicExpressions.ExtensionInterfaceModule:
9-
bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
8+
import DynamicExpressions.ExtensionInterfaceModule: bumper_eval_tree_array, bumper_kern!
109

1110
function bumper_eval_tree_array(
1211
tree::AbstractExpressionNode{T},
@@ -37,8 +36,7 @@ function bumper_eval_tree_array(
3736
branch_node -> branch_node,
3837
# In the evaluation kernel, we combine the branch nodes
3938
# with the arrays created by the leaf nodes:
40-
((args::Vararg{Any,M}) where {M}) ->
41-
dispatch_kerns!(operators, args..., eval_options),
39+
KernelDispatcher(operators, eval_options),
4240
tree;
4341
break_sharing=Val(true),
4442
)
@@ -49,63 +47,44 @@ function bumper_eval_tree_array(
4947
return (result, all_ok[])
5048
end
5149

52-
function dispatch_kerns!(
53-
operators, branch_node, cumulator, eval_options::EvalOptions{<:Any,true,early_exit}
54-
) where {early_exit}
55-
cumulator.ok || return cumulator
56-
57-
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, eval_options)
58-
return ResultOk(out, early_exit ? is_valid_array(out) : true)
59-
end
60-
function dispatch_kerns!(
61-
operators,
62-
branch_node,
63-
cumulator1,
64-
cumulator2,
65-
eval_options::EvalOptions{<:Any,true,early_exit},
66-
) where {early_exit}
67-
cumulator1.ok || return cumulator1
68-
cumulator2.ok || return cumulator2
69-
70-
out = dispatch_kern2!(
71-
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, eval_options
72-
)
73-
return ResultOk(out, early_exit ? is_valid_array(out) : true)
50+
struct KernelDispatcher{O<:OperatorEnum,E<:EvalOptions{<:Any,true,<:Any}} <: Function
51+
operators::O
52+
eval_options::E
7453
end
7554

76-
@generated function dispatch_kern1!(unaops, op_idx, cumulator, eval_options::EvalOptions)
77-
nuna = counttuple(unaops)
55+
@generated function (kd::KernelDispatcher{<:Any,<:EvalOptions{<:Any,true,early_exit}})(
56+
branch_node, inputs::Vararg{Any,degree}
57+
) where {degree,early_exit}
7858
quote
79-
Base.@nif(
80-
$nuna,
81-
i -> i == op_idx,
82-
i -> let op = unaops[i]
83-
return bumper_kern1!(op, cumulator, eval_options)
84-
end,
85-
)
59+
Base.Cartesian.@nexprs($degree, i -> inputs[i].ok || return inputs[i])
60+
cumulators = Base.Cartesian.@ntuple($degree, i -> inputs[i].x)
61+
out = dispatch_kerns!(kd.operators, branch_node, cumulators, kd.eval_options)
62+
return ResultOk(out, early_exit ? is_valid_array(out) : true)
8663
end
8764
end
88-
@generated function dispatch_kern2!(
89-
binops, op_idx, cumulator1, cumulator2, eval_options::EvalOptions
90-
)
91-
nbin = counttuple(binops)
65+
@generated function dispatch_kerns!(
66+
operators::OperatorEnum{OPS},
67+
branch_node,
68+
cumulators::Tuple{Vararg{Any,degree}},
69+
eval_options::EvalOptions,
70+
) where {OPS,degree}
71+
nops = length(OPS.types[degree].types)
9272
quote
93-
Base.@nif(
94-
$nbin,
73+
op_idx = branch_node.op
74+
Base.Cartesian.@nif(
75+
$nops,
9576
i -> i == op_idx,
96-
i -> let op = binops[i]
97-
return bumper_kern2!(op, cumulator1, cumulator2, eval_options)
98-
end,
77+
i -> bumper_kern!(operators[$degree][i], cumulators, eval_options)
9978
)
10079
end
10180
end
102-
function bumper_kern1!(op::F, cumulator, ::EvalOptions{false,true}) where {F}
103-
@. cumulator = op(cumulator)
104-
return cumulator
105-
end
106-
function bumper_kern2!(op::F, cumulator1, cumulator2, ::EvalOptions{false,true}) where {F}
107-
@. cumulator1 = op(cumulator1, cumulator2)
108-
return cumulator1
81+
82+
function bumper_kern!(
83+
op::F, cumulators::Tuple{Vararg{Any,degree}}, ::EvalOptions{false,true,early_exit}
84+
) where {F,degree,early_exit}
85+
cumulator_1 = first(cumulators)
86+
@. cumulator_1 = op(cumulators...)
87+
return cumulator_1
10988
end
11089

11190
end

ext/DynamicExpressionsLoopVectorizationExt.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import DynamicExpressions.EvaluateModule:
1515
deg2_l0_eval,
1616
deg2_r0_eval
1717
import DynamicExpressions.ExtensionInterfaceModule:
18-
_is_loopvectorization_loaded, bumper_kern1!, bumper_kern2!
18+
_is_loopvectorization_loaded, bumper_kern!
1919

2020
_is_loopvectorization_loaded(::Int) = true
2121

@@ -208,18 +208,13 @@ function deg2_r0_eval(
208208
end
209209
end
210210

211-
## Interface with Bumper.jl
212-
function bumper_kern1!(
213-
op::F, cumulator, ::EvalOptions{true,true,early_exit}
214-
) where {F,early_exit}
215-
@turbo @. cumulator = op(cumulator)
216-
return cumulator
217-
end
218-
function bumper_kern2!(
219-
op::F, cumulator1, cumulator2, ::EvalOptions{true,true,early_exit}
220-
) where {F,early_exit}
221-
@turbo @. cumulator1 = op(cumulator1, cumulator2)
222-
return cumulator1
211+
# Interface with Bumper.jl
212+
function bumper_kern!(
213+
op::F, cumulators::Tuple{Vararg{Any,degree}}, ::EvalOptions{true,true,early_exit}
214+
) where {F,degree,early_exit}
215+
cumulator_1 = first(cumulators)
216+
@turbo @. cumulator_1 = op(cumulators...)
217+
return cumulator_1
223218
end
224219

225220
end

src/ExtensionInterface.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ end
2525
function bumper_eval_tree_array(args...)
2626
return error("Please load the Bumper.jl package to use this feature.")
2727
end
28-
function bumper_kern1! end
29-
function bumper_kern2! end
28+
function bumper_kern! end
3029

3130
_is_loopvectorization_loaded(_) = false
3231

0 commit comments

Comments
 (0)