@@ -218,6 +218,10 @@ function eval_tree_array(
218218 " Bumper and LoopVectorization features are only compatible with numeric element types" ,
219219 )
220220 end
221+ if any_special_operators (typeof (operators))
222+ cX = copy (cX)
223+ # TODO : This is dangerous if the element type is mutable
224+ end
221225 if _eval_options. bumper isa Val{true }
222226 return bumper_eval_tree_array (tree, cX, operators, _eval_options)
223227 end
@@ -329,22 +333,26 @@ end
329333 long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
330334 if long_compilation_time
331335 return quote
336+ op = operators. binops[op_idx]
337+ special_operator (op) && return deg2_eval_special (tree, cX, op, eval_options)
332338 result_l = _eval_tree_array (tree. l, cX, operators, eval_options)
333339 ! result_l. ok && return result_l
334340 @return_on_nonfinite_array (eval_options, result_l. x)
335341 result_r = _eval_tree_array (tree. r, cX, operators, eval_options)
336342 ! result_r. ok && return result_r
337343 @return_on_nonfinite_array (eval_options, result_r. x)
338344 # op(x, y), for any x or y
339- deg2_eval (result_l. x, result_r. x, operators . binops[op_idx] , eval_options)
345+ deg2_eval (result_l. x, result_r. x, op , eval_options)
340346 end
341347 end
342348 return quote
343349 return Base. Cartesian. @nif (
344350 $ nbin,
345351 i -> i == op_idx,
346352 i -> let op = operators. binops[i]
347- if tree. l. degree == 0 && tree. r. degree == 0
353+ if special_operator (op)
354+ deg2_eval_special (tree, cX, op, eval_options)
355+ elseif tree. l. degree == 0 && tree. r. degree == 0
348356 deg2_l0_r0_eval (tree, cX, op, eval_options)
349357 elseif tree. r. degree == 0
350358 result_l = _eval_tree_array (tree. l, cX, operators, eval_options)
@@ -380,13 +388,16 @@ end
380388 eval_options:: EvalOptions ,
381389) where {T}
382390 nuna = get_nuna (operators)
391+ special_operators = any_special_operators (operators)
383392 long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
384393 if long_compilation_time
385394 return quote
395+ op = operators. unaops[op_idx]
396+ special_operator (op) && return deg1_eval_special (tree, cX, op, eval_options)
386397 result = _eval_tree_array (tree. l, cX, operators, eval_options)
387398 ! result. ok && return result
388399 @return_on_nonfinite_array (eval_options, result. x)
389- deg1_eval (result. x, operators . unaops[op_idx] , eval_options)
400+ deg1_eval (result. x, op , eval_options)
390401 end
391402 end
392403 # This @nif lets us generate an if statement over choice of operator,
@@ -396,13 +407,18 @@ end
396407 $ nuna,
397408 i -> i == op_idx,
398409 i -> let op = operators. unaops[i]
399- if tree. l. degree == 2 && tree. l. l. degree == 0 && tree. l. r. degree == 0
410+ if special_operator (op)
411+ deg1_eval_special (tree, cX, op, eval_options)
412+ elseif ! special_operators &&
413+ tree. l. degree == 2 &&
414+ tree. l. l. degree == 0 &&
415+ tree. l. r. degree == 0
400416 # op(op2(x, y)), where x, y, z are constants or variables.
401417 l_op_idx = tree. l. op
402418 dispatch_deg1_l2_ll0_lr0_eval (
403419 tree, cX, op, l_op_idx, operators. binops, eval_options
404420 )
405- elseif tree. l. degree == 1 && tree. l. l. degree == 0
421+ elseif ! special_operators && tree. l. degree == 1 && tree. l. l. degree == 0
406422 # op(op2(x)), where x is a constant or variable.
407423 l_op_idx = tree. l. op
408424 dispatch_deg1_l1_ll0_eval (
925941 end
926942end
927943
944+ # Overloaded by SpecialOperators.jl:
945+ function any_special_operators end
946+ function special_operator end
947+ function deg2_eval_special end
948+ function deg1_eval_special end
949+
928950end
0 commit comments