Skip to content

Commit e381f3e

Browse files
jonasmac16torfjeldeyebai
committed
Extended generated_quantities function (#299)
Extended `generated_quantities` function to also accept a named tuple or keys+values Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent ecc3c02 commit e381f3e

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

src/model.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,50 @@ function generated_quantities(model::Model, chain::AbstractChains)
590590
model(varinfo)
591591
end
592592
end
593+
594+
"""
595+
generated_quantities(model::Model, parameters::NamedTuple)
596+
generated_quantities(model::Model, values, keys)
597+
generated_quantities(model::Model, values, keys)
598+
599+
Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.
600+
601+
If a `NamedTuple` is given, `keys=keys(parameters)` and `values=values(parameters)`.
602+
603+
# Example
604+
```jldoctest
605+
julia> using DynamicPPL, Distributions
606+
607+
julia> @model function demo(xs)
608+
s ~ InverseGamma(2, 3)
609+
m_shifted ~ Normal(10, √s)
610+
m = m_shifted - 10
611+
for i in eachindex(xs)
612+
xs[i] ~ Normal(m, √s)
613+
end
614+
return (m, )
615+
end
616+
demo (generic function with 1 method)
617+
618+
julia> model = demo(randn(10));
619+
620+
julia> parameters = (; s = 1.0, m_shifted=10);
621+
622+
julia> generated_quantities(model, parameters)
623+
(0.0,)
624+
625+
julia> generated_quantities(model, values(parameters), keys(parameters))
626+
(0.0,)
627+
```
628+
"""
629+
function generated_quantities(model::Model, parameters::NamedTuple)
630+
varinfo = VarInfo(model)
631+
setval_and_resample!(varinfo, values(parameters), keys(parameters))
632+
return model(varinfo)
633+
end
634+
635+
function generated_quantities(model::Model, values, keys)
636+
varinfo = VarInfo(model)
637+
setval_and_resample!(varinfo, values, keys)
638+
return model(varinfo)
639+
end

0 commit comments

Comments
 (0)