Skip to content

Commit 92f935d

Browse files
committed
refactor more; add additional test
1 parent aabc844 commit 92f935d

File tree

2 files changed

+135
-90
lines changed

2 files changed

+135
-90
lines changed

ext/DynamicPPLMooncakeExt.jl

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,7 @@ using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_interna
1111
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
1212

1313
# =======================
14-
# Cache Strategy System
15-
# =======================
16-
17-
"""
18-
determine_cache_strategy(x)
19-
20-
Determines the appropriate caching strategy for a given tangent.
21-
Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk.
22-
"""
23-
function determine_cache_strategy(x)
24-
# Fast path: check for known circular reference patterns
25-
has_circular_reference_risk(x) && return IdDict{Any,Bool}()
26-
27-
# Check for DynamicPPL types that can safely use NoCache
28-
is_safe_dppl_type(x) && return NoCache()
29-
30-
# Special case: LogDensityFunction without problematic patterns can use NoCache
31-
if is_dppl_ldf_tangent(x)
32-
return NoCache()
33-
end
34-
35-
# Default to safe caching for unknown types
36-
return IdDict{Any,Bool}()
37-
end
38-
39-
# =======================
40-
# Type Recognition
14+
# `Mooncake.set_to_zero!!` optimization with `NoCache`
4115
# =======================
4216

