Skip to content

Commit 66f453c

Browse files
committed
resolve CI error
1 parent 7b99643 commit 66f453c

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

ext/DynamicPPLMooncakeExt.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,55 @@ function is_closure_model(model_f_tangent)
8989
return false
9090
end
9191

92+
"""
93+
Check if a VarInfo tangent needs caching due to circular references (e.g., Ref fields).
94+
"""
95+
function needs_caching_for_varinfo(x)
96+
# Check if it's a VarInfo tangent
97+
is_dppl_varinfo_tangent(x) || return false
98+
99+
# Check if the logp field contains a Ref-like tangent structure
100+
hasfield(typeof(x.fields), :logp) || return false
101+
logp_tangent = x.fields.logp
102+
103+
# Ref types in tangents often appear as MutableTangent with circular references
104+
return logp_tangent isa MutableTangent
105+
end
106+
107+
"""
108+
Check if a tangent contains PossiblyUninitTangent{Any} which can cause infinite recursion.
109+
"""
110+
function contains_possibly_uninit_any(x)
111+
x isa Mooncake.PossiblyUninitTangent{Any} && return true
112+
113+
if x isa Tangent && hasfield(typeof(x), :fields)
114+
for (_, fval) in pairs(x.fields)
115+
contains_possibly_uninit_any(fval) && return true
116+
end
117+
elseif x isa MutableTangent && hasfield(typeof(x), :fields)
118+
hasfield(typeof(x.fields), :contents) &&
119+
x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} &&
120+
return true
121+
end
122+
123+
return false
124+
end
125+
92126
function Mooncake.set_to_zero!!(x)
127+
# Always use caching if we detect PossiblyUninitTangent{Any} anywhere
128+
if contains_possibly_uninit_any(x)
129+
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
130+
end
131+
93132
# Check for DynamicPPL types and use NoCache for better performance
94133
if is_dppl_ldf_tangent(x)
95134
# Special handling for LogDensityFunction to detect closures
96135
model_f_tangent = x.fields.model.fields.f
97136
cache = is_closure_model(model_f_tangent) ? IdDict{Any,Bool}() : NoCache()
98137
return set_to_zero_internal!!(cache, x)
138+
elseif is_dppl_varinfo_tangent(x) && needs_caching_for_varinfo(x)
139+
# Use IdDict for SimpleVarInfo with Ref fields to avoid circular references
140+
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
99141
elseif is_dppl_varinfo_tangent(x) ||
100142
is_dppl_model_tangent(x) ||
101143
is_dppl_metadata_tangent(x)

0 commit comments

Comments
 (0)