Skip to content

Commit 8520ec3

Browse files
committed
Un-fix predict on varinfo
1 parent 2edcd10 commit 8520ec3

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/model.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,15 +1200,24 @@ function predict(
12001200
varinfo = DynamicPPL.VarInfo(model)
12011201
return map(chain) do params_varinfo
12021202
vi = deepcopy(varinfo)
1203-
# TODO(penelopeysm): Requires two model evaluations, one to extract the
1204-
# parameters and one to set them. The reason why we need values_as_in_model
1205-
# is because `params_varinfo` may well have some weird combination of
1206-
# linked/unlinked, whereas `varinfo` is always unlinked since it is
1207-
# freshly constructed.
1208-
# This is quite inefficient. It would of course be alright if
1209-
# ValuesAsInModelAccumulator was a default acc.
1210-
values_nt = values_as_in_model(model, false, params_varinfo)
1211-
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit()))
1203+
# TODO(penelopeysm): BEWARE - Because `ParamsInit` expects unlinked
1204+
# values, this is bugged in the case where `params_varinfo` is linked.
1205+
# This could be solved by using `values_as_in_model`. However, we run
1206+
# into another problem:
1207+
# - The `model` passed into this function will have the target
1208+
# prediction variables (say `y`) set to `missing`.
1209+
# - Calling `values_as_in_model` on this new model will lead to
1210+
# DynamicPPL attempting to extract the values of `y` from the
1211+
# `params_varinfo`, which will fail since `y` will not have been
1212+
# present.
1213+
# I think that the solution is to pass in the ORIGINAL model, but ALSO
1214+
# pass in a set / iterable of VarNames that are to be predicted against.
1215+
# That way, we can:
1216+
# - Safely use `values_as_in_model` to extract a dict of values.
1217+
# - Drop the values of the prediction variables from the dict.
1218+
# - Then use `ParamsInit` to generate the predictions.
1219+
values_dict = values_as(params_varinfo, Dict{VarName,Any})
1220+
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_dict, PriorInit()))
12121221
return vi
12131222
end
12141223
end

0 commit comments

Comments
 (0)