Skip to content

Commit 10f960e

Browse files
committed
Import varname_leaves etc from AbstractPPL instead
1 parent ff8d01e commit 10f960e

File tree

5 files changed

+13
-10
lines changed

5 files changed

+13
-10
lines changed

ext/TuringOptimExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TuringOptimExt
22

33
using Turing: Turing
4-
import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
4+
import Turing: AbstractPPL, DynamicPPL, NamedArrays, Accessors, Optimisation
55
using Optim: Optim
66

77
####################
@@ -186,7 +186,7 @@ function _optimize(
186186
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
187187
)
188188
vals_dict = Turing.Inference.getparams(f.ldf.model, vi_optimum)
189-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
189+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
190190
vns_vals_iter = mapreduce(collect, vcat, iters)
191191
varnames = map(Symbol first, vns_vals_iter)
192192
vals = map(last, vns_vals_iter)

src/mcmc/Inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,13 +262,13 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
262262
dicts = map(ts) do t
263263
# In general getparams returns a dict of VarName => values. We need to also
264264
# split it up into constituent elements using
265-
# `DynamicPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
265+
# `AbstractPPL.varname_and_value_leaves` because otherwise MCMCChains.jl
266266
# won't understand it.
267267
vals = getparams(model, t)
268268
nms_and_vs = if isempty(vals)
269269
Tuple{VarName,Any}[]
270270
else
271-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
271+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
272272
mapreduce(collect, vcat, iters)
273273
end
274274
nms = map(first, nms_and_vs)

src/optimisation/Optimisation.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module Optimisation
22

33
using ..Turing
44
using NamedArrays: NamedArrays
5+
using AbstractPPL: AbstractPPL
56
using DynamicPPL: DynamicPPL
67
using LogDensityProblems: LogDensityProblems
78
using Optimization: Optimization
@@ -320,7 +321,7 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
320321
# m.values, but they are more convenient to filter when they are VarNames rather than
321322
# Symbols.
322323
vals_dict = Turing.Inference.getparams(log_density.model, log_density.varinfo)
323-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
324+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals_dict), values(vals_dict))
324325
vns_and_vals = mapreduce(collect, vcat, iters)
325326
varnames = collect(map(first, vns_and_vals))
326327
# For each symbol s in var_symbols, pick all the values from m.values for which the
@@ -351,7 +352,7 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati
351352
varinfo_new = DynamicPPL.unflatten(log_density.ldf.varinfo, solution.u)
352353
# `getparams` performs invlinking if needed
353354
vals = Turing.Inference.getparams(log_density.ldf.model, varinfo_new)
354-
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
355+
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
355356
vns_vals_iter = mapreduce(collect, vcat, iters)
356357
syms = map(Symbol first, vns_vals_iter)
357358
vals = map(last, vns_vals_iter)

test/ext/OptimInterface.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module OptimInterfaceTests
22

33
using ..Models: gdemo_default
44
using Distributions.FillArrays: Zeros
5+
using AbstractPPL: AbstractPPL
56
using DynamicPPL: DynamicPPL
67
using LinearAlgebra: I
78
using Optim: Optim
@@ -124,7 +125,7 @@ using Turing
124125
vals = result.values
125126

126127
for vn in DynamicPPL.TestUtils.varnames(model)
127-
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
128+
for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn))
128129
@test get(result_true, vn_leaf) vals[Symbol(vn_leaf)] atol = 0.05
129130
end
130131
end
@@ -159,7 +160,7 @@ using Turing
159160
vals = result.values
160161

161162
for vn in DynamicPPL.TestUtils.varnames(model)
162-
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
163+
for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn))
163164
if model.f in allowed_incorrect_mle
164165
@test isfinite(get(result_true, vn_leaf))
165166
else

test/optimisation/Optimisation.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module OptimisationTests
22

33
using ..Models: gdemo, gdemo_default
4+
using AbstractPPL: AbstractPPL
45
using Distributions
56
using Distributions.FillArrays: Zeros
67
using DynamicPPL: DynamicPPL
@@ -495,7 +496,7 @@ using Turing
495496
vals = result.values
496497

497498
for vn in DynamicPPL.TestUtils.varnames(model)
498-
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
499+
for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn))
499500
@test get(result_true, vn_leaf) vals[Symbol(vn_leaf)] atol = 0.05
500501
end
501502
end
@@ -534,7 +535,7 @@ using Turing
534535
vals = result.values
535536

536537
for vn in DynamicPPL.TestUtils.varnames(model)
537-
for vn_leaf in DynamicPPL.TestUtils.varname_leaves(vn, get(result_true, vn))
538+
for vn_leaf in AbstractPPL.varname_leaves(vn, get(result_true, vn))
538539
if model.f in allowed_incorrect_mle
539540
@test isfinite(get(result_true, vn_leaf))
540541
else

0 commit comments

Comments
 (0)