Skip to content

Commit 4896793

Browse files
authored
Merge branch 'master' into torfjelde/returned-quantities-macro
2 parents d477137 + d6e2147 commit 4896793

File tree

8 files changed

+43
-89
lines changed

8 files changed

+43
-89
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ on:
44
push:
55
branches:
66
- master
7+
- backport-*
78
pull_request:
89
branches:
910
- master
11+
- backport-*
1012
merge_group:
1113
types: [checks_requested]
1214

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.30.1"
3+
version = "0.30.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -46,7 +46,7 @@ AbstractMCMC = "5"
4646
AbstractPPL = "0.8.4, 0.9"
4747
Accessors = "0.1"
4848
BangBang = "0.4.1"
49-
Bijectors = "0.13.18"
49+
Bijectors = "0.13.18, 0.14"
5050
ChainRulesCore = "1"
5151
Compat = "4"
5252
ConstructionBase = "1.5.4"
@@ -65,7 +65,7 @@ Requires = "1"
6565
ReverseDiff = "1"
6666
Test = "1.6"
6767
ZygoteRules = "0.2"
68-
julia = "~1.6.6, 1.7.3"
68+
julia = "1.10"
6969

7070
[extras]
7171
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 7 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -108,86 +108,14 @@ function DynamicPPL.returned_quantities(
108108
varinfo = DynamicPPL.VarInfo(model)
109109
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
110110
return map(iters) do (sample_idx, chain_idx)
111-
if DynamicPPL.supports_varname_indexing(chain)
112-
varname_pairs = _varname_pairs_with_varname_indexing(
113-
chain, varinfo, sample_idx, chain_idx
114-
)
115-
else
116-
varname_pairs = _varname_pairs_without_varname_indexing(
117-
chain, varinfo, sample_idx, chain_idx
118-
)
119-
end
120-
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
121-
return fixed_model()
111+
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
112+
# Update the varinfo with the current sample and make variables not present in `chain`
113+
# to be sampled.
114+
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
115+
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
116+
# `deepcopy` the `varinfo` before passing it to the `model`.
117+
model(deepcopy(varinfo))
122118
end
123119
end
124120

125-
"""
126-
_varname_pairs_with_varname_indexing(
127-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
128-
)
129-
130-
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
131-
from the chain.
132-
133-
This implementation assumes `chain` can be indexed using variable names, and is the
134-
preffered implementation.
135-
"""
136-
function _varname_pairs_with_varname_indexing(
137-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
138-
)
139-
vns = DynamicPPL.varnames(chain)
140-
vn_parents = Iterators.map(vns) do vn
141-
# The call nested_setindex_maybe! is used to handle cases where vn is not
142-
# the variable name used in the model, but rather subsumed by one. Except
143-
# for the subsumption part, this could be
144-
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
145-
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
146-
DynamicPPL.nested_setindex_maybe!(
147-
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
148-
)
149-
end
150-
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
151-
vn_parent => varinfo[vn_parent]
152-
end
153-
return varname_pairs
154-
end
155-
156-
"""
157-
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.
158-
159-
The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
160-
won't catch all cases. We should get rid of this if we can.
161-
"""
162-
# TODO(mhauru) See docstring above.
163-
function _vcat_subsumed_values(vn_string, values, key_strings)
164-
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
165-
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
166-
end
167-
168-
"""
169-
_varname_pairs_without_varname_indexing(
170-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
171-
)
172-
173-
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
174-
from the chain.
175-
176-
This implementation does not assume that `chain` can be indexed using variable names. It is
177-
thus not guaranteed to work in cases where the variable names have complex subsumption
178-
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
179-
"""
180-
function _varname_pairs_without_varname_indexing(
181-
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
182-
)
183-
values = chain.value[sample_idx, :, chain_idx]
184-
keys = Base.keys(chain)
185-
keys_strings = map(string, keys)
186-
varname_pairs = [
187-
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
188-
vn in Base.keys(varinfo)
189-
]
190-
return varname_pairs
191-
end
192-
193121
end

src/utils.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,15 @@ function (f::ReshapeTransform)(x)
286286
if size(x) != f.input_size
287287
throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))"))
288288
end
289-
# The call to `tovec` is only needed in case `x` is a scalar.
290-
return reshape(tovec(x), f.output_size)
289+
if f.output_size == ()
290+
# Specially handle the case where x is a singleton array, see
291+
# https://github.com/JuliaDiff/ReverseDiff.jl/issues/265 and
292+
# https://github.com/TuringLang/DynamicPPL.jl/issues/698
293+
return fill(x[], ())
294+
else
295+
# The call to `tovec` is only needed in case `x` is a scalar.
296+
return reshape(tovec(x), f.output_size)
297+
end
291298
end
292299

293300
function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)
@@ -934,10 +941,10 @@ end
934941
"""
935942
float_type_with_fallback(x)
936943
937-
Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`.
944+
Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`.
938945
"""
939-
float_type_with_fallback(::Type) = Real
940-
float_type_with_fallback(::Type{Union{}}) = Real
946+
float_type_with_fallback(::Type) = float(Real)
947+
float_type_with_fallback(::Type{Union{}}) = float(Real)
941948
float_type_with_fallback(::Type{T}) where {T<:Real} = float(T)
942949

943950
"""

src/varinfo.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,8 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`.
16941694
16951695
The value(s) may or may not be transformed to Euclidean space.
16961696
"""
1697+
getindex(vi::UntypedVarInfo, spl::Sampler) =
1698+
copy(getindex(vi.metadata.vals, _getranges(vi, spl)))
16971699
getindex(vi::VarInfo, spl::Sampler) = copy(getindex_internal(vi, _getranges(vi, spl)))
16981700
function getindex(vi::TypedVarInfo, spl::Sampler)
16991701
# Gets the ranges as a NamedTuple

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ ADTypes = "1"
3131
AbstractMCMC = "5"
3232
AbstractPPL = "0.8.4, 0.9"
3333
Accessors = "0.1"
34-
Bijectors = "0.13.9"
34+
Bijectors = "0.13.9, 0.14"
3535
Combinatorics = "1"
3636
Compat = "4.3.0"
3737
Distributions = "0.25"

test/turing/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30"
1515
HypothesisTests = "0.11"
1616
MCMCChains = "6"
1717
ReverseDiff = "1.15"
18-
Turing = "0.33, 0.34"
18+
Turing = "0.33, 0.34, 0.35"
1919
julia = "1.7"

test/turing/varinfo.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,4 +342,19 @@
342342
model = state_space(y, length(t))
343343
@test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n
344344
end
345+
346+
if Threads.nthreads() > 1
347+
@testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin
348+
@model function f(x)
349+
ns ~ filldist(Normal(0, 2.0), 3)
350+
m ~ Uniform(0, 1)
351+
return x ~ Normal(m, 1)
352+
end
353+
model = f(1)
354+
chain = sample(model, NUTS(), MCMCThreads(), 10, 2)
355+
loglikelihood(model, chain)
356+
logprior(model, chain)
357+
logjoint(model, chain)
358+
end
359+
end
345360
end

0 commit comments

Comments
 (0)