-
Notifications
You must be signed in to change notification settings - Fork 36
Description
I am opening this issue to propose a way to treat Turing models as distributions. After TuringLang/Turing.jl#997, we will have a nice syntax to query the logpdf
of a NamedTuple
of symbols given another NamedTuple
of symbols, e.g. logprob"x = [1, 2] | y = [1, 1], model = mymodel"
. This takes care of the logpdf
function in the Distributions
API. So if we can also have a way to sample a NamedTuple
from the probability distribution P(x | y= [1, 1], model = mymodel)
and the likes, we can define a simple syntax that generates such distribution structs from Turing models, something like prob"x | y = [1, 1], model = mymodel"
. Defining the posterior predictive distribution from there is also somewhat easy. The question is how to sample NamedTuple
s from Turing models.
Initial proposal
Let the inner function called when calling an instance of the Turing.Model
be:
function inner_function(vi, sampler, ctx, model)
#function body
end
Let's assume we have random variables s
and m
in the model. The idea I initially had was to replace the above expression with:
function inner_function(vi, sampler, ctx, model)
if ctx isa DistributionContext
s = missing
m = missing
end
@inline function f()
#function body
end
model_output = f()
if ctx isa DistributionContext
return (s = s, m = m)
else
return model_output
end
end
Semantics
Calling the model
in the DistributionContext
will therefore return us the NamedTuple
of random variables sampled. The DistributionContext
can wrap another context and forward the tilde
and dot_tilde
functions to the context wrapped. This will give us a way to return the sampled random variables from the model. If a random variable wasn't sampled, missing
will be returned. If a vector or matrix random variable can be partially sampled, i.e. the number of random variables is itself random, we can tell the users to initialize such variables using:
x = Vector{Union{Missing, T}}(undef, 10)`
for example inside the model body. The unsampled random variables will therefore be missing
. These semantics for "sampling from Turing models" seem reasonable to me even in the case when the number of random variables is itself random.
Performance and alternative implementation
Given that we inline f
, I thought the Julia compiler should be able to optimize out s = missing
and m = missing
if s
and m
were defined inside #function body
, because the compiler does optimize these away when I don't use a closure. Unfortunately, the following minimal example doesn't infer properly:
function f(a)
# initialize
if a isa Int
b = missing
end
@inline function g()
# model body starts
b = [1, 1]
# Can have more stuff here and multiple return statements in branches
return 2*b
# model body ends
end
out = g() # overwrite `b` if `a isa Int`
# Return a namedtuple or the output of the model depending on the type of the input `a` (or ctx in the real example)
if a isa Int
return (b = b,)
else
return out
end
end
@code_warntype f(1) # f(1.0)
and according to Keno Fischer on Slack, this is "a limitation that needs to be addressed properly at some point". The problem is basically the closure.
So my alternative solution is to do:
function inner_function(vi, sampler, ctx, model)
if ctx isa DistributionContext
s = missing
m = missing
end
#function body
model_output = nothing
@label end_of_func
if ctx isa DistributionContext
return (s = s, m = m)
else
return model_output
end
end
and inside #function body
, we replace every return ...
statement with:
model_output = ...
@goto end_of_func
and every return
statement with:
@goto end_of_func
This eliminates the use of a closure and therefore should infer properly. At least, the following minimal example does:
function f(a)
# initialize
if a isa Int
b = missing
end
# model body starts
# can have arbitrary code here with multiple return points, some or all the random variables can be defined
b = [1, 1]
if sum(b) > 0.5
out = 2*b
@goto end_of_func
else
out = 3*b
@goto end_of_func
end
# model body ends
out = nothing # the function output is nothing if we made it here
@label end_of_func
# Return a namedtuple or the output of the model depending on the type of the input `a` (or ctx in the real example)
if a isa Int
return (b = b,)
else
return out
end
end
@code_warntype f(1)