Skip to content

Commit 30506e1

Browse files
Fix Enzyme extension: use targeted inactive rules for operator types
Changed from marking all Function types as inactive (too broad, caused segfaults in downstream packages) to only marking SciMLOperator abstract types as inactive. This prevents interference with other packages using Enzyme while still allowing gradients to flow through operator operations.
1 parent aeb863f commit 30506e1

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

ext/SciMLOperatorsEnzymeExt.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ function Enzyme.EnzymeRules.inactive_type(::Type{SciMLOperators.NoKwargFilter})
2121
return true
2222
end
2323

24-
# The key insight: Function-typed fields in operators are code (update functions),
25-
# not differentiable data. Tell Enzyme to treat them as inactive.
26-
# This prevents Enzyme from trying to differentiate through closure captures.
27-
function Enzyme.EnzymeRules.inactive_type(::Type{F}) where {F <: Function}
28-
return true
29-
end
24+
# For operator types with function fields, we need to tell Enzyme that the operators
25+
# themselves are inactive during forward/reverse passes - the differentiation happens
26+
# through the mathematical operations (mul!, ldiv!, etc.) not through the operator structures.
27+
# The function fields (update_func) are just code that computes coefficients.
28+
29+
# Mark specific scalar and matrix operator types that have function fields as inactive
30+
Enzyme.EnzymeRules.inactive_type(::Type{<:SciMLOperators.AbstractSciMLScalarOperator}) = true
31+
Enzyme.EnzymeRules.inactive_type(::Type{<:SciMLOperators.AbstractSciMLOperator}) = true
3032

3133
# Note: The actual differentiation will happen through the mathematical operations
32-
# (mul!, *, +, etc.) which Enzyme knows how to handle natively. The operator
33-
# structures just orchestrate these operations.
34+
# (mul!, *, +, etc.) which Enzyme knows how to handle natively.
3435

3536
end # module

0 commit comments

Comments
 (0)