|
28 | 28 | )
|
29 | 29 | end
|
30 | 30 |
|
31 |
| - @testset "set_to_zero!! optimization" begin |
32 |
| - # Test with a real DynamicPPL model |
33 |
| - model = test_model1([1.0, 2.0, 3.0]) |
34 |
| - vi = VarInfo(Random.default_rng(), model) |
35 |
| - ldf = LogDensityFunction(model, vi, DefaultContext()) |
36 |
| - tangent = zero_tangent(ldf) |
37 |
| - |
38 |
| - # Test that set_to_zero!! works correctly |
39 |
| - result = set_to_zero!!(deepcopy(tangent)) |
40 |
| - @test result isa typeof(tangent) |
41 |
| - |
42 |
| - # Test with metadata - verify structure exists |
43 |
| - if hasfield(typeof(tangent.fields.varinfo.fields), :metadata) |
44 |
| - metadata = tangent.fields.varinfo.fields.metadata |
45 |
| - @test !isnothing(metadata) |
46 |
| - end |
47 |
| - end |
48 |
| - |
49 |
| - @testset "NoCache optimization correctness" begin |
50 |
| - # Test that set_to_zero!! uses NoCache for DynamicPPL types |
| 31 | + @testset "set_to_zero!! correctness" begin |
| 32 | + # Test that set_to_zero!! works correctly for DynamicPPL types |
51 | 33 | model = test_model1([1.0, 2.0, 3.0])
|
52 | 34 | vi = VarInfo(Random.default_rng(), model)
|
53 | 35 | ldf = LogDensityFunction(model, vi, DefaultContext())
|
|
63 | 45 | end
|
64 | 46 |
|
65 | 47 | # Call set_to_zero!! and verify it works
|
66 |
| - set_to_zero!!(tangent) |
| 48 | + result = set_to_zero!!(tangent) |
| 49 | + @test result isa typeof(tangent) |
67 | 50 |
|
68 | 51 | # Check that values are zeroed
|
69 | 52 | if hasfield(typeof(tangent.fields.model.fields), :args) &&
|
|
76 | 59 | end
|
77 | 60 |
|
78 | 61 | @testset "Performance improvement" begin
|
79 |
| - # Test with DEMO_MODELS if available |
80 |
| - if isdefined(DynamicPPL.TestUtils, :DEMO_MODELS) && |
81 |
| - !isempty(DynamicPPL.TestUtils.DEMO_MODELS) |
82 |
| - model = DynamicPPL.TestUtils.DEMO_MODELS[1] |
83 |
| - else |
84 |
| - # Fallback to our test model |
85 |
| - model = test_model1([1.0, 2.0, 3.0, 4.0]) |
86 |
| - end |
87 |
| - |
| 62 | + model = DynamicPPL.TestUtils.DEMO_MODELS[1] |
88 | 63 | vi = VarInfo(Random.default_rng(), model)
|
89 | 64 | ldf = LogDensityFunction(model, vi, DefaultContext())
|
90 | 65 | tangent = zero_tangent(ldf)
|
|
189 | 164 | # Global should be faster (uses NoCache)
|
190 | 165 | @test time_global < time_closure
|
191 | 166 | end
|
| 167 | + |
| 168 | + @testset "Struct field assumptions" begin |
| 169 | + # Test that our assumptions about DynamicPPL struct fields are correct |
| 170 | + # These tests will fail if DynamicPPL changes its internal structure |
| 171 | + |
| 172 | + @testset "LogDensityFunction tangent structure" begin |
| 173 | + model = test_model1([1.0, 2.0, 3.0]) |
| 174 | + vi = VarInfo(Random.default_rng(), model) |
| 175 | + ldf = LogDensityFunction(model, vi, DefaultContext()) |
| 176 | + tangent = zero_tangent(ldf) |
| 177 | + |
| 178 | + # Test expected fields exist |
| 179 | + @test hasfield(typeof(tangent), :fields) |
| 180 | + @test hasfield(typeof(tangent.fields), :model) |
| 181 | + @test hasfield(typeof(tangent.fields), :varinfo) |
| 182 | + @test hasfield(typeof(tangent.fields), :context) |
| 183 | + @test hasfield(typeof(tangent.fields), :adtype) |
| 184 | + @test hasfield(typeof(tangent.fields), :prep) |
| 185 | + |
| 186 | + # Test exact field names match |
| 187 | + @test propertynames(tangent.fields) == |
| 188 | + (:model, :varinfo, :context, :adtype, :prep) |
| 189 | + end |
| 190 | + |
| 191 | + @testset "VarInfo tangent structure" begin |
| 192 | + model = test_model1([1.0, 2.0, 3.0]) |
| 193 | + vi = VarInfo(Random.default_rng(), model) |
| 194 | + tangent_vi = zero_tangent(vi) |
| 195 | + |
| 196 | + # Test expected fields exist |
| 197 | + @test hasfield(typeof(tangent_vi), :fields) |
| 198 | + @test hasfield(typeof(tangent_vi.fields), :metadata) |
| 199 | + @test hasfield(typeof(tangent_vi.fields), :logp) |
| 200 | + @test hasfield(typeof(tangent_vi.fields), :num_produce) |
| 201 | + |
| 202 | + # Test exact field names match |
| 203 | + @test propertynames(tangent_vi.fields) == (:metadata, :logp, :num_produce) |
| 204 | + end |
| 205 | + |
| 206 | + @testset "Model tangent structure" begin |
| 207 | + model = test_model1([1.0, 2.0, 3.0]) |
| 208 | + tangent_model = zero_tangent(model) |
| 209 | + |
| 210 | + # Test expected fields exist |
| 211 | + @test hasfield(typeof(tangent_model), :fields) |
| 212 | + @test hasfield(typeof(tangent_model.fields), :f) |
| 213 | + @test hasfield(typeof(tangent_model.fields), :args) |
| 214 | + @test hasfield(typeof(tangent_model.fields), :defaults) |
| 215 | + @test hasfield(typeof(tangent_model.fields), :context) |
| 216 | + |
| 217 | + # Test exact field names match |
| 218 | + @test propertynames(tangent_model.fields) == (:f, :args, :defaults, :context) |
| 219 | + end |
| 220 | + |
| 221 | + @testset "Metadata tangent structure" begin |
| 222 | + model = test_model1([1.0, 2.0, 3.0]) |
| 223 | + vi = VarInfo(Random.default_rng(), model) |
| 224 | + tangent_vi = zero_tangent(vi) |
| 225 | + metadata = tangent_vi.fields.metadata |
| 226 | + |
| 227 | + # Metadata is a NamedTuple with variable names as keys |
| 228 | + @test metadata isa NamedTuple |
| 229 | + |
| 230 | + # Each variable's metadata should be a Tangent with the expected fields |
| 231 | + for (varname, var_metadata) in pairs(metadata) |
| 232 | + @test var_metadata isa Mooncake.Tangent |
| 233 | + @test hasfield(typeof(var_metadata), :fields) |
| 234 | + |
| 235 | + # Test expected fields exist |
| 236 | + @test hasfield(typeof(var_metadata.fields), :idcs) |
| 237 | + @test hasfield(typeof(var_metadata.fields), :vns) |
| 238 | + @test hasfield(typeof(var_metadata.fields), :ranges) |
| 239 | + @test hasfield(typeof(var_metadata.fields), :vals) |
| 240 | + @test hasfield(typeof(var_metadata.fields), :dists) |
| 241 | + @test hasfield(typeof(var_metadata.fields), :orders) |
| 242 | + @test hasfield(typeof(var_metadata.fields), :flags) |
| 243 | + |
| 244 | + # Test exact field names match |
| 245 | + @test propertynames(var_metadata.fields) == |
| 246 | + (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags) |
| 247 | + end |
| 248 | + end |
| 249 | + end |
192 | 250 | end
|
0 commit comments