Skip to content

Commit ef8a51a

Browse files
committed
feat: create EvalOptions(use_fused=false) to avoid fused operations
1 parent e92645d commit ef8a51a

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

src/Evaluate.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,34 +93,43 @@ This holds options for expression evaluation, such as evaluation backend.
9393
- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
9494
This should be an instance of `ArrayBuffer` which has an `array` field and an
9595
`index` field used to iterate which buffer slot to use.
96+
- `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster
97+
evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that
98+
you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom
99+
evaluation.
96100
"""
97-
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing}}
101+
struct EvalOptions{T,B,E,BUF<:Union{ArrayBuffer,Nothing},U}
98102
turbo::Val{T}
99103
bumper::Val{B}
100104
early_exit::Val{E}
101105
buffer::BUF
106+
use_fused::Val{U}
102107
end
103108

104109
@unstable function EvalOptions(;
105110
turbo::Union{Bool,Val}=Val(false),
106111
bumper::Union{Bool,Val}=Val(false),
107112
early_exit::Union{Bool,Val}=Val(true),
108113
buffer::Union{ArrayBuffer,Nothing}=nothing,
114+
use_fused::Union{Bool,Val}=Val(true),
109115
)
110116
v_turbo = _to_bool_val(turbo)
111117
v_bumper = _to_bool_val(bumper)
112118
v_early_exit = _to_bool_val(early_exit)
119+
v_use_fused = _to_bool_val(use_fused)
113120

114121
if v_bumper isa Val{true}
115122
@assert buffer === nothing
116123
end
117124

118-
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer)
125+
return EvalOptions(v_turbo, v_bumper, v_early_exit, buffer, v_use_fused)
119126
end
120127

121128
@unstable @inline _to_bool_val(x::Bool) = x ? Val(true) : Val(false)
122129
@inline _to_bool_val(::Val{T}) where {T} = Val(T::Bool)
123130

131+
@inline use_fused(eval_options::EvalOptions) = eval_options.use_fused isa Val{true}
132+
124133
_copy(x) = copy(x)
125134
_copy(::Nothing) = nothing
126135
function Base.copy(eval_options::EvalOptions)
@@ -129,6 +138,7 @@ function Base.copy(eval_options::EvalOptions)
129138
bumper=eval_options.bumper,
130139
early_exit=eval_options.early_exit,
131140
buffer=_copy(eval_options.buffer),
141+
use_fused=eval_options.use_fused,
132142
)
133143
end
134144

@@ -340,19 +350,20 @@ end
340350
end
341351
end
342352
return quote
353+
fused = use_fused(eval_options)
343354
return Base.Cartesian.@nif(
344355
$nbin,
345356
i -> i == op_idx,
346357
i -> let op = operators.binops[i]
347-
if tree.l.degree == 0 && tree.r.degree == 0
358+
if fused && tree.l.degree == 0 && tree.r.degree == 0
348359
deg2_l0_r0_eval(tree, cX, op, eval_options)
349-
elseif tree.r.degree == 0
360+
elseif fused && tree.r.degree == 0
350361
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
351362
!result_l.ok && return result_l
352363
@return_on_nonfinite_array(eval_options, result_l.x)
353364
# op(x, y), where y is a constant or variable but x is not.
354365
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
355-
elseif tree.l.degree == 0
366+
elseif fused && tree.l.degree == 0
356367
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
357368
!result_r.ok && return result_r
358369
@return_on_nonfinite_array(eval_options, result_r.x)
@@ -392,17 +403,18 @@ end
392403
# This @nif lets us generate an if statement over choice of operator,
393404
# which means the compiler will be able to completely avoid type inference on operators.
394405
return quote
406+
fused = use_fused(eval_options)
395407
Base.Cartesian.@nif(
396408
$nuna,
397409
i -> i == op_idx,
398410
i -> let op = operators.unaops[i]
399-
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
411+
if fused && tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
400412
# op(op2(x, y)), where x, y, z are constants or variables.
401413
l_op_idx = tree.l.op
402414
dispatch_deg1_l2_ll0_lr0_eval(
403415
tree, cX, op, l_op_idx, operators.binops, eval_options
404416
)
405-
elseif tree.l.degree == 1 && tree.l.l.degree == 0
417+
elseif fused && tree.l.degree == 1 && tree.l.l.degree == 0
406418
# op(op2(x)), where x is a constant or variable.
407419
l_op_idx = tree.l.op
408420
dispatch_deg1_l1_ll0_eval(

0 commit comments

Comments
 (0)