@@ -93,34 +93,43 @@ This holds options for expression evaluation, such as evaluation backend.
9393- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
9494 This should be an instance of `ArrayBuffer` which has an `array` field and an
9595 `index` field used to iterate which buffer slot to use.
96+ - `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster
97+ evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that
98+ you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom
99+ evaluation.
96100"""
97- struct EvalOptions{T,B,E,BUF<: Union{ArrayBuffer,Nothing} }
101+ struct EvalOptions{T,B,E,BUF<: Union{ArrayBuffer,Nothing} ,U }
98102 turbo:: Val{T}
99103 bumper:: Val{B}
100104 early_exit:: Val{E}
101105 buffer:: BUF
106+ use_fused:: Val{U}
102107end
103108
104109@unstable function EvalOptions (;
105110 turbo:: Union{Bool,Val} = Val (false ),
106111 bumper:: Union{Bool,Val} = Val (false ),
107112 early_exit:: Union{Bool,Val} = Val (true ),
108113 buffer:: Union{ArrayBuffer,Nothing} = nothing ,
114+ use_fused:: Union{Bool,Val} = Val (true ),
109115)
110116 v_turbo = _to_bool_val (turbo)
111117 v_bumper = _to_bool_val (bumper)
112118 v_early_exit = _to_bool_val (early_exit)
119+ v_use_fused = _to_bool_val (use_fused)
113120
114121 if v_bumper isa Val{true }
115122 @assert buffer === nothing
116123 end
117124
118- return EvalOptions (v_turbo, v_bumper, v_early_exit, buffer)
125+ return EvalOptions (v_turbo, v_bumper, v_early_exit, buffer, v_use_fused )
119126end
120127
121128@unstable @inline _to_bool_val (x:: Bool ) = x ? Val (true ) : Val (false )
122129@inline _to_bool_val (:: Val{T} ) where {T} = Val (T:: Bool )
123130
131+ @inline use_fused (eval_options:: EvalOptions ) = eval_options. use_fused isa Val{true }
132+
124133_copy (x) = copy (x)
125134_copy (:: Nothing ) = nothing
126135function Base. copy (eval_options:: EvalOptions )
@@ -129,6 +138,7 @@ function Base.copy(eval_options::EvalOptions)
129138 bumper= eval_options. bumper,
130139 early_exit= eval_options. early_exit,
131140 buffer= _copy (eval_options. buffer),
141+ use_fused= eval_options. use_fused,
132142 )
133143end
134144
@@ -340,19 +350,20 @@ end
340350 end
341351 end
342352 return quote
353+ fused = use_fused (eval_options)
343354 return Base. Cartesian. @nif (
344355 $ nbin,
345356 i -> i == op_idx,
346357 i -> let op = operators. binops[i]
347- if tree. l. degree == 0 && tree. r. degree == 0
358+ if fused && tree. l. degree == 0 && tree. r. degree == 0
348359 deg2_l0_r0_eval (tree, cX, op, eval_options)
349- elseif tree. r. degree == 0
360+ elseif fused && tree. r. degree == 0
350361 result_l = _eval_tree_array (tree. l, cX, operators, eval_options)
351362 ! result_l. ok && return result_l
352363 @return_on_nonfinite_array (eval_options, result_l. x)
353364 # op(x, y), where y is a constant or variable but x is not.
354365 deg2_r0_eval (tree, result_l. x, cX, op, eval_options)
355- elseif tree. l. degree == 0
366+ elseif fused && tree. l. degree == 0
356367 result_r = _eval_tree_array (tree. r, cX, operators, eval_options)
357368 ! result_r. ok && return result_r
358369 @return_on_nonfinite_array (eval_options, result_r. x)
@@ -392,17 +403,18 @@ end
392403 # This @nif lets us generate an if statement over choice of operator,
393404 # which means the compiler will be able to completely avoid type inference on operators.
394405 return quote
406+ fused = use_fused (eval_options)
395407 Base. Cartesian. @nif (
396408 $ nuna,
397409 i -> i == op_idx,
398410 i -> let op = operators. unaops[i]
399- if tree. l. degree == 2 && tree. l. l. degree == 0 && tree. l. r. degree == 0
411+ if fused && tree. l. degree == 2 && tree. l. l. degree == 0 && tree. l. r. degree == 0
400412 # op(op2(x, y)), where x, y, z are constants or variables.
401413 l_op_idx = tree. l. op
402414 dispatch_deg1_l2_ll0_lr0_eval (
403415 tree, cX, op, l_op_idx, operators. binops, eval_options
404416 )
405- elseif tree. l. degree == 1 && tree. l. l. degree == 0
417+ elseif fused && tree. l. degree == 1 && tree. l. l. degree == 0
406418 # op(op2(x)), where x is a constant or variable.
407419 l_op_idx = tree. l. op
408420 dispatch_deg1_l1_ll0_eval (
0 commit comments