Skip to content

Commit 7f12c3e

Browse files
committed
Fix a bunch of tests; I'll let CI tell me what's still broken...
1 parent 27b0096 commit 7f12c3e

File tree

11 files changed

+48
-32
lines changed

11 files changed

+48
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
9292
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
9393

9494
[sources]
95-
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}
95+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/setleafcontext-model"}

src/mcmc/hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ function find_initial_params(
161161
@warn "failed to find valid initial parameters in $(attempts) tries; consider providing a different initialisation strategy with the `initial_params` keyword"
162162

163163
# Resample and try again.
164-
varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy)
164+
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, init_strategy)
165165
end
166166

167167
# if we failed to find valid initial parameters, error

src/mcmc/prior.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function AbstractMCMC.step(
1313
kwargs...,
1414
)
1515
vi = DynamicPPL.setaccs!!(
16-
vi,
16+
DynamicPPL.VarInfo(),
1717
(
1818
DynamicPPL.ValuesAsInModelAccumulator(true),
1919
DynamicPPL.LogPriorAccumulator(),

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ Combinatorics = "1"
5353
Distributions = "0.25"
5454
DistributionsAD = "0.6.3"
5555
DynamicHMC = "2.1.6, 3.0"
56-
DynamicPPL = "0.37.2"
5756
FiniteDifferences = "0.10.8, 0.11, 0.12"
5857
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
5958
HypothesisTests = "0.11"
@@ -77,3 +76,6 @@ StatsBase = "0.33, 0.34"
7776
StatsFuns = "0.9.5, 1"
7877
TimerOutputs = "0.5"
7978
julia = "1.10"
79+
80+
[sources]
81+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/setleafcontext-model"}

test/essential/container.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using Turing
2222
vi = DynamicPPL.setacc!!(vi, Turing.Inference.ProduceLogLikelihoodAccumulator())
2323
sampler = Sampler(PG(10))
2424
model = test()
25-
trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())
25+
trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false)
2626

2727
# Make sure the backreference from taped_globals to the trace is in place.
2828
@test trace.model.ctask.taped_globals.other === trace
@@ -48,7 +48,7 @@ using Turing
4848
sampler = Sampler(PG(10))
4949
model = normal()
5050

51-
trace = AdvancedPS.Trace(model, sampler, vi, AdvancedPS.TracedRNG())
51+
trace = AdvancedPS.Trace(model, vi, AdvancedPS.TracedRNG(), false)
5252

5353
newtrace = AdvancedPS.forkr(trace)
5454
# Catch broken replay mechanism

test/mcmc/ess.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,12 @@ using Turing
108108
spl_x = Gibbs(@varname(z) => NUTS(), @varname(x) => ESS())
109109
spl_xy = Gibbs(@varname(z) => NUTS(), (@varname(x), @varname(y)) => ESS())
110110

111-
@test sample(StableRNG(23), xy(), spl_xy, num_samples).value
112-
sample(StableRNG(23), x12(), spl_x, num_samples).value
111+
chn1 = sample(StableRNG(23), xy(), spl_xy, num_samples)
112+
chn2 = sample(StableRNG(23), x12(), spl_x, num_samples)
113+
114+
@test mean(chn1[:z]) mean(chn2[:z]) atol = 0.05
115+
@test mean(chn1[:x]) mean(chn2["x[1]"]) atol = 0.05
116+
@test mean(chn1[:y]) mean(chn2["x[2]"]) atol = 0.05
113117
end
114118
end
115119

test/mcmc/external_sampler.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ using Turing.Inference: AdvancedHMC
4545
rng::Random.AbstractRNG,
4646
model::AbstractMCMC.LogDensityModel,
4747
sampler::MySampler;
48+
# This initial_params should be an AbstractVector because the model is just a
49+
# LogDensityModel, not a DynamicPPL.Model
4850
initial_params::AbstractVector,
4951
kwargs...,
5052
)
@@ -82,7 +84,10 @@ using Turing.Inference: AdvancedHMC
8284
model = test_external_sampler()
8385
a, b = 0.5, 0.0
8486

85-
chn = sample(model, externalsampler(MySampler()), 10; initial_params=[a, b])
87+
# This `initial_params` should be an InitStrategy
88+
chn = sample(
89+
model, externalsampler(MySampler()), 10; initial_params=InitFromParams((a=a, b=b))
90+
)
8691
@test chn isa MCMCChains.Chains
8792
@test all(chn[:a] .== a)
8893
@test all(chn[:b] .== b)
@@ -167,9 +172,7 @@ function initialize_mh_with_prior_proposal(model)
167172
)
168173
end
169174

170-
function test_initial_params(
171-
model, sampler, initial_params=DynamicPPL.VarInfo(model)[:]; kwargs...
172-
)
175+
function test_initial_params(model, sampler, initial_params=InitFromPrior(); kwargs...)
173176
# Execute the transition with two different RNGs and check that the resulting
174177
# parameter values are the same.
175178
rng1 = Random.MersenneTwister(42)
@@ -204,7 +207,7 @@ end
204207
n_adapts=1_000,
205208
discard_initial=1_000,
206209
# FIXME: Remove this once we can run `test_initial_params` above.
207-
initial_params=DynamicPPL.VarInfo(model)[:],
210+
initial_params=InitFromPrior(),
208211
)
209212

