From 0e662423688babda273c13b88f207040481a35ed Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Jun 2025 21:18:49 +0100 Subject: [PATCH] fix: move non_differentiable to special module for JET masking --- src/EvaluationHelpers.jl | 2 +- src/NonDifferentiableDeclarations.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/EvaluationHelpers.jl b/src/EvaluationHelpers.jl index a302e533..131a3f63 100644 --- a/src/EvaluationHelpers.jl +++ b/src/EvaluationHelpers.jl @@ -8,11 +8,11 @@ import ..NodeModule: AbstractExpressionNode import ..EvaluateModule: eval_tree_array import ..EvaluateDerivativeModule: eval_grad_tree_array +# Needs to be special function so we can declare it non-differentiable to Zygote function _set_nan!(out) out .= convert(eltype(out), NaN) return nothing end -@non_differentiable _set_nan!(out) # Evaluation: """ diff --git a/src/NonDifferentiableDeclarations.jl b/src/NonDifferentiableDeclarations.jl index 0e523dab..cf4a2b73 100644 --- a/src/NonDifferentiableDeclarations.jl +++ b/src/NonDifferentiableDeclarations.jl @@ -6,6 +6,7 @@ import ..NodeModule: AbstractExpressionNode, AbstractNode import ..NodeUtilsModule: tree_mapreduce import ..ExpressionModule: AbstractExpression, get_operators, get_variable_names, _validate_input +import ..EvaluationHelpersModule: _set_nan! #! format: off @non_differentiable tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type) @@ -13,6 +14,7 @@ import ..ExpressionModule: @non_differentiable get_operators(ex::Union{AbstractExpression,AbstractExpressionNode}, operators::Union{AbstractOperatorEnum,Nothing}) @non_differentiable get_variable_names(ex::AbstractExpression, variable_names::Union{AbstractVector{<:AbstractString},Nothing}) @non_differentiable _validate_input(ex::AbstractExpression, X, operators::Union{AbstractOperatorEnum,Nothing}) +@non_differentiable _set_nan!(::Any) #! format: on end