Skip to content

Commit f4d7823

Browse files
committed
Implement StatsBase.predict
1 parent f51166d commit f4d7823

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/abstractprobprog.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using AbstractMCMC
22
using DensityInterface
33
using Random
4+
using StatsBase
45

56

67
"""
@@ -80,3 +81,29 @@ end
8081
function Base.rand(model::AbstractProbabilisticProgram)
8182
return rand(Random.default_rng(), NamedTuple, model)
8283
end
84+
85+
"""
86+
predict(
87+
[rng::AbstractRNG=Random.default_rng(),]
88+
[T=NamedTuple,]
89+
model::AbstractProbabilisticProgram,
90+
params,
91+
) -> T
92+
93+
Draw a sample from the joint distribution specified by `model` conditioned on the values in
94+
`params`.
95+
96+
The sample will be returned as format specified by `T`.
97+
"""
98+
function StatsBase.predict(rand::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params)
99+
return rand(rng, T, condition(model, params))
100+
end
101+
function StatsBase.predict(T::Type, model::AbstractProbabilisticProgram, params)
102+
return StatsBase.predict(Random.default_rng(), T, model, params)
103+
end
104+
function StatsBase.predict(model::AbstractProbabilisticProgram, params)
105+
return StatsBase.predict(NamedTuple, model, params)
106+
end
107+
function StatsBase.predict(rng::AbstractRNG, params)
108+
return StatsBase.predict(rng, NamedTuple, model, params)
109+
end

0 commit comments

Comments
 (0)