Skip to content

Commit 3602c56

Browse files
committed
Introducing AbstractPPL dependency (#214)
This has three very rudimentary consequences: - `VarName` and its helpers are moved over to AbstractPPL completely - The new abstract base type for models is `AbstractPPL.AbstractProbabilisticProgram <: AbstractMCMC.AbstractModel` - `AbstractVarInfo <: AbstractPPL.AbstractModelTrace` More abstractions (and hopefully concrete generalizations, too) are about to come.
1 parent 1ab2e4e commit 3602c56

File tree

7 files changed

+22
-262
lines changed

7 files changed

+22
-262
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.7"
3+
version = "0.10.8"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
7+
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
78
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -13,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314

1415
[compat]
1516
AbstractMCMC = "2"
17+
AbstractPPL = "0.1.2"
1618
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
1719
ChainRulesCore = "0.9.7"
1820
Distributions = "0.23.8, 0.24"

src/DynamicPPL.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DynamicPPL
22

3-
using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
3+
using AbstractMCMC: AbstractSampler, AbstractChains
4+
using AbstractPPL
45
using Distributions
56
using Bijectors
67

@@ -49,13 +50,13 @@ export AbstractVarInfo,
4950
link!,
5051
invlink!,
5152
tonamedtuple,
52-
#VarName
53+
# VarName (reexport from AbstractPPL)
5354
VarName,
5455
inspace,
5556
subsumes,
57+
@varname,
5658
# Compiler
5759
@model,
58-
@varname,
5960
# Utilities
6061
vectorize,
6162
reconstruct,
@@ -104,7 +105,7 @@ export loglikelihood
104105
function getspace end
105106

106107
# Necessary forward declarations
107-
abstract type AbstractVarInfo end
108+
abstract type AbstractVarInfo <: AbstractModelTrace end
108109
abstract type AbstractContext end
109110

110111

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
3232
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
3333
```
3434
"""
35-
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel
35+
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractProbabilisticProgram
3636
name::Symbol
3737
f::F
3838
args::NamedTuple{argnames,Targs}

src/varname.jl

Lines changed: 9 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -1,240 +1,23 @@
11
"""
2-
VarName(sym[, indexing=()])
2+
inargnames(varname::VarName, model::Model)
33
4-
A variable identifier for a symbol `sym` and indices `indexing` in the format
5-
returned by [`@vinds`](@ref).
4+
Statically check whether the variable of name `varname` is an argument of the `model`.
65
7-
The Julia variable in the model corresponding to `sym` can refer to a single value or to a
8-
hierarchical array structure of univariate, multivariate or matrix variables. The field `indexing`
9-
stores the indices requires to access the random variable from the Julia variable indicated by `sym`
10-
as a tuple of tuples. Each element of the tuple thereby contains the indices of one indexing
11-
operation.
12-
13-
`VarName`s can be manually constructed using the `VarName(sym, indexing)` constructor, or from an
14-
indexing expression through the [`@varname`](@ref) convenience macro.
15-
16-
# Examples
17-
18-
```jldoctest
19-
julia> vn = VarName(:x, ((Colon(), 1), (2,)))
20-
x[Colon(),1][2]
21-
22-
julia> vn.indexing
23-
((Colon(), 1), (2,))
24-
25-
julia> VarName(DynamicPPL.@vsym(x[:, 1][1+1]), DynamicPPL.@vinds(x[:, 1][1+1]))
26-
x[Colon(),1][2]
27-
```
28-
"""
29-
struct VarName{sym, T<:Tuple}
30-
indexing::T
31-
end
32-
33-
VarName(sym::Symbol, indexing::Tuple = ()) = VarName{sym, typeof(indexing)}(indexing)
34-
35-
"""
36-
VarName(vn::VarName[, indexing=()])
37-
38-
Return a copy of `vn` with a new index `indexing`.
39-
"""
40-
function VarName(vn::VarName, indexing::Tuple = ())
41-
return VarName{getsym(vn), typeof(indexing)}(indexing)
42-
end
43-
44-
45-
"""
46-
getsym(vn::VarName)
47-
48-
Return the symbol of the Julia variable used to generate `vn`.
49-
"""
50-
getsym(vn::VarName{sym}) where sym = sym
51-
52-
53-
"""
54-
getindexing(vn::VarName)
55-
56-
Return the indexing tuple of the Julia variable used to generate `vn`.
57-
"""
58-
getindexing(vn::VarName) = vn.indexing
59-
60-
61-
Base.hash(vn::VarName, h::UInt) = hash((getsym(vn), getindexing(vn)), h)
62-
Base.:(==)(x::VarName, y::VarName) = getsym(x) == getsym(y) && getindexing(x) == getindexing(y)
63-
64-
function Base.show(io::IO, vn::VarName)
65-
print(io, getsym(vn))
66-
for indices in getindexing(vn)
67-
print(io, "[")
68-
join(io, indices, ",")
69-
print(io, "]")
70-
end
71-
end
72-
73-
74-
"""
75-
Symbol(vn::VarName)
76-
77-
Return a `Symbol` represenation of the variable identifier `VarName`.
78-
"""
79-
Base.Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol
80-
81-
82-
"""
83-
inspace(vn::Union{VarName, Symbol}, space::Tuple)
84-
85-
Check whether `vn`'s variable symbol is in `space`.
86-
"""
87-
inspace(vn, space::Tuple{}) = true # empty space is treated as universal set
88-
inspace(vn, space::Tuple) = vn in space
89-
inspace(vn::VarName, space::Tuple{}) = true
90-
inspace(vn::VarName, space::Tuple) = any(_in(vn, s) for s in space)
91-
92-
_in(vn::VarName, s::Symbol) = getsym(vn) == s
93-
_in(vn::VarName, s::VarName) = subsumes(s, vn)
94-
95-
96-
"""
97-
subsumes(u::VarName, v::VarName)
98-
99-
Check whether the variable name `v` describes a sub-range of the variable `u`. Supported
100-
indexing:
101-
102-
- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc.
103-
- Array of scalar: `x[[1, 2], 3]` subsumes `x[1, 3]`, `x[1:3]` subsumes `x[2][1]`, etc.
104-
(basically everything that fulfills `issubset`).
105-
- Slices: `x[2, :]` subsumes `x[2, 10][1]`, etc.
106-
107-
Currently _not_ supported are:
108-
109-
- Boolean indexing, literal `CartesianIndex` (these could be added, though)
110-
- Linear indexing of multidimensional arrays: `x[4]` does not subsume `x[2, 2]` for `x` a matrix
111-
- Trailing ones: `x[2, 1]` does not subsume `x[2]` for `x` a vector
112-
"""
113-
function subsumes(u::VarName, v::VarName)
114-
return getsym(u) == getsym(v) && subsumes(u.indexing, v.indexing)
115-
end
116-
117-
subsumes(::Tuple{}, ::Tuple{}) = true # x subsumes x
118-
subsumes(::Tuple{}, ::Tuple) = true # x subsumes x[1]
119-
subsumes(::Tuple, ::Tuple{}) = false # x[1] does not subsume x
120-
function subsumes(t::Tuple, u::Tuple) # does x[i]... subsume x[j]...?
121-
return _issubindex(first(t), first(u)) && subsumes(Base.tail(t), Base.tail(u))
122-
end
123-
124-
const AnyIndex = Union{Int, AbstractVector{Int}, Colon}
125-
_issubindex_(::Tuple{Vararg{AnyIndex}}, ::Tuple{Vararg{AnyIndex}}) = false
126-
function _issubindex(t::NTuple{N, AnyIndex}, u::NTuple{N, AnyIndex}) where {N}
127-
return all(_issubrange(j, i) for (i, j) in zip(t, u))
128-
end
129-
130-
const ConcreteIndex = Union{Int, AbstractVector{Int}} # this include all kinds of ranges
131-
"""Determine whether indices `i` are contained in `j`, treating `:` as universal set."""
132-
_issubrange(i::ConcreteIndex, j::ConcreteIndex) = issubset(i, j)
133-
_issubrange(i::Union{ConcreteIndex, Colon}, j::Colon) = true
134-
_issubrange(i::Colon, j::ConcreteIndex) = true
135-
136-
137-
138-
"""
139-
@varname(expr)
140-
141-
A macro that returns an instance of [`VarName`](@ref) given a symbol or indexing expression `expr`.
142-
143-
The `sym` value is taken from the actual variable name, and the index values are put appropriately
144-
into the constructor (and resolved at runtime).
145-
146-
# Examples
147-
148-
```jldoctest
149-
julia> @varname(x).indexing
150-
()
151-
152-
julia> @varname(x[1]).indexing
153-
((1,),)
154-
155-
julia> @varname(x[:, 1]).indexing
156-
((Colon(), 1),)
157-
158-
julia> @varname(x[:, 1][2]).indexing
159-
((Colon(), 1), (2,))
160-
161-
julia> @varname(x[1,2][1+5][45][3]).indexing
162-
((1, 2), (6,), (45,), (3,))
163-
```
164-
165-
!!! compat "Julia 1.5"
166-
Using `begin` in an indexing expression to refer to the first index requires at least
167-
Julia 1.5.
6+
Possibly existing indices of `varname` are neglected.
1687
"""
169-
macro varname(expr::Union{Expr, Symbol})
170-
return esc(varname(expr))
171-
end
172-
173-
varname(expr::Symbol) = VarName(expr)
174-
function varname(expr::Expr)
175-
if Meta.isexpr(expr, :ref)
176-
sym, inds = vsym(expr), vinds(expr)
177-
return :($(DynamicPPL.VarName)($(QuoteNode(sym)), $inds))
178-
else
179-
throw("VarName: Mis-formed variable name $(expr)!")
180-
end
181-
end
182-
183-
184-
"""
185-
@vsym(expr)
186-
187-
A macro that returns the variable symbol given the input variable expression `expr`.
188-
For example, `@vsym x[1]` returns `:x`.
189-
"""
190-
macro vsym(expr::Union{Expr, Symbol})
191-
return QuoteNode(vsym(expr))
8+
@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F}
9+
return s in argnames
19210
end
19311

194-
vsym(expr::Symbol) = expr
195-
function vsym(expr::Expr)
196-
if Meta.isexpr(expr, :ref)
197-
return vsym(expr.args[1])
198-
else
199-
throw("VarName: Mis-formed variable name $(expr)!")
200-
end
201-
end
20212

20313
"""
204-
@vinds(expr)
14+
inmissings(varname::VarName, model::Model)
20515
206-
Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1, :][2]` returns
207-
`((1, Colon()), (2,))`.
16+
Statically check whether the variable of name `varname` is a statically declared unobserved variable
17+
of the `model`.
20818
209-
!!! compat "Julia 1.5"
210-
Using `begin` in an indexing expression to refer to the first index requires at least
211-
Julia 1.5.
19+
Possibly existing indices of `varname` are neglected.
21220
"""
213-
macro vinds(expr::Union{Expr, Symbol})
214-
return esc(vinds(expr))
215-
end
216-
217-
vinds(expr::Symbol) = Expr(:tuple)
218-
function vinds(expr::Expr)
219-
if Meta.isexpr(expr, :ref)
220-
ex = copy(expr)
221-
@static if VERSION < v"1.5.0-DEV.666"
222-
Base.replace_ref_end!(ex)
223-
else
224-
Base.replace_ref_begin_end!(ex)
225-
end
226-
last = Expr(:tuple, ex.args[2:end]...)
227-
init = vinds(ex.args[1]).args
228-
return Expr(:tuple, init..., last)
229-
else
230-
throw("VarName: Mis-formed variable name $(expr)!")
231-
end
232-
end
233-
234-
@generated function inargnames(::VarName{s}, ::Model{_F, argnames}) where {s, argnames, _F}
235-
return s in argnames
236-
end
237-
23821
@generated function inmissings(::VarName{s}, ::Model{_F, _a, _T, missings}) where {s, missings, _F, _a, _T}
23922
return s in missings
24023
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
3+
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
34
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
45
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
56
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -17,6 +18,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
1920
AbstractMCMC = "2.1"
21+
AbstractPPL = "0.1.2"
2022
Bijectors = "0.8.2"
2123
Distributions = "0.24"
2224
DistributionsAD = "0.6.3"

test/compiler.jl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -222,35 +222,6 @@ end
222222
varinfo = VarInfo(model)
223223
@test getlogp(varinfo) == lp
224224
end
225-
@testset "var name splitting" begin
226-
var_expr = :(x)
227-
@test vsym(var_expr) == :x
228-
@test vinds(var_expr) == :(())
229-
230-
var_expr = :(x[1,1][2,3])
231-
@test vsym(var_expr) == :x
232-
@test vinds(var_expr) == :(((1, 1), (2, 3)))
233-
234-
var_expr = :(x[:,1][2,:])
235-
@test vsym(var_expr) == :x
236-
@test vinds(var_expr) == :(((:, 1), (2, :)))
237-
238-
var_expr = :(x[2:3,1][2,1:2])
239-
@test vsym(var_expr) == :x
240-
@test vinds(var_expr) == :(((2:3, 1), (2, 1:2)))
241-
242-
var_expr = :(x[2:3,2:3][[1,2],[1,2]])
243-
@test vsym(var_expr) == :x
244-
@test vinds(var_expr) == :(((2:3, 2:3), ([1, 2], [1, 2])))
245-
246-
var_expr = :(x[end])
247-
@test vsym(var_expr) == :x
248-
@test vinds(var_expr) == :((($lastindex(x),),))
249-
250-
var_expr = :(x[1, end])
251-
@test vsym(var_expr) == :x
252-
@test vinds(var_expr) == :(((1, $lastindex(x, 2)),))
253-
end
254225
@testset "user-defined variable name" begin
255226
@model f1() = x ~ NamedDist(Normal(), :y)
256227
@model f2() = x ~ NamedDist(Normal(), @varname(y[2][:,1]))

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using DynamicPPL
22
using AbstractMCMC
3+
using AbstractPPL
34
using Bijectors
45
using Distributions
56
using DistributionsAD
@@ -16,7 +17,7 @@ using Random
1617
using Serialization
1718
using Test
1819

19-
using DynamicPPL: vsym, vinds, getargs_dottilde, getargs_tilde, Selector
20+
using DynamicPPL: getargs_dottilde, getargs_tilde, Selector
2021

2122
const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL)))
2223
const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing")

0 commit comments

Comments
 (0)