Skip to content

Commit fe368b9

Browse files
authored
Merge pull request #43 from phipsgabler/context_independence
Make default model evaluation independent of Turing
2 parents c59f519 + ee9b693 commit fe368b9

26 files changed

+911
-783
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
2121
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
2222
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
2323
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
24+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
25+
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
2426
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2527
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2628
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -42,4 +44,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4244
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
4345

4446
[targets]
45-
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
47+
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]

src/DynamicPPL.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
module DynamicPPL
22

33
using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
4-
using Distributions: UnivariateDistribution,
5-
MultivariateDistribution,
6-
MatrixDistribution,
7-
Distribution
8-
using Bijectors: link, invlink
4+
using Distributions
5+
using Bijectors
96
using MacroTools
107

118
import Base: string,
@@ -76,21 +73,33 @@ export VarName,
7673
LikelihoodContext,
7774
PriorContext,
7875
MiniBatchContext,
76+
assume,
77+
dot_assume,
78+
observer,
79+
dot_observe,
80+
tilde,
81+
dot_tilde,
82+
# Pseudo distributions
83+
NamedDist,
84+
NoDist,
7985
# Prob macros
8086
@prob_str,
8187
@logprob_str
8288

89+
const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_DYNAMICPPL", "0")))
90+
8391
# Used here and overloaded in Turing
8492
function getspace end
85-
function tilde end
86-
function dot_tilde end
8793

8894
include("utils.jl")
8995
include("selector.jl")
9096
include("model.jl")
9197
include("sampler.jl")
98+
include("varname.jl")
99+
include("distribution_wrappers.jl")
92100
include("contexts.jl")
93101
include("varinfo.jl")
102+
include("context_implementations.jl")
94103
include("compiler.jl")
95104
include("prob_macro.jl")
96105

src/compiler.jl

Lines changed: 8 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -33,116 +33,7 @@ function _error_msg()
3333
return "This macro is only for use in the `@model` macro and not for external use."
3434
end
3535

36-
"""
37-
@varname(var)
38-
39-
A macro that returns an instance of `VarName` given the symbol or expression of a Julia variable, e.g. `@varname x[1,2][1+5][45][3]` returns `VarName{:x}("[1,2][6][45][3]")`.
40-
"""
41-
macro varname(expr::Union{Expr, Symbol})
42-
expr |> varname |> esc
43-
end
44-
function varname(expr)
45-
ex = deepcopy(expr)
46-
(ex isa Symbol) && return quote
47-
DynamicPPL.VarName{$(QuoteNode(ex))}("")
48-
end
49-
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
50-
inds = :(())
51-
while ex.head == :ref
52-
if length(ex.args) >= 2
53-
strs = map(x -> :($x === (:) ? "Colon()" : string($x)), ex.args[2:end])
54-
pushfirst!(inds.args, :("[" * join($(Expr(:vect, strs...)), ",") * "]"))
55-
end
56-
ex = ex.args[1]
57-
isa(ex, Symbol) && return quote
58-
DynamicPPL.VarName{$(QuoteNode(ex))}(foldl(*, $inds, init = ""))
59-
end
60-
end
61-
throw("VarName: Mis-formed variable name $(expr)!")
62-
end
63-
64-
macro vsym(expr::Union{Expr, Symbol})
65-
expr |> vsym
66-
end
67-
68-
"""
69-
vsym(expr::Union{Expr, Symbol})
70-
71-
Returns the variable symbol given the input variable expression `expr`. For example, if the input `expr = :(x[1])`, the output is `:x`.
72-
"""
73-
function vsym(expr::Union{Expr, Symbol})
74-
ex = deepcopy(expr)
75-
(ex isa Symbol) && return QuoteNode(ex)
76-
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
77-
while ex.head == :ref
78-
ex = ex.args[1]
79-
isa(ex, Symbol) && return QuoteNode(ex)
80-
end
81-
throw("VarName: Mis-formed variable name $(expr)!")
82-
end
83-
84-
"""
85-
@vinds(expr)
8636

87-
Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1,:][2]` returns
88-
`((1, Colon()), (2,))`.
89-
"""
90-
macro vinds(expr::Union{Expr, Symbol})
91-
expr |> vinds |> esc
92-
end
93-
function vinds(expr::Union{Expr, Symbol})
94-
ex = deepcopy(expr)
95-
inds = Expr(:tuple)
96-
(ex isa Symbol) && return inds
97-
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
98-
while ex.head == :ref
99-
pushfirst!(inds.args, Expr(:tuple, ex.args[2:end]...))
100-
ex = ex.args[1]
101-
isa(ex, Symbol) && return inds
102-
end
103-
throw("VarName: Mis-formed variable name $(expr)!")
104-
end
105-
106-
"""
107-
split_var_str(var_str, inds_as = Vector)
108-
109-
This function splits a variable string, e.g. `"x[1:3,1:2][3,2]"` to the variable's symbol `"x"` and the indexing `"[1:3,1:2][3,2]"`. If `inds_as = String`, the indices are returned as a string, e.g. `"[1:3,1:2][3,2]"`. If `inds_as = Vector`, the indices are returned as a vector of vectors of strings, e.g. `[["1:3", "1:2"], ["3", "2"]]`.
110-
"""
111-
function split_var_str(var_str, inds_as = Vector)
112-
ind = findfirst(c -> c == '[', var_str)
113-
if inds_as === String
114-
if ind === nothing
115-
return var_str, ""
116-
else
117-
return var_str[1:ind-1], var_str[ind:end]
118-
end
119-
end
120-
@assert inds_as === Vector
121-
inds = Vector{String}[]
122-
if ind === nothing
123-
return var_str, inds
124-
end
125-
sym = var_str[1:ind-1]
126-
ind = length(sym)
127-
while ind < length(var_str)
128-
ind += 1
129-
@assert var_str[ind] == '['
130-
push!(inds, String[])
131-
while var_str[ind] != ']'
132-
ind += 1
133-
if var_str[ind] == '['
134-
ind2 = findnext(c -> c == ']', var_str, ind)
135-
push!(inds[end], strip(var_str[ind:ind2]))
136-
ind = ind2+1
137-
else
138-
ind2 = findnext(c -> c == ',' || c == ']', var_str, ind)
139-
push!(inds[end], strip(var_str[ind:ind2-1]))
140-
ind = ind2
141-
end
142-
end
143-
end
144-
return sym, inds
145-
end
14637

