@@ -2,6 +2,7 @@ module Optimisation
2
2
3
3
using .. Turing
4
4
using NamedArrays: NamedArrays
5
+ using AbstractPPL: AbstractPPL
5
6
using DynamicPPL: DynamicPPL
6
7
using LogDensityProblems: LogDensityProblems
7
8
using Optimization: Optimization
@@ -320,7 +321,7 @@ function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
320
321
# m.values, but they are more convenient to filter when they are VarNames rather than
321
322
# Symbols.
322
323
vals_dict = Turing. Inference. getparams (log_density. model, log_density. varinfo)
323
- iters = map (DynamicPPL . varname_and_value_leaves, keys (vals_dict), values (vals_dict))
324
+ iters = map (AbstractPPL . varname_and_value_leaves, keys (vals_dict), values (vals_dict))
324
325
vns_and_vals = mapreduce (collect, vcat, iters)
325
326
varnames = collect (map (first, vns_and_vals))
326
327
# For each symbol s in var_symbols, pick all the values from m.values for which the
@@ -351,7 +352,7 @@ function ModeResult(log_density::OptimLogDensity, solution::SciMLBase.Optimizati
351
352
varinfo_new = DynamicPPL. unflatten (log_density. ldf. varinfo, solution. u)
352
353
# `getparams` performs invlinking if needed
353
354
vals = Turing. Inference. getparams (log_density. ldf. model, varinfo_new)
354
- iters = map (DynamicPPL . varname_and_value_leaves, keys (vals), values (vals))
355
+ iters = map (AbstractPPL . varname_and_value_leaves, keys (vals), values (vals))
355
356
vns_vals_iter = mapreduce (collect, vcat, iters)
356
357
syms = map (Symbol ∘ first, vns_vals_iter)
357
358
vals = map (last, vns_vals_iter)
0 commit comments