4317
"""
@@ -61,9 +35,6 @@ function has_expected_structure(
6135
return true
6236
end
6337

64-
"""
65-
Check if a tangent corresponds to a DynamicPPL.LogDensityFunction
66-
"""
6738
function is_dppl_ldf_tangent(x)
6839
has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) ||
6940
return false
@@ -75,58 +46,62 @@ function is_dppl_ldf_tangent(x)
7546
return true
7647
end
7748

78-
"""
79-
Check if a tangent corresponds to a DynamicPPL.VarInfo
80-
"""
8149
function is_dppl_varinfo_tangent(x)
8250
return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce))
8351
end
8452

85-
"""
86-
Check if a tangent corresponds to a DynamicPPL.Model
87-
"""
8853
function is_dppl_model_tangent(x)
8954
return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context))
9055
end
9156

92-
"""
93-
Check if a MutableTangent corresponds to DynamicPPL.Metadata
94-
"""
9557
function is_dppl_metadata_tangent(x)
96-
return has_expected_structure(
58+
# Metadata can be either:
59+
# 1. A MutableTangent with the expected fields (for single metadata)
60+
# 2. A NamedTuple where each value is a Tangent with the expected fields
61+
62+
# Check for MutableTangent case
63+
if has_expected_structure(
9764
x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
9865
)
99-
end
66+
return true
67+
end
10068

101-
# =======================
102-
# Circular Reference Detection
103-
# =======================
69+
# Check for NamedTuple case (multiple metadata)
70+
if x isa NamedTuple
71+
# Each value should be a Tangent with metadata fields
72+
for var_metadata in values(x)
73+
if !has_expected_structure(
74+
var_metadata,
75+
Tangent,
76+
(:idcs, :vns, :ranges, :vals, :dists, :orders, :flags),
77+
)
78+
return false
79+
end
80+
end
81+
return true
82+
end
83+
84+
return false
85+
end
10486

10587
"""
10688
has_circular_reference_risk(x)
10789
10890
Main entry point for detecting circular reference patterns that require caching.
109-
Optimized for performance with targeted checks instead of recursive traversal.
11091
"""
11192
function has_circular_reference_risk(x)
112-
# Type-specific targeted checks only
11393
if is_dppl_ldf_tangent(x)
11494
# Check model function for closure patterns with circular refs
11595
model_f = x.fields.model.fields.f
11696
return is_closure_with_circular_refs(model_f)
11797
elseif is_dppl_varinfo_tangent(x)
118-
# Check for Ref fields in VarInfo
11998
return check_for_ref_fields(x)
12099
end
121100

122101
# For unknown types, do a shallow check for PossiblyUninitTangent{Any}
123102
return x isa Mooncake.PossiblyUninitTangent{Any}
124103
end
125104

126-
"""
127-
Check if a tangent represents a closure with circular reference patterns.
128-
Only returns true for actual problematic patterns, not all MutableTangents.
129-
"""
130105
function is_closure_with_circular_refs(x)
131106
# Check if MutableTangent contains PossiblyUninitTangent{Any}
132107
if x isa MutableTangent && hasfield(typeof(x), :fields)
@@ -150,9 +125,6 @@ function is_closure_with_circular_refs(x)
150125
return false
151126
end
152127

153-
"""
154-
Check if a VarInfo tangent has Ref fields that need caching.
155-
"""
156128
function check_for_ref_fields(x)
157129
# Check if it's a VarInfo tangent
158130
is_dppl_varinfo_tangent(x) || return false
@@ -165,9 +137,6 @@ function check_for_ref_fields(x)
165137
return logp_tangent isa MutableTangent
166138
end
167139

168-
"""
169-
Check if a tangent is a safe DynamicPPL type that can use NoCache.
170-
"""
171140
function is_safe_dppl_type(x)
172141
# Metadata is always safe
173142
is_dppl_metadata_tangent(x) && return true
@@ -185,9 +154,27 @@ function is_safe_dppl_type(x)
185154
return false
186155
end
187156

188-
# =======================
189-
# Main Entry Point
190-
# =======================
157+
"""
158+
determine_cache_strategy(x)
159+
160+
Determines the appropriate caching strategy for a given tangent.
161+
Returns either `NoCache()` for safe types or `IdDict{Any,Bool}()` for types with circular reference risk.
162+
"""
163+
function determine_cache_strategy(x)
164+
# Fast path: check for known circular reference patterns
165+
has_circular_reference_risk(x) && return IdDict{Any,Bool}()
166+
167+
# Check for DynamicPPL types that can safely use NoCache
168+
is_safe_dppl_type(x) && return NoCache()
169+
170+
# Special case: LogDensityFunction without problematic patterns can use NoCache
171+
if is_dppl_ldf_tangent(x)
172+
return NoCache()
173+
end
174+
175+
# Default to safe caching for unknown types
176+
return IdDict{Any,Bool}()
177+
end
191178

192179
function Mooncake.set_to_zero!!(x)
193180
cache = determine_cache_strategy(x)

test/ext/DynamicPPLMooncakeExt.jl

Lines changed: 88 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,8 @@ end
2828
)
2929
end
3030

31-
@testset "set_to_zero!! optimization" begin
32-
# Test with a real DynamicPPL model
33-
model = test_model1([1.0, 2.0, 3.0])
34-
vi = VarInfo(Random.default_rng(), model)
35-
ldf = LogDensityFunction(model, vi, DefaultContext())
36-
tangent = zero_tangent(ldf)
37-
38-
# Test that set_to_zero!! works correctly
39-
result = set_to_zero!!(deepcopy(tangent))
40-
@test result isa typeof(tangent)
41-
42-
# Test with metadata - verify structure exists
43-
if hasfield(typeof(tangent.fields.varinfo.fields), :metadata)
44-
metadata = tangent.fields.varinfo.fields.metadata
45-
@test !isnothing(metadata)
46-
end
47-
end
48-
49-
@testset "NoCache optimization correctness" begin
50-
# Test that set_to_zero!! uses NoCache for DynamicPPL types
31+
@testset "set_to_zero!! correctness" begin
32+
# Test that set_to_zero!! works correctly for DynamicPPL types
5133
model = test_model1([1.0, 2.0, 3.0])
5234
vi = VarInfo(Random.default_rng(), model)
5335
ldf = LogDensityFunction(model, vi, DefaultContext())
@@ -63,7 +45,8 @@ end
6345
end
6446

6547
# Call set_to_zero!! and verify it works
66-
set_to_zero!!(tangent)
48+
result = set_to_zero!!(tangent)
49+
@test result isa typeof(tangent)
6750

6851
# Check that values are zeroed
6952
if hasfield(typeof(tangent.fields.model.fields), :args) &&
@@ -76,15 +59,7 @@ end
7659
end
7760

7861
@testset "Performance improvement" begin
79-
# Test with DEMO_MODELS if available
80-
if isdefined(DynamicPPL.TestUtils, :DEMO_MODELS) &&
81-
!isempty(DynamicPPL.TestUtils.DEMO_MODELS)
82-
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
83-
else
84-
# Fallback to our test model
85-
model = test_model1([1.0, 2.0, 3.0, 4.0])
86-
end
87-
62+
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
8863
vi = VarInfo(Random.default_rng(), model)
8964
ldf = LogDensityFunction(model, vi, DefaultContext())
9065
tangent = zero_tangent(ldf)
@@ -189,4 +164,87 @@ end
189164
# Global should be faster (uses NoCache)
190165
@test time_global < time_closure
191166
end
167+
168+
@testset "Struct field assumptions" begin
169+
# Test that our assumptions about DynamicPPL struct fields are correct
170+
# These tests will fail if DynamicPPL changes its internal structure
171+
172+
@testset "LogDensityFunction tangent structure" begin
173+
model = test_model1([1.0, 2.0, 3.0])
174+
vi = VarInfo(Random.default_rng(), model)
175+
ldf = LogDensityFunction(model, vi, DefaultContext())
176+
tangent = zero_tangent(ldf)
177+
178+
# Test expected fields exist
179+
@test hasfield(typeof(tangent), :fields)
180+
@test hasfield(typeof(tangent.fields), :model)
181+
@test hasfield(typeof(tangent.fields), :varinfo)
182+
@test hasfield(typeof(tangent.fields), :context)
183+
@test hasfield(typeof(tangent.fields), :adtype)
184+
@test hasfield(typeof(tangent.fields), :prep)
185+
186+
# Test exact field names match
187+
@test propertynames(tangent.fields) ==
188+
(:model, :varinfo, :context, :adtype, :prep)
189+
end
190+
191+
@testset "VarInfo tangent structure" begin
192+
model = test_model1([1.0, 2.0, 3.0])
193+
vi = VarInfo(Random.default_rng(), model)
194+
tangent_vi = zero_tangent(vi)
195+
196+
# Test expected fields exist
197+
@test hasfield(typeof(tangent_vi), :fields)
198+
@test hasfield(typeof(tangent_vi.fields), :metadata)
199+
@test hasfield(typeof(tangent_vi.fields), :logp)
200+
@test hasfield(typeof(tangent_vi.fields), :num_produce)
201+
202+
# Test exact field names match
203+
@test propertynames(tangent_vi.fields) == (:metadata, :logp, :num_produce)
204+
end
205+
206+
@testset "Model tangent structure" begin
207+
model = test_model1([1.0, 2.0, 3.0])
208+
tangent_model = zero_tangent(model)
209+
210+
# Test expected fields exist
211+
@test hasfield(typeof(tangent_model), :fields)
212+
@test hasfield(typeof(tangent_model.fields), :f)
213+
@test hasfield(typeof(tangent_model.fields), :args)
214+
@test hasfield(typeof(tangent_model.fields), :defaults)
215+
@test hasfield(typeof(tangent_model.fields), :context)
216+
217+
# Test exact field names match
218+
@test propertynames(tangent_model.fields) == (:f, :args, :defaults, :context)
219+
end
220+
221+
@testset "Metadata tangent structure" begin
222+
model = test_model1([1.0, 2.0, 3.0])
223+
vi = VarInfo(Random.default_rng(), model)
224+
tangent_vi = zero_tangent(vi)
225+
metadata = tangent_vi.fields.metadata
226+
227+
# Metadata is a NamedTuple with variable names as keys
228+
@test metadata isa NamedTuple
229+
230+
# Each variable's metadata should be a Tangent with the expected fields
231+
for (varname, var_metadata) in pairs(metadata)
232+
@test var_metadata isa Mooncake.Tangent
233+
@test hasfield(typeof(var_metadata), :fields)
234+
235+
# Test expected fields exist
236+
@test hasfield(typeof(var_metadata.fields), :idcs)
237+
@test hasfield(typeof(var_metadata.fields), :vns)
238+
@test hasfield(typeof(var_metadata.fields), :ranges)
239+
@test hasfield(typeof(var_metadata.fields), :vals)
240+
@test hasfield(typeof(var_metadata.fields), :dists)
241+
@test hasfield(typeof(var_metadata.fields), :orders)
242+
@test hasfield(typeof(var_metadata.fields), :flags)
243+
244+
# Test exact field names match
245+
@test propertynames(var_metadata.fields) ==
246+
(:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
247+
end
248+
end
249+
end
192250
end

0 commit comments

Comments
 (0)