14738
# Check if the right-hand side is a distribution.
14839
function assert_dist(dist; msg)
@@ -404,21 +295,21 @@ function replace_tilde!(model_info)
404295
ex = model_info[:main_body]
405296
ex = MacroTools.postwalk(ex) do x
406297
if @capture(x, @M_ L_ ~ R_) && M == Symbol("@__dot__")
407-
dot_tilde(L, R, model_info)
298+
generate_dot_tilde(L, R, model_info)
408299
else
409300
x
410301
end
411302
end
412303
$(VERSION >= v"1.1" ? "ex = MacroTools.postwalk(ex) do x
413304
if @capture(x, L_ .~ R_)
414-
dot_tilde(L, R, model_info)
305+
generate_dot_tilde(L, R, model_info)
415306
else
416307
x
417308
end
418309
end" : "")
419310
ex = MacroTools.postwalk(ex) do x
420311
if @capture(x, L_ ~ R_)
421-
tilde(L, R, model_info)
312+
generate_tilde(L, R, model_info)
422313
else
423314
x
424315
end
@@ -429,12 +320,12 @@ end
429320
""" |> Meta.parse |> eval
430321

431322
"""
432-
tilde(left, right, model_info)
323+
generate_tilde(left, right, model_info)
433324
434325
The `tilde` function generates `observe` expression for data variables and `assume`
435326
expressions for parameter variables, updating `model_info` in the process.
436327
"""
437-
function tilde(left, right, model_info)
328+
function generate_tilde(left, right, model_info)
438329
arg_syms = Val((model_info[:arg_syms]...,))
439330
model = model_info[:main_body_names][:model]
440331
vi = model_info[:main_body_names][:vi]
@@ -478,11 +369,11 @@ function tilde(left, right, model_info)
478369
end
479370

480371
"""
481-
dot_tilde(left, right, model_info)
372+
generate_dot_tilde(left, right, model_info)
482373
483374
This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
484375
"""
485-
function dot_tilde(left, right, model_info)
376+
function generate_dot_tilde(left, right, model_info)
486377
arg_syms = Val((model_info[:arg_syms]...,))
487378
model = model_info[:main_body_names][:model]
488379
vi = model_info[:main_body_names][:vi]
@@ -636,4 +527,4 @@ end
636527
Get the specialized version of type `T` for sampler `spl`. For example,
637528
if `T === Float64` and `spl::Hamiltonian`, the matching type is `eltype(vi[spl])`.
638529
"""
639-
function get_matching_type end
530+
function get_matching_type end

0 commit comments

Comments
 (0)