Skip to content

Commit d6e9639

Browse files
committed
Merge branch 'master' into torfjelde/fix-fixes
2 parents da6f9a0 + 2344689 commit d6e9639

20 files changed

+1395
-1299
lines changed

.github/workflows/CI.yml

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ on:
44
push:
55
branches:
66
- master
7+
- backport-*
78
pull_request:
89
branches:
910
- master
11+
- backport-*
1012
merge_group:
1113
types: [checks_requested]
1214

@@ -17,39 +19,53 @@ permissions:
1719

1820
jobs:
1921
test:
20-
runs-on: ${{ matrix.os }}
22+
runs-on: ${{ matrix.runner.os }}
2123
strategy:
2224
matrix:
23-
version:
24-
- 'min' # minimum supported version
25-
- '1' # current stable version
26-
os:
27-
- ubuntu-latest
28-
arch:
29-
- x64
30-
num_threads:
31-
- 1
32-
- 2
33-
include:
25+
runner:
26+
# Current stable version
27+
- version: '1'
28+
os: ubuntu-latest
29+
arch: x64
30+
num_threads: 2
31+
# Minimum supported version
32+
- version: 'min'
33+
os: ubuntu-latest
34+
arch: x64
35+
num_threads: 2
36+
# Single-threaded
37+
- version: '1'
38+
os: ubuntu-latest
39+
arch: x64
40+
num_threads: 1
41+
# Minimum supported version, single-threaded
42+
- version: 'min'
43+
os: ubuntu-latest
44+
arch: x64
45+
num_threads: 1
46+
# x86
3447
- version: '1'
3548
os: ubuntu-latest
3649
arch: x86
3750
num_threads: 2
51+
# Windows
3852
- version: '1'
3953
os: windows-latest
4054
arch: x64
4155
num_threads: 2
56+
# macOS
4257
- version: '1'
43-
os: macOS-latest
44-
arch: x64
58+
os: macos-latest
59+
arch: aarch64
4560
num_threads: 2
61+
4662
steps:
4763
- uses: actions/checkout@v4
4864

4965
- uses: julia-actions/setup-julia@v2
5066
with:
51-
version: ${{ matrix.version }}
52-
arch: ${{ matrix.arch }}
67+
version: ${{ matrix.runner.version }}
68+
arch: ${{ matrix.runner.arch }}
5369

5470
- uses: julia-actions/cache@v2
5571

@@ -58,7 +74,7 @@ jobs:
5874
- uses: julia-actions/julia-runtest@v1
5975
env:
6076
GROUP: All
61-
JULIA_NUM_THREADS: ${{ matrix.num_threads }}
77+
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}
6278

6379
- uses: julia-actions/julia-processcoverage@v1
6480

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.30.2"
3+
version = "0.30.6"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -14,6 +14,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1414
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1515
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1616
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
17+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1920
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
@@ -46,7 +47,7 @@ AbstractMCMC = "5"
4647
AbstractPPL = "0.8.4, 0.9"
4748
Accessors = "0.1"
4849
BangBang = "0.4.1"
49-
Bijectors = "0.13.18"
50+
Bijectors = "0.13.18, 0.14"
5051
ChainRulesCore = "1"
5152
Compat = "4"
5253
ConstructionBase = "1.5.4"
@@ -65,7 +66,7 @@ Requires = "1"
6566
ReverseDiff = "1"
6667
Test = "1.6"
6768
ZygoteRules = "0.2"
68-
julia = "~1.6.6, 1.7.3"
69+
julia = "1.10"
6970

