Skip to content

Commit e93458c

Browse files
committed
use NoCache to improve set_to_zero!! performance with Mooncake
1 parent acac44d commit e93458c

File tree

3 files changed

+298
-3
lines changed

3 files changed

+298
-3
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## Unreleased
4+
5+
Improved performance for some models with Mooncake.jl by using `NoCache` with `Mooncake.set_to_zero!!` for DynamicPPL types.
6+
37
## 0.36.14
48

59
Added compatibility with [email protected].

ext/DynamicPPLMooncakeExt.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,110 @@
11
module DynamicPPLMooncakeExt
22

3+
__precompile__(false)
4+
35
using DynamicPPL: DynamicPPL, istrans
46
using Mooncake: Mooncake
7+
import Mooncake: set_to_zero!!
8+
using Mooncake: NoTangent, Tangent, MutableTangent, NoCache, set_to_zero_internal!!
59

610
# This is purely an optimisation.
711
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
812

13+
"""
14+
Check if a tangent has the expected structure for a given type.
15+
"""
16+
function has_expected_structure(
17+
x, expected_type::Type{<:Union{Tangent,MutableTangent}}, expected_fields
18+
)
19+
x isa expected_type || return false
20+
hasfield(typeof(x), :fields) || return false
21+
22+
fields = x.fields
23+
if expected_fields isa Tuple
24+
# Exact match required
25+
propertynames(fields) == expected_fields || return false
26+
else
27+
# All expected fields must be present
28+
all(f in propertynames(fields) for f in expected_fields) || return false
29+
end
30+
31+
return true
32+
end
33+
34+
"""
35+
Check if a tangent corresponds to a DynamicPPL.LogDensityFunction
36+
"""
37+
function is_dppl_ldf_tangent(x)
38+
has_expected_structure(x, Tangent, (:model, :varinfo, :context, :adtype, :prep)) ||
39+
return false
40+
41+
fields = x.fields
42+
is_dppl_varinfo_tangent(fields.varinfo) || return false
43+
is_dppl_model_tangent(fields.model) || return false
44+
45+
return true
46+
end
47+
48+
"""
49+
Check if a tangent corresponds to a DynamicPPL.VarInfo
50+
"""
51+
function is_dppl_varinfo_tangent(x)
52+
return has_expected_structure(x, Tangent, (:metadata, :logp, :num_produce))
53+
end
54+
55+
"""
56+
Check if a tangent corresponds to a DynamicPPL.Model
57+
"""
58+
function is_dppl_model_tangent(x)
59+
return has_expected_structure(x, Tangent, (:f, :args, :defaults, :context))
60+
end
61+
62+
"""
63+
Check if a MutableTangent corresponds to DynamicPPL.Metadata
64+
"""
65+
function is_dppl_metadata_tangent(x)
66+
return has_expected_structure(
67+
x, MutableTangent, (:idcs, :vns, :ranges, :vals, :dists, :orders, :flags)
68+
)
69+
end
70+
71+
"""
72+
Check if a model function tangent represents a closure.
73+
"""
74+
function is_closure_model(model_f_tangent)
75+
model_f_tangent isa MutableTangent && return true
76+
77+
if model_f_tangent isa Tangent && hasfield(typeof(model_f_tangent), :fields)
78+
# Check if any field is a MutableTangent with PossiblyUninitTangent{Any}
79+
for (_, fval) in pairs(model_f_tangent.fields)
80+
if fval isa MutableTangent &&
81+
hasfield(typeof(fval), :fields) &&
82+
hasfield(typeof(fval.fields), :contents) &&
83+
fval.fields.contents isa Mooncake.PossiblyUninitTangent{Any}
84+
return true
85+
end
86+
end
87+
end
88+
89+
return false
90+
end
91+
92+
function Mooncake.set_to_zero!!(x)
93+
# Check for DynamicPPL types and use NoCache for better performance
94+
if is_dppl_ldf_tangent(x)
95+
# Special handling for LogDensityFunction to detect closures
96+
model_f_tangent = x.fields.model.fields.f
97+
cache = is_closure_model(model_f_tangent) ? IdDict{Any,Bool}() : NoCache()
98+
return set_to_zero_internal!!(cache, x)
99+
elseif is_dppl_varinfo_tangent(x) ||
100+
is_dppl_model_tangent(x) ||
101+
is_dppl_metadata_tangent(x)
102+
# These types can always use NoCache
103+
return set_to_zero_internal!!(NoCache(), x)
104+
else
105+
# Use the original implementation with IdDict for all other types
106+
return set_to_zero_internal!!(IdDict{Any,Bool}(), x)
107+
end
108+
end
109+
9110
end # module

