11module EvaluationHelpersModule
22
3+ using ChainRulesCore: @non_differentiable
4+
35import Base: adjoint
46import .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
57import .. NodeModule: AbstractExpressionNode
68import .. EvaluateModule: eval_tree_array
79import .. EvaluateDerivativeModule: eval_grad_tree_array
810
11+ function _set_nan! (out)
12+ out .= convert (eltype (out), NaN )
13+ return nothing
14+ end
15+ @non_differentiable _set_nan! (out)
16+
917# Evaluation:
1018"""
1119 (tree::AbstractExpressionNode)(X, operators::OperatorEnum; kws...)
@@ -27,7 +35,7 @@ and triplets of operations for lower memory usage.
2735"""
2836function (tree:: AbstractExpressionNode )(X, operators:: OperatorEnum ; kws... )
2937 out, did_finish = eval_tree_array (tree, X, operators; kws... )
30- ! did_finish && (out . = convert ( eltype (out), NaN ) )
38+ ! did_finish && _set_nan! (out)
3139 return out
3240end
3341"""
@@ -56,7 +64,7 @@ function _grad_evaluator(
5664 tree:: AbstractExpressionNode , X, operators:: OperatorEnum ; variable= Val (true ), kws...
5765)
5866 _, grad, did_complete = eval_grad_tree_array (tree, X, operators; variable, kws... )
59- ! did_complete && (grad . = convert ( eltype (grad), NaN ) )
67+ ! did_complete && _set_nan! (grad)
6068 return grad
6169end
6270function _grad_evaluator (
0 commit comments