Skip to content

Commit 76622ae

Browse files
committed
use Mooncake.requires_cache function
1 parent de57edd commit 76622ae

File tree

2 files changed

+18
-250
lines changed

2 files changed

+18
-250
lines changed

ext/DynamicPPLMooncakeExt.jl

Lines changed: 18 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,184 +1,35 @@
11
module DynamicPPLMooncakeExt
22

3-
__precompile__(false)
4-
53
using DynamicPPL: DynamicPPL, istrans
64
using Mooncake: Mooncake
7-
import Mooncake: set_to_zero!!
8-
using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!!
95

106
# This is purely an optimisation.
117
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
128

13-
# =======================
14-
# `Mooncake.set_to_zero!!` optimization with `NoCache`
15-
# =======================
16-
17-
"""
18-
Check if a tangent has the expected structure for a given type.
19-
"""
20-
function has_expected_structure(
21-
x, expected_type::Type{<:Union{Tangent,MutableTangent}}, expected_fields
22-
)
23-
x isa expected_type || return false
24-
hasfield(typeof(x), :fields) || return false
9+
@static if isdefined(Mooncake, :requires_cache)
10+
import Mooncake: requires_cache
2511

26-
fields = x.fields
27-
if expected_fields isa Tuple
28-
# Exact match required
29-
propertynames(fields) == expected_fields || return false
30-
else
31-
# All expected fields must be present
32-
all(f in propertynames(fields) for f in expected_fields) || return false
12+
function Mooncake.requires_cache(::Type{<:DynamicPPL.Metadata})
13+
return Val(false)
3314
end
34-
35-
return true
36-
end
37-
38-
function is_dppl_ldf_tangent(x)
39-
has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) ||
40-
return false
41-
42-
fields = x.fields
43-
is_dppl_varinfo_tangent(fields.varinfo) || return false
44-
is_dppl_model_tangent(fields.model) || return false
45-
46-
return true
47-
end
48-
49-
function is_dppl_varinfo_tangent(x)
50-
return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce))
51-
end
52-
53-
function is_dppl_model_tangent(x)
54-
return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context))
55-
end
56-
57-
function is_dppl_metadata_tangent(x)
58-
# Metadata can be either:
59-
# 1. A MutableTangent with the expected fields (for single metadata)
60-
# 2. A NamedTuple where each value is a Tangent with the expected fields
61-
62-
# Check for MutableTangent case
63-
if has_expected_structure(
64-
x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
65-
)
66-
return true
15+
16+
function Mooncake.requires_cache(::Type{<:DynamicPPL.TypedVarInfo})
17+
return Val(false)
6718
end
68-
69-
# Check for NamedTuple case (multiple metadata)
70-
if x isa NamedTuple
71-
# Each value should be a Tangent with metadata fields
72-
for var_metadata in values(x)
73-
if !has_expected_structure(
74-
var_metadata,
75-
Tangent,
76-
(:idcs, :vns, :ranges, :vals, :dists, :orders, :flags),
77-
)
78-
return false
79-
end
80-
end
81-
return true
19+
20+
function Mooncake.requires_cache(::Type{<:DynamicPPL.Model})
21+
# Model has f (function/closure), args, defaults, context
22+
# Closures can have circular references
23+
return Val(false)
8224
end
83-
84-
return false
85-
end
86-
87-
"""
88-
has_circular_reference_risk(x)
89-
90-
Main entry point for detecting circular reference patterns that require caching.
91-
"""
92-
function has_circular_reference_risk(x)
93-
if is_dppl_ldf_tangent(x)
94-
# Check model function for closure patterns with circular refs
95-
model_f = x.fields.model.fields.f
96-
return is_closure_with_circular_refs(model_f)
97-
elseif is_dppl_varinfo_tangent(x)
98-
return check_for_ref_fields(x)
25+
26+
function Mooncake.requires_cache(::Type{<:DynamicPPL.LogDensityFunction})
27+
return Val(false)
9928
end
100-
101-
# For unknown types, do a shallow check for PossiblyUninitTangent{Any}
102-
return x isa Mooncake.PossiblyUninitTangent{Any}
103-
end
104-
105-
function is_closure_with_circular_refs(x)
106-
# Check if MutableTangent contains PossiblyUninitTangent{Any}
107-
if x isa MutableTangent && hasfield(typeof(x), :fields)
108-
hasfield(typeof(x.fields), :contents) &&
109-
x.fields.contents isa Mooncake.PossiblyUninitTangent{Any} &&
110-
return true
29+
30+
function Mooncake.requires_cache(::Type{<:DynamicPPL.AbstractContext})
31+
return Val(false)
11132
end
112-
113-
# For Tangent, only check immediate fields (no deep recursion)
114-
if x isa Tangent && hasfield(typeof(x), :fields)
115-
for (_, fval) in pairs(x.fields)
116-
if fval isa MutableTangent &&
117-
hasfield(typeof(fval), :fields) &&
118-
hasfield(typeof(fval.fields), :contents) &&
119-
fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any}
120-
return true
121-
end
122-
end
123-
end
124-
125-
return false
126-
end
127-
128-
function check_for_ref_fields(x)
129-
# Check if it's a VarInfo tangent
130-
is_dppl_varinfo_tangent(x) || return false
131-
132-
# Check if the logp field contains a Ref-like tangent structure
133-
hasfield(typeof(x.fields), :logp) || return false
134-
logp_tangent = x.fields.logp
135-
136-
# Ref types in tangents often appear as MutableTangent with circular references
137-
return logp_tangent isa MutableTangent
138-
end
139-
140-
function is_safe_dppl_type(x)
141-
# Metadata is always safe
142-
is_dppl_metadata_tangent(x) && return true
143-
144-
# Model tangents without closures are safe
145-
if is_dppl_model_tangent(x)
146-
!is_closure_with_circular_refs(x.fields.f) && return true
147-
end
148-
149-
# VarInfo without Ref fields is safe
150-
if is_dppl_varinfo_tangent(x)
151-
!check_for_ref_fields(x) && return true
152-
end
153-
154-
return false
155-
end
156-
157-
"""
158-
determine_cache_strategy(x)
159-
160-
Determines the appropriate caching strategy for a given tangent.
161-
Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk.
162-
"""
163-
function determine_cache_strategy(x)
164-
# Fast path: check for known circular reference patterns
165-
has_circular_reference_risk(x) && return IdDict{Any,Bool}()
166-
167-
# Check for DynamicPPL types that can safely use NoCache
168-
is_safe_dppl_type(x) && return NoCache()
169-
170-
# Special case: LogDensityFunction without problematic patterns can use NoCache
171-
if is_dppl_ldf_tangent(x)
172-
return NoCache()
173-
end
174-
175-
# Default to safe caching for unknown types
176-
return IdDict{Any,Bool}()
177-
end
178-
179-
function Mooncake.set_to_zero!!(x)
180-
cache = determine_cache_strategy(x)
181-
return set_to_zero_internal!!(cache, x)
18233
end
18334

