Skip to content

Commit 6a307cd

Browse files
torfjeldeyebai
andauthored
Bump DynamicPPL to 0.23 (#2001)
* bump dppl test versions * also bump bijectors * bump AdvancedVI versions * revert Bijectors bump * bumped vi and bijectors too * breaking change * removed refernce to Bijectors.setadbackend * make use of DynamicPPL.make_evaluate_args_and_kwargs * bump DPPL version * bump DPPL version for tests * fixed bug in TracedModel * forgot to remove some lines * just drop the kwargs completely :( * Update container.jl * Update container.jl * will now error if we're using a model with kwargs and SMC * added reference to issue * added test for keyword models failing * make this a breaking change * made error message more informative * makde it slightly less informative * fixed typo in checking for TRaceModel * finally fixed the if-statement.. * Fix test error * fixed tests maybe * now fixed maybe * Update test/inference/Inference.jl --------- Co-authored-by: Hong Ge <[email protected]>
1 parent a38c709 commit 6a307cd

File tree

5 files changed

+46
-36
lines changed

5 files changed

+46
-36
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.25.3"
3+
version = "0.26.0"
4+
45

56
[deps]
67
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -47,7 +48,7 @@ DataStructures = "0.18"
4748
Distributions = "0.23.3, 0.24, 0.25"
4849
DistributionsAD = "0.6"
4950
DocStringExtensions = "0.8, 0.9"
50-
DynamicPPL = "0.21.5, 0.22"
51+
DynamicPPL = "0.23"
5152
EllipticalSliceSampling = "0.5, 1"
5253
ForwardDiff = "0.10.3"
5354
Libtask = "0.7, 0.8"

src/essential/container.jl

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,21 @@ function TracedModel(
99
model::Model,
1010
sampler::AbstractSampler,
1111
varinfo::AbstractVarInfo,
12-
rng::Random.AbstractRNG
13-
)
12+
rng::Random.AbstractRNG,
13+
)
1414
context = SamplingContext(rng, sampler, DefaultContext())
15-
evaluator = _get_evaluator(model, varinfo, context)
16-
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(model, sampler, varinfo, evaluator)
17-
end
18-
19-
# TODO: maybe move to DynamicPPL
20-
@generated function _get_evaluator(
21-
model::Model{_F,argnames}, varinfo, context
22-
) where {_F,argnames}
23-
unwrap_args = [
24-
:($DynamicPPL.matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames
25-
]
26-
# We want to give `context` precedence over `model.context` while also
27-
# preserving the leaf context of `context`. We can do this by
28-
# 1. Set the leaf context of `model.context` to `leafcontext(context)`.
29-
# 2. Set leaf context of `context` to the context resulting from (1).
30-
# The result is:
31-
# `context` -> `childcontext(context)` -> ... -> `model.context`
32-
# -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
33-
return quote
34-
context_new = DynamicPPL.setleafcontext(
35-
context, DynamicPPL.setleafcontext(model.context, DynamicPPL.leafcontext(context))
36-
)
37-
(model.f, model, DynamicPPL.resetlogp!!(varinfo), context_new, $(unwrap_args...))
15+
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
16+
if kwargs !== nothing && !isempty(kwargs)
17+
error("Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.")
3818
end
19+
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
20+
model,
21+
sampler,
22+
varinfo,
23+
(model.f, args...)
24+
)
3925
end
4026

41-
4227
function Base.copy(model::AdvancedPS.GenericModel{<:TracedModel})
4328
newtask = copy(model.ctask)
4429
newmodel = TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(deepcopy(model.f.model), deepcopy(model.f.sampler), deepcopy(model.f.varinfo), deepcopy(model.f.evaluator))
@@ -73,10 +58,12 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
7358
return trace
7459
end
7560

76-
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
61+
function AdvancedPS.update_rng!(trace::AdvancedPS.Trace{AdvancedPS.GenericModel{TracedModel{M,S,V,E}, F}, R}) where {M,S,V,E,F,R}
62+
# Extract the `args`.
7763
args = trace.model.ctask.args
78-
_, _, container, = args
79-
rng = container.rng
64+
# From `args`, extract the `SamplingContext`, which contains the RNG.
65+
sampling_context = args[3]
66+
rng = sampling_context.rng
8067
trace.rng = rng
8168
return trace
8269
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Clustering = "0.14, 0.15"
4040
Distributions = "0.25"
4141
DistributionsAD = "0.6.3"
4242
DynamicHMC = "2.1.6, 3.0"
43-
DynamicPPL = "0.21.5, 0.22"
43+
DynamicPPL = "0.23"
4444
FiniteDifferences = "0.10.8, 0.11, 0.12"
4545
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
4646
LogDensityProblems = "2"

test/inference/AdvancedSMC.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,12 @@ end
173173
@test length(unique(c[:m])) == 1
174174
@test length(unique(c[:s])) == 1
175175
end
176+
177+
# https://github.com/TuringLang/Turing.jl/issues/2007
178+
@turing_testset "keyword arguments not supported" begin
179+
@model kwarg_demo(; x = 2) = return x
180+
@test_throws ErrorException sample(kwarg_demo(), PG(1), 10)
181+
end
176182
end
177183

178184
# @testset "pmmh.jl" begin

test/inference/Inference.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,27 @@
259259
return priors
260260
end
261261

262-
chain = sample(gauss2(; x=x), PG(10), 10)
263-
chain = sample(gauss2(; x=x), SMC(), 10)
262+
@test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10)
263+
@test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10)
264264

265-
chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10)
266-
chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10)
265+
@test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), PG(10), 10)
266+
@test_throws ErrorException chain = sample(gauss2(Vector{Float64}; x=x), SMC(), 10)
267+
268+
@model function gauss3(x, ::Type{TV}=Vector{Float64}) where {TV}
269+
priors = TV(undef, 2)
270+
priors[1] ~ InverseGamma(2, 3) # s
271+
priors[2] ~ Normal(0, sqrt(priors[1])) # m
272+
for i in 1:length(x)
273+
x[i] ~ Normal(priors[2], sqrt(priors[1]))
274+
end
275+
return priors
276+
end
277+
278+
chain = sample(gauss3(x), PG(10), 10)
279+
chain = sample(gauss3(x), SMC(), 10)
280+
281+
chain = sample(gauss3(x, Vector{Real}), PG(10), 10)
282+
chain = sample(gauss3(x, Vector{Real}), SMC(), 10)
267283
end
268284
@testset "new interface" begin
269285
obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1]

0 commit comments

Comments
 (0)