Skip to content

Commit 5c289c5

Browse files
committed
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-macro' into torfjelde/returned-quantities-macro
2 parents b421687 + c71242f commit 5c289c5

File tree

14 files changed

+1265
-1338
lines changed

14 files changed

+1265
-1338
lines changed

.github/workflows/CI.yml

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,39 +19,53 @@ permissions:
1919

2020
jobs:
2121
test:
22-
runs-on: ${{ matrix.os }}
22+
runs-on: ${{ matrix.runner.os }}
2323
strategy:
2424
matrix:
25-
version:
26-
- 'min' # minimum supported version
27-
- '1' # current stable version
28-
os:
29-
- ubuntu-latest
30-
arch:
31-
- x64
32-
num_threads:
33-
- 1
34-
- 2
35-
include:
25+
runner:
26+
# Current stable version
27+
- version: '1'
28+
os: ubuntu-latest
29+
arch: x64
30+
num_threads: 2
31+
# Minimum supported version
32+
- version: 'min'
33+
os: ubuntu-latest
34+
arch: x64
35+
num_threads: 2
36+
# Single-threaded
37+
- version: '1'
38+
os: ubuntu-latest
39+
arch: x64
40+
num_threads: 1
41+
# Minimum supported version, single-threaded
42+
- version: 'min'
43+
os: ubuntu-latest
44+
arch: x64
45+
num_threads: 1
46+
# x86
3647
- version: '1'
3748
os: ubuntu-latest
3849
arch: x86
3950
num_threads: 2
51+
# Windows
4052
- version: '1'
4153
os: windows-latest
4254
arch: x64
4355
num_threads: 2
56+
# macOS
4457
- version: '1'
45-
os: macOS-latest
46-
arch: x64
58+
os: macos-latest
59+
arch: aarch64
4760
num_threads: 2
61+
4862
steps:
4963
- uses: actions/checkout@v4
5064

5165
- uses: julia-actions/setup-julia@v2
5266
with:
53-
version: ${{ matrix.version }}
54-
arch: ${{ matrix.arch }}
67+
version: ${{ matrix.runner.version }}
68+
arch: ${{ matrix.runner.arch }}
5569

5670
- uses: julia-actions/cache@v2
5771

@@ -60,7 +74,7 @@ jobs:
6074
- uses: julia-actions/julia-runtest@v1
6175
env:
6276
GROUP: All
63-
JULIA_NUM_THREADS: ${{ matrix.num_threads }}
77+
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}
6478

6579
- uses: julia-actions/julia-processcoverage@v1
6680

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.30.5"
3+
version = "0.30.6"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/context_implementations.jl

Lines changed: 7 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -77,49 +77,11 @@ function tilde_assume(
7777
return tilde_assume(rng, childcontext(context), args...)
7878
end
7979

80-
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
81-
if haskey(context.vars, getsym(vn))
82-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
83-
settrans!!(vi, false, vn)
84-
end
85-
return tilde_assume(PriorContext(), right, vn, vi)
86-
end
87-
function tilde_assume(
88-
rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi
89-
)
90-
if haskey(context.vars, getsym(vn))
91-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
92-
settrans!!(vi, false, vn)
93-
end
94-
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
95-
end
96-
97-
function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
98-
if haskey(context.vars, getsym(vn))
99-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
100-
settrans!!(vi, false, vn)
101-
end
102-
return tilde_assume(LikelihoodContext(), right, vn, vi)
103-
end
104-
function tilde_assume(
105-
rng::Random.AbstractRNG,
106-
context::LikelihoodContext{<:NamedTuple},
107-
sampler,
108-
right,
109-
vn,
110-
vi,
111-
)
112-
if haskey(context.vars, getsym(vn))
113-
vi = setindex!!(vi, tovec(get(context.vars, vn)), vn)
114-
settrans!!(vi, false, vn)
115-
end
116-
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
117-
end
11880
function tilde_assume(::LikelihoodContext, right, vn, vi)
119-
return assume(NoDist(right), vn, vi)
81+
return assume(nodist(right), vn, vi)
12082
end
12183
function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi)
122-
return assume(rng, sampler, NoDist(right), vn, vi)
84+
return assume(rng, sampler, nodist(right), vn, vi)
12385
end
12486

