Skip to content

Turing models as distributions #4

@mohamed82008

Description

@mohamed82008

I am opening this issue to propose a way to treat Turing models as distributions. After TuringLang/Turing.jl#997, we will have a nice syntax to query the logpdf of a NamedTuple of symbols given another NamedTuple of symbols, e.g. logprob"x = [1, 2] | y = [1, 1], model = mymodel". This takes care of the logpdf function in the Distributions API. So if we can also have a way to sample a NamedTuple from the probability distribution P(x | y= [1, 1], model = mymodel) and the likes, we can define a simple syntax that generates such distribution structs from Turing models, something like prob"x | y = [1, 1], model = mymodel". Defining the posterior predictive distribution from there is also somewhat easy. The question is how to sample NamedTuples from Turing models.

Initial proposal

Let the inner function called when calling an instance of the Turing.Model be:

function inner_function(vi, sampler, ctx, model)
     #function body
end

Let's assume we have random variables s and m in the model. The idea I initially had was to replace the above expression with:

function inner_function(vi, sampler, ctx, model)
    if ctx isa DistributionContext
        s = missing
        m = missing
    end
    @inline function f()
        #function body
    end
    model_output = f()
    if ctx isa DistributionContext
        return (s = s, m = m)
    else
        return model_output
    end
end

Semantics

Calling the model in the DistributionContext will therefore return us the NamedTuple of random variables sampled. The DistributionContext can wrap another context and forward the tilde and dot_tilde functions to the context wrapped. This will give us a way to return the sampled random variables from the model. If a random variable wasn't sampled, missing will be returned. If a vector or matrix random variable can be partially sampled, i.e. the number of random variables is itself random, we can tell the users to initialize such variables using:

x = Vector{Union{Missing, T}}(undef, 10)`

for example inside the model body. The unsampled random variables will therefore be missing. These semantics for "sampling from Turing models" seem reasonable to me even in the case when the number of random variables is itself random.

Performance and alternative implementation

Given that we inline f, I thought the Julia compiler should be able to optimize out s = missing and m = missing if s and m were defined inside #function body, because the compiler does optimize these away when I don't use a closure. Unfortunately, the following minimal example doesn't infer properly:

function f(a)
    # initialize
    if a isa Int
        b = missing
    end

    @inline function g()
        # model body starts
        b = [1, 1]
        # Can have more stuff here and multiple return statements in branches
        return 2*b
        # model body ends
    end
    out = g() # overwrite `b` if `a isa Int`

    # Return a namedtuple or the output of the model depending on the type of the input `a` (or ctx in the real example)
    if a isa Int
        return (b = b,)
    else
        return out
    end
end
@code_warntype f(1) # f(1.0)

and according to Keno Fischer on Slack, this is "a limitation that needs to be addressed properly at some point". The problem is basically the closure.

So my alternative solution is to do:

function inner_function(vi, sampler, ctx, model)
    if ctx isa DistributionContext
        s = missing
        m = missing
    end
    #function body
    model_output = nothing
    @label end_of_func
    if ctx isa DistributionContext
        return (s = s, m = m)
    else
        return model_output
    end
end

and inside #function body, we replace every return ... statement with:

model_output = ...
@goto end_of_func

and every return statement with:

@goto end_of_func

This eliminates the use of a closure and therefore should infer properly. At least, the following minimal example does:

function f(a)
    # initialize
    if a isa Int
        b = missing
    end

    # model body starts
    # can have arbitrary code here with multiple return points, some or all the random variables can be defined
    b = [1, 1]
    if sum(b) > 0.5
        out = 2*b
        @goto end_of_func
    else
        out = 3*b
        @goto end_of_func
    end
    # model body ends
    out = nothing # the function output is nothing if we made it here

    @label end_of_func

    # Return a namedtuple or the output of the model depending on the type of the input `a` (or ctx in the real example)    
    if a isa Int
        return (b = b,)
    else
        return out
    end
end
@code_warntype f(1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions