Skip to content

[Merged by Bors] - Remove ModelGen #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.8.1"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
AbstractMCMC = "1"
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
Distributions = "0.22, 0.23"
ExprTools = "0.1.1"
MacroTools = "0.5.1"
ZygoteRules = "0.2"
julia = "1"
Expand Down
6 changes: 1 addition & 5 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Bijectors
using MacroTools

import AbstractMCMC
import ExprTools
import ZygoteRules

import Random
Expand Down Expand Up @@ -51,25 +52,20 @@ export AbstractVarInfo,
inspace,
subsumes,
# Compiler
ModelGen,
@model,
@varname,
# Utilities
vectorize,
reconstruct,
reconstruct!,
Sample,
Chain,
init,
vectorize,
set_resume!,
# Model
ModelGen,
Model,
getmissings,
getargnames,
getdefaults,
getgenerator,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
111 changes: 71 additions & 40 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,27 @@ end
Builds the `model_info` dictionary from the model's expression.
"""
function build_model_info(input_expr)
# Extract model name (:name), arguments (:args), (:kwargs) and definition (:body)
modeldef = MacroTools.splitdef(input_expr)
# Function body of the model is empty
# Break up the model definition and extract its name, arguments, and function body
modeldef = ExprTools.splitdef(input_expr)

# Print a warning if function body of the model is empty
warn_empty(modeldef[:body])
# Construct model_info dictionary

## Construct model_info dictionary

# Shortcut if the model does not have any arguments
if !haskey(modeldef, :args)
modelinfo = Dict(
:name => modeldef[:name],
:main_body => modeldef[:body],
:arg_syms => [],
:args_nt => NamedTuple(),
:defaults_nt => NamedTuple(),
:args => [],
:modeldef => modeldef,
)
return modelinfo
end

# Extracting the argument symbols from the model definition
arg_syms = map(modeldef[:args]) do arg
Expand Down Expand Up @@ -158,7 +174,7 @@ function build_model_info(input_expr)
:args_nt => args_nt,
:defaults_nt => defaults_nt,
:args => args,
:whereparams => modeldef[:whereparams]
:modeldef => modeldef,
)

return model_info
Expand Down Expand Up @@ -318,48 +334,63 @@ hasmissing(T::Type) = false
Builds the output expression.
"""
function build_output(model_info)
# Arguments with default values
## Build the anonymous evaluator from the user-provided model definition

# Remove the name and use `function (....)` syntax
modeldef = model_info[:modeldef]
delete!(modeldef, :name)
modeldef[:head] = :function

# Define the input arguments (positional + keyword arguments), without default values
origargs = map(vcat(get(modeldef, :args, Any[]), get(modeldef, :kwargs, Any[]))) do arg
Meta.isexpr(arg, :kw) && length(arg.args) >= 1 ? arg.args[1] : arg
end

# Add our own arguments
newargs = Any[:(_rng::$(Random.AbstractRNG)),
:(_model::$(DynamicPPL.Model)),
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
:(_sampler::$(DynamicPPL.AbstractSampler)),
:(_context::$(DynamicPPL.AbstractContext))]
combinedargs = vcat(newargs, origargs)

# Delete keyword arguments and update positional arguments
delete!(modeldef, :kwargs)
modeldef[:args] = combinedargs

# Replace function body
modeldef[:body] = model_info[:main_body]

## Extract other relevant information

# All arguments with default values (if existent)
args = model_info[:args]
# Argument symbols without default values
arg_syms = model_info[:arg_syms]
# Arguments namedtuple
# Named tuple of all arguments
args_nt = model_info[:args_nt]
# Default values of the arguments
# Arguments namedtuple

# Named tuple of the default values of the arguments
defaults_nt = model_info[:defaults_nt]
# Where parameters
whereparams = model_info[:whereparams]
# Model generator name
model_gen = model_info[:name]
# Main body of the model
main_body = model_info[:main_body]

unwrap_data_expr = Expr(:block)
for var in arg_syms
push!(unwrap_data_expr.args,
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
end

@gensym(evaluator, generator)
generator_kw_form = isempty(args) ? () : (:($generator(;$(args...)) = $generator($(arg_syms...))),)
model_gen_constructor = :($(DynamicPPL.ModelGen){$(Tuple(arg_syms))}($generator, $defaults_nt))
# Model name
model = model_info[:name]

return quote
function $evaluator(
_rng::$(Random.AbstractRNG),
_model::$(DynamicPPL.Model),
_varinfo::$(DynamicPPL.AbstractVarInfo),
_sampler::$(DynamicPPL.AbstractSampler),
_context::$(DynamicPPL.AbstractContext),
)
$unwrap_data_expr
$main_body
end
# Define model definition with only keyword arguments
if isempty(args)
model_kwform = ()
else
# All arguments without default values (i.e., only symbols)
arg_syms = model_info[:arg_syms]

$generator($(args...)) = $(DynamicPPL.Model)($evaluator, $args_nt, $model_gen_constructor)
$(generator_kw_form...)
model_kwform = (:($model(; $(args...)) = $model($(arg_syms...))),)
end

$(Base).@__doc__ $model_gen = $model_gen_constructor
@gensym(evaluator)
return quote
$(Base).@__doc__ function $model($(args...))
$evaluator = $(ExprTools.combinedef(modeldef))
return $(DynamicPPL.Model)($evaluator, $args_nt, $defaults_nt)
end
$(model_kwform...)
end
end

Expand Down
Loading