@@ -10,6 +10,36 @@ using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_interna
10
10
# This is purely an optimisation.
11
11
Mooncake. @zero_adjoint Mooncake. DefaultCtx Tuple{typeof (istrans),Vararg}
12
12
13
+ # =======================
14
+ # Cache Strategy System
15
+ # =======================
16
+
17
+ """
18
+ determine_cache_strategy(x)
19
+
20
+ Determines the appropriate caching strategy for a given tangent.
21
+ Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk.
22
+ """
23
+ function determine_cache_strategy (x)
24
+ # Fast path: check for known circular reference patterns
25
+ has_circular_reference_risk (x) && return IdDict {Any,Bool} ()
26
+
27
+ # Check for DynamicPPL types that can safely use NoCache
28
+ is_safe_dppl_type (x) && return NoCache ()
29
+
30
+ # Special case: LogDensityFunction without problematic patterns can use NoCache
31
+ if is_dppl_ldf_tangent (x)
32
+ return NoCache ()
33
+ end
34
+
35
+ # Default to safe caching for unknown types
36
+ return IdDict {Any,Bool} ()
37
+ end
38
+
39
+ # =======================
40
+ # Type Recognition
41
+ # =======================
42
+
13
43
"""
14
44
Check if a tangent has the expected structure for a given type.
15
45
"""
@@ -68,15 +98,46 @@ function is_dppl_metadata_tangent(x)
68
98
)
69
99
end
70
100
101
+ # =======================
102
+ # Circular Reference Detection
103
+ # =======================
104
+
71
105
"""
72
- Check if a model function tangent represents a closure.
106
+ has_circular_reference_risk(x)
107
+
108
+ Main entry point for detecting circular reference patterns that require caching.
109
+ Optimized for performance with targeted checks instead of recursive traversal.
73
110
"""
74
- function is_closure_model (model_f_tangent)
75
- model_f_tangent isa MutableTangent && return true
111
+ function has_circular_reference_risk (x)
112
+ # Type-specific targeted checks only
113
+ if is_dppl_ldf_tangent (x)
114
+ # Check model function for closure patterns with circular refs
115
+ model_f = x. fields. model. fields. f
116
+ return is_closure_with_circular_refs (model_f)
117
+ elseif is_dppl_varinfo_tangent (x)
118
+ # Check for Ref fields in VarInfo
119
+ return check_for_ref_fields (x)
120
+ end
76
121
77
- if model_f_tangent isa Tangent && hasfield (typeof (model_f_tangent), :fields )
78
- # Check if any field is a MutableTangent with PossiblyUninitTangent{Any}
79
- for (_, fval) in pairs (model_f_tangent. fields)
122
+ # For unknown types, do a shallow check for PossiblyUninitTangent{Any}
123
+ return x isa Mooncake. PossiblyUninitTangent{Any}
124
+ end
125
+
126
+ """
127
+ Check if a tangent represents a closure with circular reference patterns.
128
+ Only returns true for actual problematic patterns, not all MutableTangents.
129
+ """
130
+ function is_closure_with_circular_refs (x)
131
+ # Check if MutableTangent contains PossiblyUninitTangent{Any}
132
+ if x isa MutableTangent && hasfield (typeof (x), :fields )
133
+ hasfield (typeof (x. fields), :contents ) &&
134
+ x. fields. contents isa Mooncake. PossiblyUninitTangent{Any} &&
135
+ return true
136
+ end
137
+
138
+ # For Tangent, only check immediate fields (no deep recursion)
139
+ if x isa Tangent && hasfield (typeof (x), :fields )
140
+ for (_, fval) in pairs (x. fields)
80
141
if fval isa MutableTangent &&
81
142
hasfield (typeof (fval), :fields ) &&
82
143
hasfield (typeof (fval. fields), :contents ) &&
@@ -90,9 +151,9 @@ function is_closure_model(model_f_tangent)
90
151
end
91
152
92
153
"""
93
- Check if a VarInfo tangent needs caching due to circular references (e.g., Ref fields) .
154
+ Check if a VarInfo tangent has Ref fields that need caching .
94
155
"""
95
- function needs_caching_for_varinfo (x)
156
+ function check_for_ref_fields (x)
96
157
# Check if it's a VarInfo tangent
97
158
is_dppl_varinfo_tangent (x) || return false
98
159
@@ -105,48 +166,32 @@ function needs_caching_for_varinfo(x)
105
166
end
106
167
107
168
"""
108
- Check if a tangent contains PossiblyUninitTangent{Any} which can cause infinite recursion .
169
+ Check if a tangent is a safe DynamicPPL type that can use NoCache .
109
170
"""
110
- function contains_possibly_uninit_any (x)
111
- x isa Mooncake. PossiblyUninitTangent{Any} && return true
171
+ function is_safe_dppl_type (x)
172
+ # Metadata is always safe
173
+ is_dppl_metadata_tangent (x) && return true
112
174
113
- if x isa Tangent && hasfield ( typeof (x), :fields )
114
- for (_, fval) in pairs (x . fields )
115
- contains_possibly_uninit_any (fval ) && return true
116
- end
117
- elseif x isa MutableTangent && hasfield ( typeof (x), :fields )
118
- hasfield ( typeof (x . fields), :contents ) &&
119
- x . fields . contents isa Mooncake . PossiblyUninitTangent{Any} &&
120
- return true
175
+ # Model tangents without closures are safe
176
+ if is_dppl_model_tangent (x )
177
+ ! is_closure_with_circular_refs (x . fields . f ) && return true
178
+ end
179
+
180
+ # VarInfo without Ref fields is safe
181
+ if is_dppl_varinfo_tangent (x)
182
+ ! check_for_ref_fields (x) && return true
121
183
end
122
184
123
185
return false
124
186
end
125
187
126
- function Mooncake. set_to_zero!! (x)
127
- # Always use caching if we detect PossiblyUninitTangent{Any} anywhere
128
- if contains_possibly_uninit_any (x)
129
- return set_to_zero_internal!! (IdDict {Any,Bool} (), x)
130
- end
188
+ # =======================
189
+ # Main Entry Point
190
+ # =======================
131
191
132
- # Check for DynamicPPL types and use NoCache for better performance
133
- if is_dppl_ldf_tangent (x)
134
- # Special handling for LogDensityFunction to detect closures
135
- model_f_tangent = x. fields. model. fields. f
136
- cache = is_closure_model (model_f_tangent) ? IdDict {Any,Bool} () : NoCache ()
137
- return set_to_zero_internal!! (cache, x)
138
- elseif is_dppl_varinfo_tangent (x) && needs_caching_for_varinfo (x)
139
- # Use IdDict for SimpleVarInfo with Ref fields to avoid circular references
140
- return set_to_zero_internal!! (IdDict {Any,Bool} (), x)
141
- elseif is_dppl_varinfo_tangent (x) ||
142
- is_dppl_model_tangent (x) ||
143
- is_dppl_metadata_tangent (x)
144
- # These types can always use NoCache
145
- return set_to_zero_internal!! (NoCache (), x)
146
- else
147
- # Use the original implementation with IdDict for all other types
148
- return set_to_zero_internal!! (IdDict {Any,Bool} (), x)
149
- end
192
+ function Mooncake. set_to_zero!! (x)
193
+ cache = determine_cache_strategy (x)
194
+ return set_to_zero_internal!! (cache, x)
150
195
end
151
196
152
197
end # module
0 commit comments