7071
[extras]
7172
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/api.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,19 @@ And some which might be useful to determine certain properties of the model base
212212
DynamicPPL.has_static_constraints
213213
```
214214

215+
For determining whether one might have type instabilities in the model, the following can be useful
216+
217+
```@docs
218+
DynamicPPL.DebugUtils.model_warntype
219+
DynamicPPL.DebugUtils.model_typed
220+
```
221+
222+
Interally, the type-checking methods make use of the following method for construction of the call with the argument types:
223+
224+
```@docs
225+
DynamicPPL.DebugUtils.gen_evaluator_call_with_types
226+
```
227+
215228
## Advanced
216229

217230
### Variable names

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 7 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities(
108108
varinfo = DynamicPPL.VarInfo(model)
109109
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
110110
return map(iters) do (sample_idx, chain_idx)
111-
if DynamicPPL.supports_varname_indexing(chain)
112-
varname_pairs = _varname_pairs_with_varname_indexing(
113-
chain, varinfo, sample_idx, chain_idx
114-
)
115-
else
116-
varname_pairs = _varname_pairs_without_varname_indexing(
117-
chain, varinfo, sample_idx, chain_idx
118-
)
119-
end
120-
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
121-
return fixed_model()
111+
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
112+
# Update the varinfo with the current sample and make variables not present in `chain`
113+
# to be sampled.
114+
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
115+
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
116+
# `deepcopy` the `varinfo` before passing it to the `model`.
117+
model(deepcopy(varinfo))
122118
end
123119
end
124120

125-
"""
126-
_varname_pairs_with_varname_indexing(
127-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
128-
)
129-
130-
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
131-
from the chain.
132-
133-
This implementation assumes `chain` can be indexed using variable names, and is the
134-
preffered implementation.
135-
"""
136-
function _varname_pairs_with_varname_indexing(
137-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
138-
)
139-
vns = DynamicPPL.varnames(chain)
140-
vn_parents = Iterators.map(vns) do vn
141-
# The call nested_setindex_maybe! is used to handle cases where vn is not
142-
# the variable name used in the model, but rather subsumed by one. Except
143-
# for the subsumption part, this could be
144-
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
145-
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
146-
DynamicPPL.nested_setindex_maybe!(
147-
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
148-
)
149-
end
150-
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
151-
vn_parent => varinfo[vn_parent]
152-
end
153-
return varname_pairs
154-
end
155-
156-
"""
157-
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
158-
159-
The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
160-
won't catch all cases. We should get rid of this if we can.
161-
"""
162-
# TODO(mhauru) See docstring above.
163-
function _vcat_subsumed_values(vn_string, values, key_strings)
164-
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
165-
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
166-
end
167-
168-
"""
169-
_varname_pairs_without_varname_indexing(
170-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
171-
)
172-
173-
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
174-
from the chain.
175-
176-
This implementation does not assume that `chain` can be indexed using variable names. It is
177-
thus not guaranteed to work in cases where the variable names have complex subsumption
178-
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
179-
"""
180-
function _varname_pairs_without_varname_indexing(
181-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
182-
)
183-
values = chain.value[sample_idx, :, chain_idx]
184-
keys = Base.keys(chain)
185-
keys_strings = map(string, keys)
186-
varname_pairs = [
187-
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
188-
vn in Base.keys(varinfo)
189-
]
190-
return varname_pairs
191-
end
192-
193121
end

src/context_implementations.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ function tilde_assume(
7878
end
7979

8080
function tilde_assume(::LikelihoodContext, right, vn, vi)
81-
return assume(NoDist(right), vn, vi)
81+
return assume(nodist(right), vn, vi)
8282
end
8383
function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi)
84-
return assume(rng, sampler, NoDist(right), vn, vi)
84+
return assume(rng, sampler, nodist(right), vn, vi)
8585
end
8686

8787
function tilde_assume(context::PrefixContext, right, vn, vi)

src/debug_utils.jl

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ..DynamicPPL: broadcast_safe, AbstractContext, childcontext
55

66
using Random: Random
77
using Accessors: Accessors
8+
using InteractiveUtils: InteractiveUtils
89

910
using DocStringExtensions
1011
using Distributions
@@ -331,7 +332,9 @@ function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi)
331332
record_post_tilde_assume!(context, vn, right, value, logp, vi)
332333
return value, logp, vi
333334
end
334-
function DynamicPPL.tilde_assume(rng, context::DebugContext, sampler, right, vn, vi)
335+
function DynamicPPL.tilde_assume(
336+
rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi
337+
)
335338
record_pre_tilde_assume!(context, vn, right, vi)
336339
value, logp, vi = DynamicPPL.tilde_assume(
337340
rng, childcontext(context), sampler, right, vn, vi
@@ -424,7 +427,7 @@ function DynamicPPL.dot_tilde_assume(context::DebugContext, right, left, vn, vi)
424427
end
425428

426429
function DynamicPPL.dot_tilde_assume(
427-
rng, context::DebugContext, sampler, right, left, vn, vi
430+
rng::Random.AbstractRNG, context::DebugContext, sampler, right, left, vn, vi
428431
)
429432
record_pre_dot_tilde_assume!(context, vn, left, right, vi)
430433
value, logp, vi = DynamicPPL.dot_tilde_assume(
@@ -678,4 +681,83 @@ function has_static_constraints(
678681
return all_the_same(transforms)
679682
end
680683

684+
"""
685+
gen_evaluator_call_with_types(model[, varinfo, context])
686+
687+
Generate the evaluator call and the types of the arguments.
688+
689+
# Arguments
690+
- `model::Model`: The model whose evaluator is of interest.
691+
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
692+
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
693+
694+
# Returns
695+
A 2-tuple with the following elements:
696+
- `f`: This is either `model.f` or `Core.kwcall`, depending on whether
697+
the model has keyword arguments.
698+
- `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator.
699+
"""
700+
function gen_evaluator_call_with_types(
701+
model::Model,
702+
varinfo::AbstractVarInfo=VarInfo(model),
703+
context::AbstractContext=DefaultContext(),
704+
)
705+
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
706+
return if isempty(kwargs)
707+
(model.f, Base.typesof(args...))
708+
else
709+
(Core.kwcall, Tuple{typeof(kwargs),Core.Typeof(model.f),map(Core.Typeof, args)...})
710+
end
711+
end
712+
713+
"""
714+
model_warntype(model[, varinfo, context]; optimize=true)
715+
716+
Check the type stability of the model's evaluator, warning about any potential issues.
717+
718+
This simply calls `@code_warntype` on the model's evaluator, filling in internal arguments where needed.
719+
720+
# Arguments
721+
- `model::Model`: The model to check.
722+
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
723+
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
724+
725+
# Keyword Arguments
726+
- `optimize::Bool`: Whether to generate optimized code. Default: `false`.
727+
"""
728+
function model_warntype(
729+
model::Model,
730+
varinfo::AbstractVarInfo=VarInfo(model),
731+
context::AbstractContext=DefaultContext();
732+
optimize::Bool=false,
733+
)
734+
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context)
735+
return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize)
736+
end
737+
738+
"""
739+
model_typed(model[, varinfo, context]; optimize=true)
740+
741+
Return the type inference for the model's evaluator.
742+
743+
This simply calls `@code_typed` on the model's evaluator, filling in internal arguments where needed.
744+
745+
# Arguments
746+
- `model::Model`: The model to check.
747+
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
748+
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
749+
750+
# Keyword Arguments
751+
- `optimize::Bool`: Whether to generate optimized code. Default: `true`.
752+
"""
753+
function model_typed(
754+
model::Model,
755+
varinfo::AbstractVarInfo=VarInfo(model),
756+
context::AbstractContext=DefaultContext();
757+
optimize::Bool=true,
758+
)
759+
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context)
760+
return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize))
761+
end
762+
681763
end

src/distribution_wrappers.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ Base.length(dist::NoDist) = Base.length(dist.dist)
4242
Base.size(dist::NoDist) = Base.size(dist.dist)
4343

4444
Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
45+
# NOTE(torfjelde): Need this to avoid stack overflow.
46+
function Distributions.rand!(
47+
rng::Random.AbstractRNG,
48+
d::NoDist{Distributions.ArrayLikeVariate{N}},
49+
x::AbstractArray{<:Real,N},
50+
) where {N}
51+
return Distributions.rand!(rng, d.dist, x)
52+
end
4553
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
4654
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
4755
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})

0 commit comments

Comments
 (0)