@@ -795,7 +795,7 @@ fixed(model::Model) = fixed(model.context)
795
795
796
796
"""
797
797
(model::Model)()
798
- (model::Model)(rng[, varinfo, sampler, context ])
798
+ (model::Model)(rng[, varinfo])
799
799
800
800
Sample from the `model` using the `sampler` with random number generator `rng`
801
801
and the `context`, and store the sample and log joint probability in `varinfo`.
@@ -806,13 +806,8 @@ If no arguments are provided, uses the default random number generator and
806
806
samples from the prior.
807
807
"""
808
808
(model:: Model )() = model (Random. default_rng ())
809
- function (model:: Model )(
810
- rng:: Random.AbstractRNG ,
811
- varinfo:: AbstractVarInfo = VarInfo (),
812
- sampler:: AbstractSampler = SampleFromPrior (),
813
- )
814
- spl_ctx = SamplingContext (rng, sampler, DefaultContext ())
815
- return first (evaluate!! (model, varinfo, spl_ctx))
809
+ function (model:: Model )(rng:: Random.AbstractRNG , varinfo:: AbstractVarInfo = VarInfo ())
810
+ return first (sample!! (rng, model, varinfo))
816
811
end
817
812
818
813
"""
@@ -1016,7 +1011,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
1016
1011
Generate a sample of type `T` from the prior distribution of the `model`.
1017
1012
"""
1018
1013
function Base. rand (rng:: Random.AbstractRNG , :: Type{T} , model:: Model ) where {T}
1019
- x = last (sample!! (model, SimpleVarInfo {Float64} (OrderedDict ())))
1014
+ x = last (sample!! (rng, model, SimpleVarInfo {Float64} (OrderedDict ())))
1020
1015
return values_as (x, T)
1021
1016
end
1022
1017
@@ -1087,7 +1082,7 @@ function logprior(model::Model, varinfo::AbstractVarInfo)
1087
1082
LogPriorAccumulator ()
1088
1083
end
1089
1084
varinfo = setaccs!! (deepcopy (varinfo), (logprioracc,))
1090
- return getlogprior (last (evaluate!! (model, varinfo, DefaultContext () )))
1085
+ return getlogprior (last (evaluate!! (model, varinfo)))
1091
1086
end
1092
1087
1093
1088
"""
@@ -1141,7 +1136,7 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo)
1141
1136
LogLikelihoodAccumulator ()
1142
1137
end
1143
1138
varinfo = setaccs!! (deepcopy (varinfo), (loglikelihoodacc,))
1144
- return getloglikelihood (last (evaluate!! (model, varinfo, DefaultContext () )))
1139
+ return getloglikelihood (last (evaluate!! (model, varinfo)))
1145
1140
end
1146
1141
1147
1142
"""
@@ -1195,7 +1190,7 @@ function predict(
1195
1190
return map (chain) do params_varinfo
1196
1191
vi = deepcopy (varinfo)
1197
1192
DynamicPPL. setval_and_resample! (vi, values_as (params_varinfo, NamedTuple))
1198
- model (rng, vi, SampleFromPrior () )
1193
+ model (rng, vi)
1199
1194
return vi
1200
1195
end
1201
1196
end
0 commit comments