Skip to content

Commit c6ccb08

Browse files
Carlos Paradagithub-actions[bot]devmotion
authored
Extra context constructors (#374)
* Extra context constructors * Simplify * Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add test * fix * formatting * Update test/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/contexts.jl Co-authored-by: David Widmann <[email protected]> * Apply suggestions from review * Update test/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * More readable * Remove inner constructor * remove brackets * Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/contexts.jl * Update src/contexts.jl Co-authored-by: David Widmann <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
1 parent aeb5e03 commit c6ccb08

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.17.5"
3+
version = "0.17.6"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/contexts.jl

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right
120120

121121
# Contexts
122122
"""
123-
SamplingContext(rng, sampler, context)
123+
SamplingContext(
124+
[rng::Random.AbstractRNG=Random.GLOBAL_RNG],
125+
[sampler::AbstractSampler=SampleFromPrior()],
126+
[context::AbstractContext=DefaultContext()],
127+
)
124128
125129
Create a context that allows you to sample parameters with the `sampler` when running the model.
126130
The `context` determines how the returned log density is computed when running the model.
@@ -132,10 +136,26 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte
132136
sampler::S
133137
context::C
134138
end
135-
SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context)
136-
SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context)
137-
SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext())
138-
SamplingContext() = SamplingContext(SampleFromPrior())
139+
140+
function SamplingContext(
141+
rng::Random.AbstractRNG=Random.GLOBAL_RNG, sampler::AbstractSampler=SampleFromPrior()
142+
)
143+
return SamplingContext(rng, sampler, DefaultContext())
144+
end
145+
146+
function SamplingContext(
147+
sampler::AbstractSampler, context::AbstractContext=DefaultContext()
148+
)
149+
return SamplingContext(Random.GLOBAL_RNG, sampler, context)
150+
end
151+
152+
function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext)
153+
return SamplingContext(rng, SampleFromPrior(), context)
154+
end
155+
156+
function SamplingContext(context::AbstractContext)
157+
return SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), context)
158+
end
139159

140160
NodeTrait(context::SamplingContext) = IsParent()
141161
childcontext(context::SamplingContext) = context.context

test/contexts.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,19 @@ end
256256
@test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x")
257257
@test getlens(vn_prefixed) === getlens(vn)
258258
end
259+
260+
@testset "SamplingContext" begin
261+
context = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext())
262+
@test context isa SamplingContext
263+
264+
# convenience constructors
265+
@test SamplingContext() == context
266+
@test SamplingContext(Random.GLOBAL_RNG) == context
267+
@test SamplingContext(SampleFromPrior()) == context
268+
@test SamplingContext(DefaultContext()) == context
269+
@test SamplingContext(Random.GLOBAL_RNG, SampleFromPrior()) == context
270+
@test SamplingContext(Random.GLOBAL_RNG, DefaultContext()) == context
271+
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
272+
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
273+
end
259274
end

0 commit comments

Comments
 (0)