|
1 | 1 | module DynamicPPLMooncakeExt
|
2 | 2 |
|
3 |
| -__precompile__(false) |
4 |
| - |
5 | 3 | using DynamicPPL: DynamicPPL, istrans
|
6 | 4 | using Mooncake: Mooncake
|
7 |
| -import Mooncake: set_to_zero!! |
8 |
| -using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!! |
9 | 5 |
|
10 | 6 | # This is purely an optimisation.
|
11 | 7 | Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
|
12 | 8 |
|
13 |
| -# ======================= |
14 |
| -# `Mooncake.set_to_zero!!` optimization with `NoCache` |
15 |
| -# ======================= |
16 |
| - |
17 |
| -""" |
18 |
| -Check if a tangent has the expected structure for a given type. |
19 |
| -""" |
20 |
| -function has_expected_structure( |
21 |
| - x, expected_type::Type{<:Union{Tangent,MutableTangent}}, expected_fields |
22 |
| -) |
23 |
| - x isa expected_type || return false |
24 |
| - hasfield(typeof(x), :fields) || return false |
| 9 | +@static if isdefined(Mooncake, :requires_cache) |
| 10 | + import Mooncake: requires_cache |
25 | 11 |
|
26 |
| - fields = x.fields |
27 |
| - if expected_fields isa Tuple |
28 |
| - # Exact match required |
29 |
| - propertynames(fields) == expected_fields || return false |
30 |
| - else |
31 |
| - # All expected fields must be present |
32 |
| - all(f in propertynames(fields) for f in expected_fields) || return false |
| 12 | + function Mooncake.requires_cache(::Type{<:DynamicPPL.Metadata}) |
| 13 | + return Val(false) |
33 | 14 | end
|
34 |
| - |
35 |
| - return true |
36 |
| -end |
37 |
| - |
38 |
| -function is_dppl_ldf_tangent(x) |
39 |
| - has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) || |
40 |
| - return false |
41 |
| - |
42 |
| - fields = x.fields |
43 |
| - is_dppl_varinfo_tangent(fields.varinfo) || return false |
44 |
| - is_dppl_model_tangent(fields.model) || return false |
45 |
| - |
46 |
| - return true |
47 |
| -end |
48 |
| - |
49 |
| -function is_dppl_varinfo_tangent(x) |
50 |
| - return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce)) |
51 |
| -end |
52 |
| - |
53 |
| -function is_dppl_model_tangent(x) |
54 |
| - return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context)) |
55 |
| -end |
56 |
| - |
57 |
| -function is_dppl_metadata_tangent(x) |
58 |
| - # Metadata can be either: |
59 |
| - # 1. A MutableTangent with the expected fields (for single metadata) |
60 |
| - # 2. A NamedTuple where each value is a Tangent with the expected fields |
61 |
| - |
62 |
| - # Check for MutableTangent case |
63 |
| - if has_expected_structure( |
64 |
| - x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) |
65 |
| - ) |
66 |
| - return true |
| 15 | + |
| 16 | + function Mooncake.requires_cache(::Type{<:DynamicPPL.TypedVarInfo}) |
| 17 | + return Val(false) |
67 | 18 | end
|
68 |
| - |
69 |
| - # Check for NamedTuple case (multiple metadata) |
70 |
| - if x isa NamedTuple |
71 |
| - # Each value should be a Tangent with metadata fields |
72 |
| - for var_metadata in values(x) |
73 |
| - if !has_expected_structure( |
74 |
| - var_metadata, |
75 |
| - Tangent, |
76 |
| - (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags), |
77 |
| - ) |
78 |
| - return false |
79 |
| - end |
80 |
| - end |
81 |
| - return true |
| 19 | + |
| 20 | + function Mooncake.requires_cache(::Type{<:DynamicPPL.Model}) |
| 21 | + # Model has f (function/closure), args, defaults, context |
| 22 | + # Closures can have circular references |
| 23 | + return Val(false) |
82 | 24 | end
|
83 |
| - |
84 |
| - return false |
85 |
| -end |
86 |
| - |
87 |
| -""" |
88 |
| - has_circular_reference_risk(x) |
89 |
| -
|
90 |
| -Main entry point for detecting circular reference patterns that require caching. |
91 |
| -""" |
92 |
| -function has_circular_reference_risk(x) |
93 |
| - if is_dppl_ldf_tangent(x) |
94 |
| - # Check model function for closure patterns with circular refs |
95 |
| - model_f = x.fields.model.fields.f |
96 |
| - return is_closure_with_circular_refs(model_f) |
97 |
| - elseif is_dppl_varinfo_tangent(x) |
98 |
| - return check_for_ref_fields(x) |
| 25 | + |
| 26 | + function Mooncake.requires_cache(::Type{<:DynamicPPL.LogDensityFunction}) |
| 27 | + return Val(false) |
99 | 28 | end
|
100 |
| - |
101 |
| - # For unknown types, do a shallow check for PossiblyUninitTangent{Any} |
102 |
| - return x isa Mooncake.PossiblyUninitTangent{Any} |
103 |
| -end |
104 |
| - |
105 |
| -function is_closure_with_circular_refs(x) |
106 |
| - # Check if MutableTangent contains PossiblyUninitTangent{Any} |
107 |
| - if x isa MutableTangent && hasfield(typeof(x), :fields) |
108 |
| - hasfield(typeof(x.fields), :contents) && |
109 |
| - x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} && |
110 |
| - return true |
| 29 | + |
| 30 | + function Mooncake.requires_cache(::Type{<:DynamicPPL.AbstractContext}) |
| 31 | + return Val(false) |
111 | 32 | end
|
112 |
| - |
113 |
| - # For Tangent, only check immediate fields (no deep recursion) |
114 |
| - if x isa Tangent && hasfield(typeof(x), :fields) |
115 |
| - for (_, fval) in pairs(x.fields) |
116 |
| - if fval isa MutableTangent && |
117 |
| - hasfield(typeof(fval), :fields) && |
118 |
| - hasfield(typeof(fval.fields), :contents) && |
119 |
| - fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any} |
120 |
| - return true |
121 |
| - end |
122 |
| - end |
123 |
| - end |
124 |
| - |
125 |
| - return false |
126 |
| -end |
127 |
| - |
128 |
| -function check_for_ref_fields(x) |
129 |
| - # Check if it's a VarInfo tangent |
130 |
| - is_dppl_varinfo_tangent(x) || return false |
131 |
| - |
132 |
| - # Check if the logp field contains a Ref-like tangent structure |
133 |
| - hasfield(typeof(x.fields), :logp) || return false |
134 |
| - logp_tangent = x.fields.logp |
135 |
| - |
136 |
| - # Ref types in tangents often appear as MutableTangent with circular references |
137 |
| - return logp_tangent isa MutableTangent |
138 |
| -end |
139 |
| - |
140 |
| -function is_safe_dppl_type(x) |
141 |
| - # Metadata is always safe |
142 |
| - is_dppl_metadata_tangent(x) && return true |
143 |
| - |
144 |
| - # Model tangents without closures are safe |
145 |
| - if is_dppl_model_tangent(x) |
146 |
| - !is_closure_with_circular_refs(x.fields.f) && return true |
147 |
| - end |
148 |
| - |
149 |
| - # VarInfo without Ref fields is safe |
150 |
| - if is_dppl_varinfo_tangent(x) |
151 |
| - !check_for_ref_fields(x) && return true |
152 |
| - end |
153 |
| - |
154 |
| - return false |
155 |
| -end |
156 |
| - |
157 |
| -""" |
158 |
| - determine_cache_strategy(x) |
159 |
| -
|
160 |
| -Determines the appropriate caching strategy for a given tangent. |
161 |
| -Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk. |
162 |
| -""" |
163 |
| -function determine_cache_strategy(x) |
164 |
| - # Fast path: check for known circular reference patterns |
165 |
| - has_circular_reference_risk(x) && return IdDict{Any,Bool}() |
166 |
| - |
167 |
| - # Check for DynamicPPL types that can safely use NoCache |
168 |
| - is_safe_dppl_type(x) && return NoCache() |
169 |
| - |
170 |
| - # Special case: LogDensityFunction without problematic patterns can use NoCache |
171 |
| - if is_dppl_ldf_tangent(x) |
172 |
| - return NoCache() |
173 |
| - end |
174 |
| - |
175 |
| - # Default to safe caching for unknown types |
176 |
| - return IdDict{Any,Bool}() |
177 |
| -end |
178 |
| - |
179 |
| -function Mooncake.set_to_zero!!(x) |
180 |
| - cache = determine_cache_strategy(x) |
181 |
| - return set_to_zero_internal!!(cache, x) |
182 | 33 | end
|
183 | 34 |
|
184 | 35 | end # module
|
0 commit comments