210213
@testset "inference" begin

test/mcmc/gibbs.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -693,13 +693,9 @@ end
693693
num_chains = 4
694694

695695
# Determine initial parameters to make comparison as fair as possible.
696+
# posterior_mean returns a NamedTuple so we can plug it in directly.
696697
posterior_mean = DynamicPPL.TestUtils.posterior_mean(model)
697-
initial_params = DynamicPPL.TestUtils.update_values!!(
698-
DynamicPPL.VarInfo(model),
699-
posterior_mean,
700-
DynamicPPL.TestUtils.varnames(model),
701-
)[:]
702-
initial_params = fill(initial_params, num_chains)
698+
initial_params = fill(InitFromParams(initial_params), num_chains)
703699

704700
# Sampler to use for Gibbs components.
705701
hmc = HMC(0.1, 32)

test/mcmc/hmc.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,11 @@ using Turing
177177
@testset "$spl_name" for (spl_name, spl) in
178178
(("HMC", HMC(0.1, 10)), ("NUTS", NUTS()))
179179
chain = sample(
180-
demo_norm(), spl, 5; discard_adapt=false, initial_params=(x=init_x,)
180+
demo_norm(),
181+
spl,
182+
5;
183+
discard_adapt=false,
184+
initial_params=InitFromParams((x=init_x,)),
181185
)
182186
@test chain[:x][1] == init_x
183187
chain = sample(
@@ -187,7 +191,7 @@ using Turing
187191
5,
188192
5;
189193
discard_adapt=false,
190-
initial_params=(fill((x=init_x,), 5)),
194+
initial_params=(fill(InitFromParams((x=init_x,)), 5)),
191195
)
192196
@test all(chain[:x][1, :] .== init_x)
193197
end
@@ -202,12 +206,11 @@ using Turing
202206
end
203207
end
204208

205-
@test_logs (
206-
:warn,
207-
"failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword",
208-
) (:info,) match_mode = :any begin
209-
sample(demo_warn_initial_params(), NUTS(), 5)
210-
end
209+
# verbose=false to suppress the initial step size notification, which messes with
210+
# the test
211+
@test_logs (:warn, r"consider providing a different initialisation strategy") sample(
212+
demo_warn_initial_params(), NUTS(), 5; verbose=false
213+
)
211214
end
212215

213216
@testset "error for impossible model" begin
@@ -253,7 +256,8 @@ using Turing
253256
model = buggy_model()
254257
num_samples = 1_000
255258

256-
chain = sample(model, NUTS(), num_samples; initial_params=[0.5, 1.75, 1.0])
259+
initial_params = InitFromParams((lb=0.5, ub=1.75, x=1.0))
260+
chain = sample(model, NUTS(), num_samples; initial_params=initial_params)
257261
chain_prior = sample(model, Prior(), num_samples)
258262

259263
# Extract the `x` like this because running `generated_quantities` was how
@@ -275,7 +279,11 @@ using Turing
275279
# Construct a HMC state by taking a single step
276280
spl = Sampler(alg)
277281
hmc_state = DynamicPPL.initialstep(
278-
Random.default_rng(), gdemo_default, spl, DynamicPPL.VarInfo(gdemo_default)
282+
Random.default_rng(),
283+
gdemo_default,
284+
spl,
285+
DynamicPPL.VarInfo(gdemo_default);
286+
initial_params=InitFromUniform(),
279287
)[2]
280288
# Check that we can obtain the current step size
281289
@test Turing.Inference.getstepsize(spl, hmc_state) isa Float64

test/mcmc/mh.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
4949
# Set the initial parameters, because if we get unlucky with the initial state,
5050
# these chains are too short to converge to reasonable numbers.
5151
discard_initial = 1_000
52-
initial_params = [1.0, 1.0]
52+
initial_params = InitFromParams((s=1.0, m=1.0))
5353

5454
@testset "gdemo_default" begin
5555
alg = MH()
@@ -81,13 +81,16 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
8181
@varname(mu1) => MH((:mu1, GKernel(1))),
8282
@varname(mu2) => MH((:mu2, GKernel(1))),
8383
)
84+
initial_params = InitFromParams((
85+
mu1=1.0, mu2=1.0, z1=0.0, z2=0.0, z3=1.0, z4=1.0
86+
))
8487
chain = sample(
8588
StableRNG(seed),
8689
MoGtest_default,
8790
gibbs,
8891
500;
8992
discard_initial=100,
90-
initial_params=[1.0, 1.0, 0.0, 0.0, 1.0, 4.0],
93+
initial_params=initial_params,
9194
)
9295
check_MoGtest_default(chain; atol=0.2)
9396
end

0 commit comments

Comments
 (0)