Skip to content

Commit 7ca59ce

Browse files
committed
Merge branch 'breaking' into mhauru/dppl-0.37
2 parents 11a2a31 + 966e17b commit 7ca59ce

File tree

6 files changed

+37
-21
lines changed

6 files changed

+37
-21
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
[...]
44

5+
# 0.39.8
6+
7+
MCMCChains.jl doesn't understand vector- or matrix-valued variables, and in Turing we split up such values into their individual components.
8+
This patch carries out some internal refactoring to avoid splitting up VarNames until absolutely necessary.
9+
There are no user-facing changes in this patch.
10+
511
# 0.39.7
612

713
Update compatibility to AdvancedPS 0.7 and Libtask 0.9.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
## 🚀 Get started
1313

14-
Install Julia (see [the official Julia website](https://julialang.org/install/); you will need at least Julia 1.10 will be required for the latest version of Turing.jl.
14+
Install Julia (see [the official Julia website](https://julialang.org/install/); you will need at least Julia 1.10 for the latest version of Turing.jl).
1515
Then, launch a Julia REPL and run:
1616

1717
```julia

ext/TuringOptimExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ function _optimize(
185185
logdensity_optimum = Optimisation.OptimLogDensity(
186186
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
187187
)
188-
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
188+
vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum)
189+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
190+
vns_vals_iter = mapreduce(collect, vcat, iters)
189191
varnames = map(Symbol first, vns_vals_iter)
190192
vals = map(last, vns_vals_iter)
191193
vmat = NamedArrays.NamedArray(vals, varnames)

src/mcmc/Inference.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
174174
# this means that the code below will work both of linked and invlinked `vi`.
175175
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
176176
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
177-
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
178-
179-
# Obtain an iterator over the flattened parameter names and values.
180-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
181-
182-
# Materialize the iterators and concatenate.
183-
return mapreduce(collect, vcat, iters)
177+
return DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
184178
end
185179
function getparams(
186180
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
@@ -191,14 +185,25 @@ function getparams(
191185
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
192186
end
193187
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
194-
return float(Real)[]
188+
return Dict{VarName,Any}()
195189
end
196190

197191
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
198192
names_set = OrderedSet{VarName}()
199193
# Extract the parameter names and values from each transition.
200194
dicts = map(ts) do t
201-
nms_and_vs = getparams(model, t)
195+
# In general getparams returns a dict of VarName => values. We need to also
196+
# split it up into constituent elements using
197+
# `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
198+
# won't understand it.
199+
vals = getparams(model, t)
200+
nms_and_vs = if isempty(vals)
201+
Tuple{VarName,Any}[]
202+
else
203+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
204+
mapreduce(collect, vcat, iters)
205+
end
206+
202207
nms = map(first, nms_and_vs)
203208
vs = map(last, nms_and_vs)
204209
for nm in nms
@@ -208,9 +213,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
208213
return OrderedDict(zip(nms, vs))
209214
end
210215
names = collect(names_set)
211-
vals = [
212-
get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names)
213-
]
216+
vals = [get(dicts[i], key, missing) for i in eachindex(dicts), key in names]
214217

215218
return names, vals
216219
end

src/optimisation/Optimisation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,10 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
426426
# Get all the variable names in the model. This is the same as the list of keys in
427427
# m.values, but they are more convenient to filter when they are VarNames rather than
428428
# Symbols.
429-
varnames = collect(
430-
map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo))
431-
)
429+
vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo)
430+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
431+
vns_and_vals = mapreduce(collect, vcat, iters)
432+
varnames = collect(map(first, vns_and_vals))
432433
# For each symbol s in var_symbols, pick all the values from m.values for which the
433434
# variable name has that symbol.
434435
et = eltype(m.values)
@@ -456,7 +457,9 @@ parameter space in case the optimization was done in a transformed space.
456457
function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
457458
varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u)
458459
# `getparams` performs invlinking if needed
459-
vns_vals_iter = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
460+
vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
461+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
462+
vns_vals_iter = mapreduce(collect, vcat, iters)
460463
syms = map(Symbol first, vns_vals_iter)
461464
vals = map(last, vns_vals_iter)
462465
return ModeResult(

test/mcmc/Inference.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,16 +585,18 @@ using Turing
585585
return x ~ Normal()
586586
end
587587
fvi = DynamicPPL.VarInfo(f())
588-
@test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)])
588+
fparams = Turing.Inference.getparams(f(), fvi)
589+
@test fparams[@varname(x)] == fvi[@varname(x)]
590+
@test length(fparams) == 1
589591

590592
@model function g()
591593
x ~ Normal()
592594
return y ~ Poisson()
593595
end
594596
gvi = DynamicPPL.VarInfo(g())
595597
gparams = Turing.Inference.getparams(g(), gvi)
596-
@test gparams[1] == (@varname(x), gvi[@varname(x)])
597-
@test gparams[2] == (@varname(y), gvi[@varname(y)])
598+
@test gparams[@varname(x)] == gvi[@varname(x)]
599+
@test gparams[@varname(y)] == gvi[@varname(y)]
598600
@test length(gparams) == 2
599601
end
600602

0 commit comments

Comments
 (0)