Skip to content

Commit aabc844

Browse files
committed
refactor
1 parent 66f453c commit aabc844

File tree

1 file changed

+87
-42
lines changed

1 file changed

+87
-42
lines changed

ext/DynamicPPLMooncakeExt.jl

Lines changed: 87 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,36 @@ using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_interna
1010
# This is purely an optimisation.
1111
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
1212

13+
# =======================
14+
# Cache Strategy System
15+
# =======================
16+
17+
"""
18+
determine_cache_strategy(x)
19+
20+
Determines the appropriate caching strategy for a given tangent.
21+
Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk.
22+
"""
23+
function determine_cache_strategy(x)
24+
# Fast path: check for known circular reference patterns
25+
has_circular_reference_risk(x) && return IdDict{Any,Bool}()
26+
27+
# Check for DynamicPPL types that can safely use NoCache
28+
is_safe_dppl_type(x) && return NoCache()
29+
30+
# Special case: LogDensityFunction without problematic patterns can use NoCache
31+
if is_dppl_ldf_tangent(x)
32+
return NoCache()
33+
end
34+
35+
# Default to safe caching for unknown types
36+
return IdDict{Any,Bool}()
37+
end
38+
39+
# =======================
40+
# Type Recognition
41+
# =======================
42+
1343
"""
1444
Check if a tangent has the expected structure for a given type.
1545
"""
@@ -68,15 +98,46 @@ function is_dppl_metadata_tangent(x)
6898
)
6999
end
70100

101+
# =======================
102+
# Circular Reference Detection
103+
# =======================
104+
71105
"""
72-
Check if a model function tangent represents a closure.
106+
has_circular_reference_risk(x)
107+
108+
Main entry point for detecting circular reference patterns that require caching.
109+
Optimized for performance with targeted checks instead of recursive traversal.
73110
"""
74-
function is_closure_model(model_f_tangent)
75-
model_f_tangent isa MutableTangent && return true
111+
function has_circular_reference_risk(x)
112+
# Type-specific targeted checks only
113+
if is_dppl_ldf_tangent(x)
114+
# Check model function for closure patterns with circular refs
115+
model_f = x.fields.model.fields.f
116+
return is_closure_with_circular_refs(model_f)
117+
elseif is_dppl_varinfo_tangent(x)
118+
# Check for Ref fields in VarInfo
119+
return check_for_ref_fields(x)
120+
end
76121

77-
if model_f_tangent isa Tangent && hasfield(typeof(model_f_tangent), :fields)
78-
# Check if any field is a MutableTangent with PossiblyUninitTangent{Any}
79-
for (_, fval) in pairs(model_f_tangent.fields)
122+
# For unknown types, do a shallow check for PossiblyUninitTangent{Any}
123+
return x isa Mooncake.PossiblyUninitTangent{Any}
124+
end
125+
126+
"""
127+
Check if a tangent represents a closure with circular reference patterns.
128+
Only returns true for actual problematic patterns, not all MutableTangents.
129+
"""
130+
function is_closure_with_circular_refs(x)
131+
# Check if MutableTangent contains PossiblyUninitTangent{Any}
132+
if x isa MutableTangent && hasfield(typeof(x), :fields)
133+
hasfield(typeof(x.fields), :contents) &&
134+
x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} &&
135+
return true
136+
end
137+
138+
# For Tangent, only check immediate fields (no deep recursion)
139+
if x isa Tangent && hasfield(typeof(x), :fields)
140+
for (_, fval) in pairs(x.fields)
80141
if fval isa MutableTangent &&
81142
hasfield(typeof(fval), :fields) &&
82143
hasfield(typeof(fval.fields), :contents) &&
@@ -90,9 +151,9 @@ function is_closure_model(model_f_tangent)
90151
end
91152

92153
"""
93-
Check if a VarInfo tangent needs caching due to circular references (e.g., Ref fields).
154+
Check if a VarInfo tangent has Ref fields that need caching.
94155
"""
95-
function needs_caching_for_varinfo(x)
156+
function check_for_ref_fields(x)
96157
# Check if it's a VarInfo tangent
97158
is_dppl_varinfo_tangent(x) || return false
98159

@@ -105,48 +166,32 @@ function needs_caching_for_varinfo(x)
105166
end
106167

107168
"""
108-
Check if a tangent contains PossiblyUninitTangent{Any} which can cause infinite recursion.
169+
Check if a tangent is a safe DynamicPPL type that can use NoCache.
109170
"""
110-
function contains_possibly_uninit_any(x)
111-
x isa Mooncake.PossiblyUninitTangent{Any} && return true
171+
function is_safe_dppl_type(x)
172+
# Metadata is always safe
173+
is_dppl_metadata_tangent(x) && return true
112174

113-
if x isa Tangent && hasfield(typeof(x), :fields)
114-
for (_, fval) in pairs(x.fields)
115-
contains_possibly_uninit_any(fval) && return true
116-
end
117-
elseif x isa MutableTangent && hasfield(typeof(x), :fields)
118-
hasfield(typeof(x.fields), :contents) &&
119-
x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} &&
120-
return true
175+
# Model tangents without closures are safe
176+
if is_dppl_model_tangent(x)
177+
!is_closure_with_circular_refs(x.fields.f) && return true
178+
end
179+
180+
# VarInfo without Ref fields is safe
181+
if is_dppl_varinfo_tangent(x)
182+
!check_for_ref_fields(x) && return true
121183
end
122184

123185
return false
124186
end
125187

126-
function Mooncake.set_to_zero!!(x)
127-
# Always use caching if we detect PossiblyUninitTangent{Any} anywhere
128-
if contains_possibly_uninit_any(x)
129-
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
130-
end
188+
# =======================
189+
# Main Entry Point
190+
# =======================
131191

132-
# Check for DynamicPPL types and use NoCache for better performance
133-
if is_dppl_ldf_tangent(x)
134-
# Special handling for LogDensityFunction to detect closures
135-
model_f_tangent = x.fields.model.fields.f
136-
cache = is_closure_model(model_f_tangent) ? IdDict{Any,Bool}() : NoCache()
137-
return set_to_zero_internal!!(cache, x)
138-
elseif is_dppl_varinfo_tangent(x) && needs_caching_for_varinfo(x)
139-
# Use IdDict for SimpleVarInfo with Ref fields to avoid circular references
140-
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
141-
elseif is_dppl_varinfo_tangent(x) ||
142-
is_dppl_model_tangent(x) ||
143-
is_dppl_metadata_tangent(x)
144-
# These types can always use NoCache
145-
return set_to_zero_internal!!(NoCache(), x)
146-
else
147-
# Use the original implementation with IdDict for all other types
148-
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
149-
end
192+
function Mooncake.set_to_zero!!(x)
193+
cache = determine_cache_strategy(x)
194+
return set_to_zero_internal!!(cache, x)
150195
end
151196

152197
end # module

0 commit comments

Comments
 (0)