Skip to content

Commit d0510b1

Browse files
authored
Avoid splitting up varnames until absolutely necessary (#2632)
* Avoid splitting up varnames until absolutely necessary * Fix more optimisation bits
1 parent 23b92eb commit d0510b1

File tree

6 files changed

+37
-21
lines changed

6 files changed

+37
-21
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# 0.39.8
2+
3+
MCMCChains.jl doesn't understand vector- or matrix-valued variables, and in Turing we split up such values into their individual components.
4+
This patch carries out some internal refactoring to avoid splitting up VarNames until absolutely necessary.
5+
There are no user-facing changes in this patch.
6+
17
# 0.39.7
28

39
Update compatibility to AdvancedPS 0.7 and Libtask 0.9.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.39.7"
3+
version = "0.39.8"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/TuringOptimExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ function _optimize(
186186
logdensity_optimum = Optimisation.OptimLogDensity(
187187
f.ldf.model, vi_optimum, f.ldf.context
188188
)
189-
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
189+
vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum)
190+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
191+
vns_vals_iter = mapreduce(collect, vcat, iters)
190192
varnames = map(Symbol first, vns_vals_iter)
191193
vals = map(last, vns_vals_iter)
192194
vmat = NamedArrays.NamedArray(vals, varnames)

src/mcmc/Inference.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
176176
# this means that the code below will work both of linked and invlinked `vi`.
177177
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
178178
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
179-
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
180-
181-
# Obtain an iterator over the flattened parameter names and values.
182-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
183-
184-
# Materialize the iterators and concatenate.
185-
return mapreduce(collect, vcat, iters)
179+
return DynamicPPL.values_as_in_model(model, true, deepcopy(vi))
186180
end
187181
function getparams(
188182
model::DynamicPPL.Model, untyped_vi::DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
@@ -193,14 +187,25 @@ function getparams(
193187
return getparams(model, DynamicPPL.typed_varinfo(untyped_vi))
194188
end
195189
function getparams(::DynamicPPL.Model, ::DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}})
196-
return float(Real)[]
190+
return Dict{VarName,Any}()
197191
end
198192

199193
function _params_to_array(model::DynamicPPL.Model, ts::Vector)
200194
names_set = OrderedSet{VarName}()
201195
# Extract the parameter names and values from each transition.
202196
dicts = map(ts) do t
203-
nms_and_vs = getparams(model, t)
197+
# In general getparams returns a dict of VarName => values. We need to also
198+
# split it up into constituent elements using
199+
# `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
200+
# won't understand it.
201+
vals = getparams(model, t)
202+
nms_and_vs = if isempty(vals)
203+
Tuple{VarName,Any}[]
204+
else
205+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
206+
mapreduce(collect, vcat, iters)
207+
end
208+
204209
nms = map(first, nms_and_vs)
205210
vs = map(last, nms_and_vs)
206211
for nm in nms
@@ -210,9 +215,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
210215
return OrderedDict(zip(nms, vs))
211216
end
212217
names = collect(names_set)
213-
vals = [
214-
get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names)
215-
]
218+
vals = [get(dicts[i], key, missing) for i in eachindex(dicts), key in names]
216219

217220
return names, vals
218221
end

src/optimisation/Optimisation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,10 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
366366
# Get all the variable names in the model. This is the same as the list of keys in
367367
# m.values, but they are more convenient to filter when they are VarNames rather than
368368
# Symbols.
369-
varnames = collect(
370-
map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo))
371-
)
369+
vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo)
370+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
371+
vns_and_vals = mapreduce(collect, vcat, iters)
372+
varnames = collect(map(first, vns_and_vals))
372373
# For each symbol s in var_symbols, pick all the values from m.values for which the
373374
# variable name has that symbol.
374375
et = eltype(m.values)
@@ -396,7 +397,9 @@ parameter space in case the optimization was done in a transformed space.
396397
function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
397398
varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u)
398399
# `getparams` performs invlinking if needed
399-
vns_vals_iter = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
400+
vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
401+
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
402+
vns_vals_iter = mapreduce(collect, vcat, iters)
400403
syms = map(Symbol first, vns_vals_iter)
401404
vals = map(last, vns_vals_iter)
402405
return ModeResult(

test/mcmc/Inference.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -630,16 +630,18 @@ using Turing
630630
return x ~ Normal()
631631
end
632632
fvi = DynamicPPL.VarInfo(f())
633-
@test only(Turing.Inference.getparams(f(), fvi)) == (@varname(x), fvi[@varname(x)])
633+
fparams = Turing.Inference.getparams(f(), fvi)
634+
@test fparams[@varname(x)] == fvi[@varname(x)]
635+
@test length(fparams) == 1
634636

635637
@model function g()
636638
x ~ Normal()
637639
return y ~ Poisson()
638640
end
639641
gvi = DynamicPPL.VarInfo(g())
640642
gparams = Turing.Inference.getparams(g(), gvi)
641-
@test gparams[1] == (@varname(x), gvi[@varname(x)])
642-
@test gparams[2] == (@varname(y), gvi[@varname(y)])
643+
@test gparams[@varname(x)] == gvi[@varname(x)]
644+
@test gparams[@varname(y)] == gvi[@varname(y)]
643645
@test length(gparams) == 2
644646
end
645647

0 commit comments

Comments
 (0)