test/ext/DynamicPPLMooncakeExt.jl

Lines changed: 193 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,195 @@
1+
using DynamicPPL
2+
using Distributions
3+
using Random
4+
using Test
5+
using StableRNGs
6+
using Mooncake: NoCache, set_to_zero!!, set_to_zero_internal!!, zero_tangent
7+
using DynamicPPL.TestUtils.AD: @be
8+
using Statistics: median
9+
10+
# Define models globally to avoid closure issues
11+
@model function test_model1(x)
12+
s ~ InverseGamma(2, 3)
13+
m ~ Normal(0, sqrt(s))
14+
return x .~ Normal(m, sqrt(s))
15+
end
16+
17+
@model function test_model2(x, y)
18+
τ ~ Gamma(1, 1)
19+
σ ~ InverseGamma(2, 3)
20+
μ ~ Normal(0, τ)
21+
x .~ Normal(μ, σ)
22+
return y .~ Normal(μ, σ)
23+
end
24+
125
@testset "DynamicPPLMooncakeExt" begin
2-
Mooncake.TestUtils.test_rule(
3-
StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true
4-
)
26+
@testset "istrans rule" begin
27+
Mooncake.TestUtils.test_rule(
28+
StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true
29+
)
30+
end
31+
32+
@testset "set_to_zero!! optimization" begin
33+
# Test with a real DynamicPPL model
34+
model = test_model1([1.0, 2.0, 3.0])
35+
vi = VarInfo(Random.default_rng(), model)
36+
ldf = LogDensityFunction(model, vi, DefaultContext())
37+
tangent = zero_tangent(ldf)
38+
39+
# Test that set_to_zero!! works correctly
40+
result = set_to_zero!!(deepcopy(tangent))
41+
@test result isa typeof(tangent)
42+
43+
# Test with metadata - verify structure exists
44+
if hasfield(typeof(tangent.fields.varinfo.fields), :metadata)
45+
metadata = tangent.fields.varinfo.fields.metadata
46+
@test !isnothing(metadata)
47+
end
48+
end
49+
50+
@testset "NoCache optimization correctness" begin
51+
# Test that set_to_zero!! uses NoCache for DynamicPPL types
52+
model = test_model1([1.0, 2.0, 3.0])
53+
vi = VarInfo(Random.default_rng(), model)
54+
ldf = LogDensityFunction(model, vi, DefaultContext())
55+
tangent = zero_tangent(ldf)
56+
57+
# Modify some values
58+
if hasfield(typeof(tangent.fields.model.fields), :args) &&
59+
hasfield(typeof(tangent.fields.model.fields.args), :x)
60+
x_tangent = tangent.fields.model.fields.args.x
61+
if !isempty(x_tangent)
62+
x_tangent[1] = 5.0
63+
end
64+
end
65+
66+
# Call set_to_zero!! and verify it works
67+
set_to_zero!!(tangent)
68+
69+
# Check that values are zeroed
70+
if hasfield(typeof(tangent.fields.model.fields), :args) &&
71+
hasfield(typeof(tangent.fields.model.fields.args), :x)
72+
x_tangent = tangent.fields.model.fields.args.x
73+
if !isempty(x_tangent)
74+
@test x_tangent[1] == 0.0
75+
end
76+
end
77+
end
78+
79+
@testset "Performance improvement" begin
80+
# Test with DEMO_MODELS if available
81+
if isdefined(DynamicPPL.TestUtils, :DEMO_MODELS) &&
82+
!isempty(DynamicPPL.TestUtils.DEMO_MODELS)
83+
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
84+
else
85+
# Fallback to our test model
86+
model = test_model1([1.0, 2.0, 3.0, 4.0])
87+
end
88+
89+
vi = VarInfo(Random.default_rng(), model)
90+
ldf = LogDensityFunction(model, vi, DefaultContext())
91+
tangent = zero_tangent(ldf)
92+
93+
# Run benchmarks
94+
result_iddict = @be begin
95+
cache = IdDict{Any,Bool}()
96+
set_to_zero_internal!!(cache, tangent)
97+
end
98+
99+
result_nocache = @be set_to_zero!!(tangent)
100+
101+
# Extract median times
102+
time_iddict = median(result_iddict).time
103+
time_nocache = median(result_nocache).time
104+
105+
# We expect NoCache to be faster
106+
speedup = time_iddict / time_nocache
107+
@test speedup > 1.5 # Conservative expectation - should be ~4x
108+
109+
@info "Performance improvement" speedup time_iddict_μs = time_iddict / 1000 time_nocache_μs =
110+
time_nocache / 1000
111+
end
112+
113+
@testset "Aliasing safety" begin
114+
# Test with aliased data
115+
shared_data = [1.0, 2.0, 3.0]
116+
model = test_model2(shared_data, shared_data) # x and y are the same array
117+
vi = VarInfo(Random.default_rng(), model)
118+
ldf = LogDensityFunction(model, vi, DefaultContext())
119+
tangent = zero_tangent(ldf)
120+
121+
# Check that aliasing is preserved in tangent
122+
if hasfield(typeof(tangent.fields.model.fields), :args)
123+
args = tangent.fields.model.fields.args
124+
if hasfield(typeof(args), :x) && hasfield(typeof(args), :y)
125+
@test args.x === args.y # Aliasing should be preserved
126+
127+
# Modify via x
128+
if !isempty(args.x)
129+
args.x[1] = 10.0
130+
@test args.y[1] == 10.0 # Should also change y
131+
end
132+
133+
# Zero and check both are zeroed
134+
# Since x and y are aliased, zeroing one zeros both
135+
set_to_zero!!(tangent)
136+
if !isempty(args.x)
137+
@test args.x[1] == 0.0
138+
@test args.y[1] == 0.0
139+
end
140+
end
141+
end
142+
end
143+
144+
@testset "Closure handling" begin
145+
# Test that closure models are correctly handled
146+
147+
# Create closure model (captures environment, has circular references)
148+
function create_closure_model()
149+
local_var = 42
150+
@model function closure_model(x)
151+
s ~ InverseGamma(2, 3)
152+
m ~ Normal(0, sqrt(s))
153+
return x .~ Normal(m, sqrt(s))
154+
end
155+
return closure_model
156+
end
157+
158+
closure_fn = create_closure_model()
159+
model_closure = closure_fn([1.0, 2.0, 3.0])
160+
vi_closure = VarInfo(Random.default_rng(), model_closure)
161+
ldf_closure = LogDensityFunction(model_closure, vi_closure, DefaultContext())
162+
tangent_closure = zero_tangent(ldf_closure)
163+
164+
# Test that it works without stack overflow
165+
@test_nowarn set_to_zero!!(deepcopy(tangent_closure))
166+
167+
# Compare with global model (no closure)
168+
model_global = test_model1([1.0, 2.0, 3.0])
169+
vi_global = VarInfo(Random.default_rng(), model_global)
170+
ldf_global = LogDensityFunction(model_global, vi_global, DefaultContext())
171+
tangent_global = zero_tangent(ldf_global)
172+
173+
# Verify model.f tangent types differ
174+
f_tangent_closure = tangent_closure.fields.model.fields.f
175+
f_tangent_global = tangent_global.fields.model.fields.f
176+
177+
@test f_tangent_global isa Mooncake.NoTangent # Global function
178+
@test f_tangent_closure isa Mooncake.Tangent # Closure function
179+
180+
# Performance comparison
181+
time_global = @elapsed for _ in 1:100
182+
set_to_zero!!(tangent_global)
183+
end
184+
185+
time_closure = @elapsed for _ in 1:100
186+
set_to_zero!!(tangent_closure)
187+
end
188+
189+
# Global should be faster (uses NoCache)
190+
@test time_global < time_closure
191+
192+
@info "Closure handling" time_global_ms = time_global * 1000 time_closure_ms =
193+
time_closure * 1000
194+
end
5195
end

0 commit comments

Comments
 (0)