12587
function tilde_assume(context::PrefixContext, right, vn, vi)
@@ -342,37 +304,6 @@ function dot_tilde_assume(
342304
end
343305

344306
# `LikelihoodContext`
345-
function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi)
346-
return if haskey(context.vars, getsym(vn))
347-
var = get(context.vars, vn)
348-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
349-
set_val!(vi, _vns, _right, _left)
350-
settrans!!.((vi,), false, _vns)
351-
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
352-
else
353-
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
354-
end
355-
end
356-
function dot_tilde_assume(
357-
rng::Random.AbstractRNG,
358-
context::LikelihoodContext{<:NamedTuple},
359-
sampler,
360-
right,
361-
left,
362-
vn,
363-
vi,
364-
)
365-
return if haskey(context.vars, getsym(vn))
366-
var = get(context.vars, vn)
367-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
368-
set_val!(vi, _vns, _right, _left)
369-
settrans!!.((vi,), false, _vns)
370-
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
371-
else
372-
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
373-
end
374-
end
375-
376307
function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi)
377308
return dot_assume(nodist(right), left, vn, vi)
378309
end
@@ -382,46 +313,16 @@ function dot_tilde_assume(
382313
return dot_assume(rng, sampler, nodist(right), vn, left, vi)
383314
end
384315

385-
# `PriorContext`
386-
function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi)
387-
return if haskey(context.vars, getsym(vn))
388-
var = get(context.vars, vn)
389-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
390-
set_val!(vi, _vns, _right, _left)
391-
settrans!!.((vi,), false, _vns)
392-
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
393-
else
394-
dot_tilde_assume(PriorContext(), right, left, vn, vi)
395-
end
396-
end
397-
function dot_tilde_assume(
398-
rng::Random.AbstractRNG,
399-
context::PriorContext{<:NamedTuple},
400-
sampler,
401-
right,
402-
left,
403-
vn,
404-
vi,
405-
)
406-
return if haskey(context.vars, getsym(vn))
407-
var = get(context.vars, vn)
408-
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
409-
set_val!(vi, _vns, _right, _left)
410-
settrans!!.((vi,), false, _vns)
411-
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
412-
else
413-
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
414-
end
415-
end
416-
417316
# `PrefixContext`
418317
function dot_tilde_assume(context::PrefixContext, right, left, vn, vi)
419-
return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi)
318+
return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi)
420319
end
421320

422-
function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi)
321+
function dot_tilde_assume(
322+
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi
323+
)
423324
return dot_tilde_assume(
424-
rng, context.context, sampler, right, prefix.(Ref(context), vn), vi
325+
rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi
425326
)
426327
end
427328

src/contexts.jl

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ DefaultContext()
5353
julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior
5454
5555
julia> DynamicPPL.childcontext(ctx_prior)
56-
PriorContext{Nothing}(nothing)
56+
PriorContext()
5757
```
5858
"""
5959
setchildcontext
@@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext()))
9797
9898
julia> # Replace the leaf context with another leaf.
9999
leafcontext(setleafcontext(ctx, PriorContext()))
100-
PriorContext{Nothing}(nothing)
100+
PriorContext()
101101
102102
julia> # Append another parent context.
103103
setleafcontext(ctx, ParentContext(DefaultContext()))
@@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end
195195
NodeTrait(context::DefaultContext) = IsLeaf()
196196

197197
"""
198-
struct PriorContext{Tvars} <: AbstractContext
199-
vars::Tvars
200-
end
198+
PriorContext <: AbstractContext
201199
202-
The `PriorContext` enables the computation of the log prior of the parameters `vars` when
203-
running the model.
200+
A leaf context resulting in the exclusion of likelihood terms when running the model.
204201
"""
205-
struct PriorContext{Tvars} <: AbstractContext
206-
vars::Tvars
207-
end
208-
PriorContext() = PriorContext(nothing)
202+
struct PriorContext <: AbstractContext end
209203
NodeTrait(context::PriorContext) = IsLeaf()
210204

211205
"""
212-
struct LikelihoodContext{Tvars} <: AbstractContext
213-
vars::Tvars
214-
end
206+
LikelihoodContext <: AbstractContext
215207
216-
The `LikelihoodContext` enables the computation of the log likelihood of the parameters when
217-
running the model. `vars` can be used to evaluate the log likelihood for specific values
218-
of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default.
208+
A leaf context resulting in the exclusion of prior terms when running the model.
219209
"""
220-
struct LikelihoodContext{Tvars} <: AbstractContext
221-
vars::Tvars
222-
end
223-
LikelihoodContext() = LikelihoodContext(nothing)
210+
struct LikelihoodContext <: AbstractContext end
224211
NodeTrait(context::LikelihoodContext) = IsLeaf()
225212

226213
"""

src/debug_utils.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,9 @@ function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
332332
record_post_tilde_assume!(context, vn, right, value, logp, vi)
333333
return value, logp, vi
334334
end
335-
function DynamicPPL.tilde_assume(rng, context::DebugContext, sampler, right, vn, vi)
335+
function DynamicPPL.tilde_assume(
336+
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
337+
)
336338
record_pre_tilde_assume!(context, vn, right, vi)
337339
value, logp, vi = DynamicPPL.tilde_assume(
338340
rng, childcontext(context), sampler, right, vn, vi
@@ -425,7 +427,7 @@ function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi)
425427
end
426428

427429
function DynamicPPL.dot_tilde_assume(
428-
rng, context::DebugContext, sampler, right, left, vn, vi
430+
rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi
429431
)
430432
record_pre_dot_tilde_assume!(context, vn, left, right, vi)
431433
value, logp, vi = DynamicPPL.dot_tilde_assume(

src/distribution_wrappers.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ Base.length(dist::NoDist) = Base.length(dist.dist)
4242
Base.size(dist::NoDist) = Base.size(dist.dist)
4343

4444
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
45+
# NOTE(torfjelde): Need this to avoid stack overflow.
46+
function Distributions.rand!(
47+
rng::Random.AbstractRNG,
48+
d::NoDist{Distributions.ArrayLikeVariate{N}},
49+
x::AbstractArray{<:Real,N},
50+
) where {N}
51+
return Distributions.rand!(rng, d.dist, x)
52+
end
4553
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
4654
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
4755
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})

0 commit comments

Comments
 (0)