18435
end # module

test/ext/DynamicPPLMooncakeExt.jl

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -164,87 +164,4 @@ end
164164
# Global should be faster (uses NoCache)
165165
@test time_global < time_closure
166166
end
167-
168-
@testset "Struct field assumptions" begin
169-
# Test that our assumptions about DynamicPPL struct fields are correct
170-
# These tests will fail if DynamicPPL changes its internal structure
171-
172-
@testset "LogDensityFunction tangent structure" begin
173-
model = test_model1([1.0, 2.0, 3.0])
174-
vi = VarInfo(Random.default_rng(), model)
175-
ldf = LogDensityFunction(model, vi, DefaultContext())
176-
tangent = zero_tangent(ldf)
177-
178-
# Test expected fields exist
179-
@test hasfield(typeof(tangent), :fields)
180-
@test hasfield(typeof(tangent.fields), :model)
181-
@test hasfield(typeof(tangent.fields), :varinfo)
182-
@test hasfield(typeof(tangent.fields), :context)
183-
@test hasfield(typeof(tangent.fields), :adtype)
184-
@test hasfield(typeof(tangent.fields), :prep)
185-
186-
# Test exact field names match
187-
@test propertynames(tangent.fields) ==
188-
(:model, :varinfo, :context, :adtype, :prep)
189-
end
190-
191-
@testset "VarInfo tangent structure" begin
192-
model = test_model1([1.0, 2.0, 3.0])
193-
vi = VarInfo(Random.default_rng(), model)
194-
tangent_vi = zero_tangent(vi)
195-
196-
# Test expected fields exist
197-
@test hasfield(typeof(tangent_vi), :fields)
198-
@test hasfield(typeof(tangent_vi.fields), :metadata)
199-
@test hasfield(typeof(tangent_vi.fields), :logp)
200-
@test hasfield(typeof(tangent_vi.fields), :num_produce)
201-
202-
# Test exact field names match
203-
@test propertynames(tangent_vi.fields) == (:metadata, :logp, :num_produce)
204-
end
205-
206-
@testset "Model tangent structure" begin
207-
model = test_model1([1.0, 2.0, 3.0])
208-
tangent_model = zero_tangent(model)
209-
210-
# Test expected fields exist
211-
@test hasfield(typeof(tangent_model), :fields)
212-
@test hasfield(typeof(tangent_model.fields), :f)
213-
@test hasfield(typeof(tangent_model.fields), :args)
214-
@test hasfield(typeof(tangent_model.fields), :defaults)
215-
@test hasfield(typeof(tangent_model.fields), :context)
216-
217-
# Test exact field names match
218-
@test propertynames(tangent_model.fields) == (:f, :args, :defaults, :context)
219-
end
220-
221-
@testset "Metadata tangent structure" begin
222-
model = test_model1([1.0, 2.0, 3.0])
223-
vi = VarInfo(Random.default_rng(), model)
224-
tangent_vi = zero_tangent(vi)
225-
metadata = tangent_vi.fields.metadata
226-
227-
# Metadata is a NamedTuple with variable names as keys
228-
@test metadata isa NamedTuple
229-
230-
# Each variable's metadata should be a Tangent with the expected fields
231-
for (varname, var_metadata) in pairs(metadata)
232-
@test var_metadata isa Mooncake.Tangent
233-
@test hasfield(typeof(var_metadata), :fields)
234-
235-
# Test expected fields exist
236-
@test hasfield(typeof(var_metadata.fields), :idcs)
237-
@test hasfield(typeof(var_metadata.fields), :vns)
238-
@test hasfield(typeof(var_metadata.fields), :ranges)
239-
@test hasfield(typeof(var_metadata.fields), :vals)
240-
@test hasfield(typeof(var_metadata.fields), :dists)
241-
@test hasfield(typeof(var_metadata.fields), :orders)
242-
@test hasfield(typeof(var_metadata.fields), :flags)
243-
244-
# Test exact field names match
245-
@test propertynames(var_metadata.fields) ==
246-
(:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
247-
end
248-
end
249-
end
250167
end

0 commit comments

Comments
 (0)