Skip to content

Commit 1c071df

Browse files
authored
Merge pull request #50 from phipsgabler/phg/varinfo-indices
Refactor VarName
2 parents 55b09f1 + 2c3244e commit 1c071df

File tree

11 files changed

+242
-265
lines changed

11 files changed

+242
-265
lines changed

src/DynamicPPL.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ using Bijectors
66
using MacroTools
77
import ZygoteRules
88

9-
import Base: string,
10-
Symbol,
9+
import Base: Symbol,
1110
==,
1211
hash,
13-
in,
1412
getindex,
1513
setindex!,
1614
push!,
@@ -23,8 +21,7 @@ import Base: string,
2321
haskey
2422

2523
# VarInfo
26-
export VarName,
27-
AbstractVarInfo,
24+
export AbstractVarInfo,
2825
VarInfo,
2926
UntypedVarInfo,
3027
getlogp,
@@ -45,6 +42,10 @@ export VarName,
4542
link!,
4643
invlink!,
4744
tonamedtuple,
45+
#VarName
46+
VarName,
47+
inspace,
48+
subsumes,
4849
# Compiler
4950
ModelGen,
5051
@model,

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ end
7070
To generate a `Model`, call `model_generator(x_value)`.
7171
"""
7272
macro model(input_expr)
73-
build_model_info(input_expr) |> replace_tilde! |> replace_vi! |>
73+
Base.replace_ref_end!(input_expr) |> build_model_info |> replace_tilde! |> replace_vi! |>
7474
replace_logpdf! |> replace_sampler! |> build_output
7575
end
7676

src/context_implementations.jl

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,7 @@ function _tilde(sampler, right, vn::VarName, vi)
5252
return assume(sampler, right, vn, vi)
5353
end
5454
function _tilde(sampler, right::NamedDist, vn::VarName, vi)
55-
name = right.name
56-
if name isa String
57-
sym_str, inds = split_var_str(name, String)
58-
sym = Symbol(sym_str)
59-
vn = VarName{sym}(inds)
60-
elseif name isa Symbol
61-
vn = VarName{name}("")
62-
elseif name isa VarName
63-
vn = name
64-
else
65-
throw("Unsupported variable name. Please use either a string, symbol or VarName.")
66-
end
67-
return _tilde(sampler, right.dist, vn, vi)
55+
return _tilde(sampler, right.dist, right.name, vi)
6856
end
6957

7058
# observe
@@ -214,30 +202,19 @@ end
214202

215203

216204
function get_vns_and_dist(dist::NamedDist, var, vn::VarName)
217-
name = dist.name
218-
if name isa String
219-
sym_str, inds = split_var_str(name, String)
220-
sym = Symbol(sym_str)
221-
vn = VarName{sym}(inds)
222-
elseif name isa Symbol
223-
vn = VarName{name}("")
224-
elseif name isa VarName
225-
vn = name
226-
else
227-
throw("Unsupported variable name. Please use either a string, symbol or VarName.")
228-
end
229-
return get_vns_and_dist(dist.dist, var, vn)
205+
return get_vns_and_dist(dist.dist, var, dist.name)
230206
end
231207
function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName)
232-
getvn = i -> VarName(vn, vn.indexing * "[Colon(),$i]")
208+
getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i)))
233209
return getvn.(1:size(var, 2)), dist
210+
234211
end
235212
function get_vns_and_dist(
236213
dist::Union{Distribution, AbstractArray{<:Distribution}},
237214
var::AbstractArray,
238215
vn::VarName
239216
)
240-
getvn = ind -> VarName(vn, vn.indexing * "[" * join(Tuple(ind), ",") * "]")
217+
getvn = ind -> VarName(vn, (vn.indexing..., Tuple(ind)))
241218
return getvn.(CartesianIndices(var)), dist
242219
end
243220

src/distribution_wrappers.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ A named distribution that carries the name of the random variable with it.
1111
struct NamedDist{
1212
variate,
1313
support,
14-
Td <: Distribution{variate, support},
15-
Tn
14+
Td <: Distribution{variate, support},
15+
Tv <: VarName
1616
} <: Distribution{variate, support}
1717
dist::Td
18-
name::Tn
18+
name::Tv
1919
end
2020

21+
NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName(name))
22+
2123

2224
struct NoDist{
2325
variate,

src/varinfo.jl

Lines changed: 7 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ end
125125
offset = :(0)
126126
for f in names
127127
mdf = :(metadata.$f)
128-
if f in space || length(space) == 0
128+
if inspace(f, space) || length(space) == 0
129129
len = :(length($mdf.vals))
130130
push!(exprs, :($f = Metadata($mdf.idcs,
131131
$mdf.vns,
@@ -330,13 +330,6 @@ setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val)
330330
return expr
331331
end
332332

333-
"""
334-
getsym(vn::VarName)
335-
336-
Return the symbol of the Julia variable used to generate `vn`.
337-
"""
338-
getsym(vn::VarName{sym}) where sym = sym
339-
340333
"""
341334
getgid(vi::VarInfo, vn::VarName)
342335
@@ -407,7 +400,7 @@ end
407400
# If the varname is in the sampler space
408401
# or the sample space is empty (all variables)
409402
# then return the indices for that variable.
410-
if f in space || length(space) == 0
403+
if inspace(f, space) || length(space) == 0
411404
push!(exprs, :($f = findinds(metadata.$f, s, Val($space))))
412405
end
413406
end
@@ -418,7 +411,7 @@ end
418411
# Get all the idcs of the vns in `space` and that belong to the selector `s`
419412
return filter((i) ->
420413
(s in f_meta.gids[i] || isempty(f_meta.gids[i])) &&
421-
(isempty(space) || in(f_meta.vns[i], space)), 1:length(f_meta.gids))
414+
(isempty(space) || inspace(f_meta.vns[i], space)), 1:length(f_meta.gids))
422415
end
423416
@inline function findinds(f_meta)
424417
# Get all the idcs of the vns
@@ -488,69 +481,6 @@ end
488481
#### APIs for typed and untyped VarInfo
489482
####
490483

