Skip to content

Commit 66423bf

Browse files
committed
Implement ParamsWithStats from FastLDF
1 parent 4ec0c72 commit 66423bf

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

src/chains.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,42 @@ function ParamsWithStats(
9090
return ParamsWithStats(params, stats)
9191
end
9292

93+
function ParamsWithStats(
94+
param_vector::AbstractVector,
95+
ldf::DynamicPPL.Experimental.FastLDF,
96+
stats::NamedTuple=NamedTuple();
97+
include_colon_eq::Bool=true,
98+
include_log_probs::Bool=true,
99+
)
100+
ctx = DynamicPPL.Experimental.FastEvalVectorContext(
101+
ldf._iden_varname_ranges, ldf._varname_ranges, param_vector
102+
)
103+
accs = if include_log_probs
104+
(
105+
DynamicPPL.LogPriorAccumulator(),
106+
DynamicPPL.LogLikelihoodAccumulator(),
107+
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
108+
)
109+
else
110+
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
111+
end
112+
return _, varinfo = DynamicPPL.Experimental.fast_evaluate!!(
113+
ldf.model, ctx, AccumulatorTuple(accs)
114+
)
115+
params = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
116+
if include_log_probs
117+
stats = merge(
118+
stats,
119+
(
120+
logprior=DynamicPPL.getlogprior(varinfo),
121+
loglikelihood=DynamicPPL.getloglikelihood(varinfo),
122+
lp=DynamicPPL.getlogjoint(varinfo),
123+
),
124+
)
125+
end
126+
return ParamsWithStats(params, stats)
127+
end
128+
93129
# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to
94130
# convert it to a typed varinfo first, hence this method.
95131
# https://github.com/TuringLang/Turing.jl/issues/2604

src/fasteval.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -352,16 +352,9 @@ end
352352
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
353353
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
354354

355-
struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
356-
_model::M
357-
_getlogdensity::F
358-
_iden_varname_ranges::N
359-
_varname_ranges::Dict{VarName,RangeAndLinked}
360-
end
361-
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
362-
ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params)
363-
model = DynamicPPL.setleafcontext(f._model, ctx)
364-
accs = fast_ldf_accs(f._getlogdensity)
355+
function fast_evaluate!!(model::Model, ctx::FastEvalVectorContext, accs::AccumulatorTuple)
356+
model = DynamicPPL.setleafcontext(model, ctx)
357+
vi = OnlyAccsVarInfo(accs)
365358
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
366359
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
367360
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
@@ -378,7 +371,18 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
378371
else
379372
OnlyAccsVarInfo(accs)
380373
end
381-
_, vi = DynamicPPL._evaluate!!(model, vi)
374+
return DynamicPPL._evaluate!!(model, vi)
375+
end
376+
377+
struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
378+
_model::M
379+
_getlogdensity::F
380+
_iden_varname_ranges::N
381+
_varname_ranges::Dict{VarName,RangeAndLinked}
382+
end
383+
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
384+
ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params)
385+
_, vi = fast_evaluate!!(f._model, ctx, fast_ldf_accs(f._getlogdensity))
382386
return f._getlogdensity(vi)
383387
end
384388

0 commit comments

Comments
 (0)