Skip to content

Commit 7b2869f

Browse files
authored
Add Base.get method for ModeResult (#2269)
* Add Base.get method for ModeResult * Make get(::ModeResult, itr) work for any iterator * Fix array type in get(::ModeResult, ...)
1 parent fcb4ca7 commit 7b2869f

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

src/optimisation/Optimisation.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,37 @@ StatsBase.params(m::ModeResult) = StatsBase.coefnames(m)
273273
StatsBase.vcov(m::ModeResult) = inv(StatsBase.informationmatrix(m))
274274
StatsBase.loglikelihood(m::ModeResult) = m.lp
275275

276+
"""
277+
Base.get(m::ModeResult, var_symbol::Symbol)
278+
Base.get(m::ModeResult, var_symbols)
279+
280+
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
281+
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
282+
argument should be either a `Symbol` or an iterator of `Symbol`s.
283+
"""
284+
function Base.get(m::ModeResult, var_symbols)
285+
log_density = m.f
286+
# Get all the variable names in the model. This is the same as the list of keys in
287+
# m.values, but they are more convenient to filter when they are VarNames rather than
288+
# Symbols.
289+
varnames = collect(
290+
map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo))
291+
)
292+
# For each symbol s in var_symbols, pick all the values from m.values for which the
293+
# variable name has that symbol.
294+
et = eltype(m.values)
295+
value_vectors = Vector{et}[]
296+
for s in var_symbols
297+
push!(
298+
value_vectors,
299+
[m.values[Symbol(vn)] for vn in varnames if DynamicPPL.getsym(vn) == s],
300+
)
301+
end
302+
return (; zip(var_symbols, value_vectors)...)
303+
end
304+
305+
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
306+
276307
"""
277308
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
278309

test/optimisation/Optimisation.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ using ..Models: gdemo, gdemo_default
44
using Distributions
55
using Distributions.FillArrays: Zeros
66
using DynamicPPL: DynamicPPL
7-
using LinearAlgebra: I
7+
using LinearAlgebra: Diagonal, I
88
using Random: Random
99
using Optimization
1010
using Optimization: Optimization
1111
using OptimizationBBO: OptimizationBBO
1212
using OptimizationNLopt: OptimizationNLopt
1313
using OptimizationOptimJL: OptimizationOptimJL
14+
using ReverseDiff: ReverseDiff
1415
using StatsBase: StatsBase
1516
using StatsBase: coef, coefnames, coeftable, informationmatrix, stderror, vcov
1617
using Test: @test, @testset, @test_throws
@@ -591,6 +592,30 @@ using Turing
591592
@test result.values[:x] 0 atol = 1e-1
592593
@test result.values[:y] 100 atol = 1e-1
593594
end
595+
596+
@testset "get ModeResult" begin
597+
@model function demo_model(N)
598+
half_N = N ÷ 2
599+
a ~ arraydist(LogNormal.(fill(0, half_N), 1))
600+
b ~ arraydist(LogNormal.(fill(0, N - half_N), 1))
601+
covariance_matrix = Diagonal(vcat(a, b))
602+
x ~ MvNormal(covariance_matrix)
603+
return nothing
604+
end
605+
606+
N = 12
607+
m = demo_model(N) | (x=randn(N),)
608+
result = maximum_a_posteriori(m)
609+
get_a = get(result, :a)
610+
get_b = get(result, :b)
611+
get_ab = get(result, [:a, :b])
612+
@assert keys(get_a) == (:a,)
613+
@assert keys(get_b) == (:b,)
614+
@assert keys(get_ab) == (:a, :b)
615+
@assert get_b[:b] == get_ab[:b]
616+
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
617+
@assert get(result, :c) == (; :c => Array{Float64}[])
618+
end
594619
end
595620

596621
end

0 commit comments

Comments
 (0)