491-
# VarName
492-
493-
"""
494-
VarName(sym, indexing)
495-
VarName{sym}(indexing::String)
496-
497-
Construct a new instance of `VarName{sym}`
498-
"""
499-
VarName(sym, indexing) = VarName{sym}(indexing)
500-
501-
"""
502-
VarName(vn::VarName, indexing::String)
503-
504-
Return a copy of `vn` with a new index `indexing`.
505-
"""
506-
function VarName(vn::VarName, indexing::String)
507-
return VarName{getsym(vn)}(indexing)
508-
end
509-
510-
"""
511-
uid(vn::VarName)
512-
513-
Return a unique tuple identifier for `vn`.
514-
"""
515-
uid(vn::VarName) = (getsym(vn), vn.indexing)
516-
517-
hash(vn::VarName) = hash(uid(vn))
518-
519-
==(x::VarName, y::VarName) = hash(uid(x)) == hash(uid(y))
520-
521-
function string(vn::VarName)
522-
return "$(getsym(vn))$(vn.indexing)"
523-
end
524-
function string(vns::Vector{<:VarName})
525-
return replace(string(map(string, vns)), "String" => "")
526-
end
527-
528-
"""
529-
Symbol(vn::VarName)
530-
531-
Return a `Symbol` represenation of the variable identifier `VarName`.
532-
"""
533-
Symbol(vn::VarName) = Symbol(string(vn)) # simplified symbol
534-
535-
"""
536-
in(vn::VarName, space::Set)
537-
538-
Check whether `vn`'s symbol is in `space`.
539-
"""
540-
in(::VarName, ::Tuple{}) = true
541-
in(vn::VarName, space::Tuple)::Bool = getsym(vn) in space || _in(string(vn), space)
542-
543-
_in(::String, ::Tuple{}) = false
544-
_in(vn_str::String, space::Tuple)::Bool = _in(vn_str, Base.tail(space))
545-
function _in(vn_str::String, space::Tuple{Expr,Vararg})::Bool
546-
# Collect expressions from space
547-
expr = first(space)
548-
# Filter `(` and `)` out and get a string representation of `exprs`
549-
expr_str = replace(string(expr), r"\(|\)" => "")
550-
# Check if `vn_str` is in `expr_strs`
551-
valid = occursin(expr_str, vn_str)
552-
return valid || _in(vn_str, Base.tail(space))
553-
end
554484

555485
# VarInfo
556486

@@ -602,8 +532,7 @@ function TypedVarInfo(vi::UntypedVarInfo)
602532
sym_vals = foldl(vcat, _vals)
603533

604534
push!(new_metas, Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals,
605-
sym_dists, sym_gids, sym_orders, sym_flags)
606-
)
535+
sym_dists, sym_gids, sym_orders, sym_flags))
607536
end
608537
logp = getlogp(vi)
609538
num_produce = get_num_produce(vi)
@@ -764,7 +693,7 @@ end
764693
@generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space}
765694
expr = Expr(:block)
766695
for f in names
767-
if f in space || length(space) == 0
696+
if inspace(f, space) || length(space) == 0
768697
push!(expr.args, quote
769698
f_vns = vi.metadata.$f.vns
770699
if ~istrans(vi, f_vns[1])
@@ -810,7 +739,7 @@ end
810739
@generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space}
811740
expr = Expr(:block)
812741
for f in names
813-
if f in space || length(space) == 0
742+
if inspace(f, space) || length(space) == 0
814743
push!(expr.args, quote
815744
f_vns = vi.metadata.$f.vns
816745
if istrans(vi, f_vns[1])
@@ -1173,7 +1102,7 @@ Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selec
11731102
and `vn`'s symbol is in the space of `spl`.
11741103
"""
11751104
function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)
1176-
if vn in getspace(spl)
1105+
if inspace(vn, getspace(spl))
11771106
setgid!(vi, spl.selector, vn)
11781107
end
11791108
end

0 commit comments

Comments
 (0)