Skip to content

Commit 34fbc54

Browse files
committed
Throw an error in predict if any variable is transformed
1 parent 509f50e commit 34fbc54

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/model.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,14 @@ function predict(
12161216
# - Safely use `values_as_in_model` to extract a dict of values.
12171217
# - Drop the values of the prediction variables from the dict.
12181218
# - Then use `ParamsInit` to generate the predictions.
1219+
for vn in keys(vi)
1220+
# note that istrans(vi) checks that ALL variables are transformed,
1221+
# whereas the failure here happens if ANY variable is transformed.
1222+
# so we have to manually loop over keys
1223+
istrans(vi, vn) && error(
1224+
"`predict(rng, model, ::AbstractArray{<:AbstractVarInfo})` will give you wrong results if the `VarInfo` contains transformed variables and is therefore currently forbidden. Please provide an unlinked `VarInfo` instead.",
1225+
)
1226+
end
12191227
values_dict = values_as(params_varinfo, Dict{VarName,Any})
12201228
_, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_dict, PriorInit()))
12211229
return vi

0 commit comments

Comments
 (0)