@@ -89,13 +89,55 @@ function is_closure_model(model_f_tangent)
89
89
return false
90
90
end
91
91
92
+ """
93
+ Check if a VarInfo tangent needs caching due to circular references (e.g., Ref fields).
94
+ """
95
+ function needs_caching_for_varinfo (x)
96
+ # Check if it's a VarInfo tangent
97
+ is_dppl_varinfo_tangent (x) || return false
98
+
99
+ # Check if the logp field contains a Ref-like tangent structure
100
+ hasfield (typeof (x. fields), :logp ) || return false
101
+ logp_tangent = x. fields. logp
102
+
103
+ # Ref types in tangents often appear as MutableTangent with circular references
104
+ return logp_tangent isa MutableTangent
105
+ end
106
+
107
+ """
108
+ Check if a tangent contains PossiblyUninitTangent{Any} which can cause infinite recursion.
109
+ """
110
+ function contains_possibly_uninit_any (x)
111
+ x isa Mooncake. PossiblyUninitTangent{Any} && return true
112
+
113
+ if x isa Tangent && hasfield (typeof (x), :fields )
114
+ for (_, fval) in pairs (x. fields)
115
+ contains_possibly_uninit_any (fval) && return true
116
+ end
117
+ elseif x isa MutableTangent && hasfield (typeof (x), :fields )
118
+ hasfield (typeof (x. fields), :contents ) &&
119
+ x. fields. contents isa Mooncake. PossiblyUninitTangent{Any} &&
120
+ return true
121
+ end
122
+
123
+ return false
124
+ end
125
+
92
126
function Mooncake. set_to_zero!! (x)
127
+ # Always use caching if we detect PossiblyUninitTangent{Any} anywhere
128
+ if contains_possibly_uninit_any (x)
129
+ return set_to_zero_internal!! (IdDict {Any,Bool} (), x)
130
+ end
131
+
93
132
# Check for DynamicPPL types and use NoCache for better performance
94
133
if is_dppl_ldf_tangent (x)
95
134
# Special handling for LogDensityFunction to detect closures
96
135
model_f_tangent = x. fields. model. fields. f
97
136
cache = is_closure_model (model_f_tangent) ? IdDict {Any,Bool} () : NoCache ()
98
137
return set_to_zero_internal!! (cache, x)
138
+ elseif is_dppl_varinfo_tangent (x) && needs_caching_for_varinfo (x)
139
+ # Use IdDict for SimpleVarInfo with Ref fields to avoid circular references
140
+ return set_to_zero_internal!! (IdDict {Any,Bool} (), x)
99
141
elseif is_dppl_varinfo_tangent (x) ||
100
142
is_dppl_model_tangent (x) ||
101
143
is_dppl_metadata_tangent (x)
0 commit comments