Skip to content

Commit c4235d1

Browse files
committed
Remove predict on vector of VarInfo
1 parent a019495 commit c4235d1

File tree

2 files changed

+2
-60
lines changed

2 files changed

+2
-60
lines changed

src/model.jl

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,32 +1195,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
11951195
end
11961196
end
11971197

1198-
"""
1199-
predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
1200-
1201-
Generate samples from the posterior predictive distribution by evaluating `model` at each set
1202-
of parameter values provided in `chain`. The number of posterior predictive samples matches
1203-
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
1204-
and the predicted values.
1205-
"""
1206-
function predict(
1207-
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
1208-
)
1209-
varinfo = DynamicPPL.VarInfo(model)
1210-
return map(chain) do params_varinfo
1211-
vi = deepcopy(varinfo)
1212-
# TODO(penelopeysm): Requires two model evaluations, one to extract the
1213-
# parameters and one to set them. The reason why we need values_as_in_model
1214-
# is because `params_varinfo` may well have some weird combination of
1215-
# linked/unlinked, whereas `varinfo` is always unlinked since it is
1216-
# freshly constructed.
1217-
# This is quite inefficient. It would of course be alright if
1218-
# ValuesAsInModelAccumulator was a default acc.
1219-
values_nt = values_as_in_model(model, false, params_varinfo)
1220-
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
1221-
return vi
1222-
end
1223-
end
1198+
# Implemented & documented in DynamicPPLMCMCChainsExt
1199+
function predict end
12241200

12251201
"""
12261202
returned(model::Model, parameters::NamedTuple)

test/model.jl

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -566,40 +566,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
566566
end
567567
end
568568
end
569-
570-
@testset "with AbstractVector{<:AbstractVarInfo}" begin
571-
@model function linear_reg(x, y, σ=0.1)
572-
β ~ Normal(1, 1)
573-
for i in eachindex(y)
574-
y[i] ~ Normal* x[i], σ)
575-
end
576-
end
577-
578-
ground_truth_β = 2.0
579-
# the data will be ignored, as we are generating samples from the prior
580-
xs_train = 1:0.1:10
581-
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
582-
m_lin_reg = linear_reg(xs_train, ys_train)
583-
chain = [VarInfo(m_lin_reg) for _ in 1:10000]
584-
585-
# chain is generated from the prior
586-
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) 1.0 atol = 0.1
587-
588-
xs_test = [10 + 0.1, 10 + 2 * 0.1]
589-
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
590-
predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain)
591-
592-
@test size(predicted_vis) == size(chain)
593-
@test Set(keys(predicted_vis[1])) ==
594-
Set([@varname(β), @varname(y[1]), @varname(y[2])])
595-
# because β samples are from the prior, the std will be larger
596-
@test mean([
597-
predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis)
598-
]) 1.0 * xs_test[1] rtol = 0.1
599-
@test mean([
600-
predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis)
601-
]) 1.0 * xs_test[2] rtol = 0.1
602-
end
603569
end
604570

605571
@testset "ProductNamedTupleDistribution sampling" begin

0 commit comments

Comments
 (0)