Skip to content

Commit 1f2f160

Browse files
committed
Rename internal variables (#225)
This PR contains only a cosmetic change and makes the internal variable names consistent with the convention used by other packages, such as Julia base (`__module__` and `__source__`) and Zygote (`__context__`). It is not strictly necessary to deprecate the current variable names since they are not exported but it seemed reasonable since probably at least `_varinfo` is known and used. Co-authored-by: David Widmann <[email protected]>
1 parent d2678d5 commit 1f2f160

File tree

6 files changed

+52
-37
lines changed

6 files changed

+52
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.13"
3+
version = "0.10.14"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
22
"Distributions."
33

4-
const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)
4+
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
5+
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)
56

67
"""
78
isassumption(expr)
@@ -24,7 +25,7 @@ function isassumption(expr::Union{Symbol, Expr})
2425
let $vn = $(varname(expr))
2526
# This branch should compile nicely in all cases except for partial missing data
2627
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
27-
if !$(DynamicPPL.inargnames)($vn, _model) || $(DynamicPPL.inmissings)($vn, _model)
28+
if !$(DynamicPPL.inargnames)($vn, __model__) || $(DynamicPPL.inmissings)($vn, __model__)
2829
true
2930
else
3031
# Evaluate the LHS
@@ -167,10 +168,20 @@ generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, war
167168

168169
generate_mainbody!(mod, found, x, warn) = x
169170
function generate_mainbody!(mod, found, sym::Symbol, warn)
171+
if sym in DEPRECATED_INTERNALNAMES
172+
newsym = Symbol(:_, sym, :__)
173+
Base.depwarn(
174+
"internal variable `$sym` is deprecated, use `$newsym` instead.",
175+
:generate_mainbody!,
176+
)
177+
return generate_mainbody!(mod, found, newsym, warn)
178+
end
179+
170180
if warn && sym in INTERNALNAMES && sym found
171-
@warn "you are using the internal variable `$(sym)`"
181+
@warn "you are using the internal variable `$sym`"
172182
push!(found, sym)
173183
end
184+
174185
return sym
175186
end
176187
function generate_mainbody!(mod, found, expr::Expr, warn)
@@ -228,18 +239,20 @@ function generate_tilde(left, right)
228239
$isassumption = $(DynamicPPL.isassumption(left))
229240
if $isassumption
230241
$left = $(DynamicPPL.tilde_assume)(
231-
_rng, _context, _sampler, $tmpright, $vn, $inds, _varinfo)
242+
__rng__, __context__, __sampler__, $tmpright, $vn, $inds, __varinfo__
243+
)
232244
else
233245
$(DynamicPPL.tilde_observe)(
234-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
246+
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
247+
)
235248
end
236249
end
237250
end
238251

239252
# If the LHS is a literal, it is always an observation
240253
return quote
241254
$(top...)
242-
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
255+
$(DynamicPPL.tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
243256
end
244257
end
245258

@@ -263,18 +276,20 @@ function generate_dot_tilde(left, right)
263276
$isassumption = $(DynamicPPL.isassumption(left)) || $left === missing
264277
if $isassumption
265278
$left .= $(DynamicPPL.dot_tilde_assume)(
266-
_rng, _context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
279+
__rng__, __context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
280+
)
267281
else
268282
$(DynamicPPL.dot_tilde_observe)(
269-
_context, _sampler, $tmpright, $left, $vn, $inds, _varinfo)
283+
__context__, __sampler__, $tmpright, $left, $vn, $inds, __varinfo__
284+
)
270285
end
271286
end
272287
end
273288

274289
# If the LHS is a literal, it is always an observation
275290
return quote
276291
$(top...)
277-
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
292+
$(DynamicPPL.dot_tilde_observe)(__context__, __sampler__, $tmpright, $left, __varinfo__)
278293
end
279294
end
280295

@@ -298,11 +313,11 @@ function build_output(modelinfo, linenumbernode)
298313
# Add the internal arguments to the user-specified arguments (positional + keywords).
299314
evaluatordef[:args] = vcat(
300315
[
301-
:(_rng::$(Random.AbstractRNG)),
302-
:(_model::$(DynamicPPL.Model)),
303-
:(_varinfo::$(DynamicPPL.AbstractVarInfo)),
304-
:(_sampler::$(DynamicPPL.AbstractSampler)),
305-
:(_context::$(DynamicPPL.AbstractContext)),
316+
:(__rng__::$(Random.AbstractRNG)),
317+
:(__model__::$(DynamicPPL.Model)),
318+
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
319+
:(__sampler__::$(DynamicPPL.AbstractSampler)),
320+
:(__context__::$(DynamicPPL.AbstractContext)),
306321
],
307322
modelinfo[:allargs_exprs],
308323
)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability.
99
"""
1010
macro addlogprob!(ex)
1111
return quote
12-
acclogp!($(esc(:(_varinfo))), $(esc(ex)))
12+
acclogp!($(esc(:(__varinfo__))), $(esc(ex)))
1313
end
1414
end
1515

test/compiler.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,12 @@ end
172172
# Test use of internal names
173173
@model function testmodel_missing3(x)
174174
x[1] ~ Bernoulli(0.5)
175-
global varinfo_ = _varinfo
176-
global sampler_ = _sampler
177-
global model_ = _model
178-
global context_ = _context
179-
global rng_ = _rng
180-
global lp = getlogp(_varinfo)
175+
global varinfo_ = __varinfo__
176+
global sampler_ = __sampler__
177+
global model_ = __model__
178+
global context_ = __context__
179+
global rng_ = __rng__
180+
global lp = getlogp(__varinfo__)
181181
return x
182182
end
183183
model = testmodel_missing3([1.0])
@@ -192,12 +192,12 @@ end
192192
# disable warnings
193193
@model function testmodel_missing4(x)
194194
x[1] ~ Bernoulli(0.5)
195-
global varinfo_ = _varinfo
196-
global sampler_ = _sampler
197-
global model_ = _model
198-
global context_ = _context
199-
global rng_ = _rng
200-
global lp = getlogp(_varinfo)
195+
global varinfo_ = __varinfo__
196+
global sampler_ = __sampler__
197+
global model_ = __model__
198+
global context_ = __context__
199+
global rng_ = __rng__
200+
global lp = getlogp(__varinfo__)
201201
return x
202202
end false
203203
lpold = lp
@@ -236,7 +236,7 @@ end
236236
function makemodel(p)
237237
@model function testmodel(x)
238238
x[1] ~ Bernoulli(p)
239-
global lp = getlogp(_varinfo)
239+
global lp = getlogp(__varinfo__)
240240
return x
241241
end
242242
return testmodel
@@ -295,11 +295,11 @@ end
295295

296296
@testset "macros within model" begin
297297
# Macro expansion
298-
@model function demo()
298+
@model function demo1()
299299
@mymodel1(y ~ Uniform())
300300
end
301301

302-
@test haskey(VarInfo(demo()), @varname(x))
302+
@test haskey(VarInfo(demo1()), @varname(x))
303303

304304
# Interpolation
305305
# Will fail if:
@@ -308,9 +308,9 @@ end
308308
# 2. `@mymodel` is expanded before entire `@model` has been
309309
# expanded => errors since `MyModelStruct` is not a distribution,
310310
# and hence `tilde_observe` errors.
311-
@model function demo()
311+
@model function demo2()
312312
$(@mymodel2(y ~ Uniform()))
313313
end
314-
@test demo()() == 42
314+
@test demo2()() == 42
315315
end
316316
end

test/threadsafe.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
x = rand(10_000)
4040

4141
@model function wthreads(x)
42-
global vi_ = _varinfo
42+
global vi_ = __varinfo__
4343
x[1] ~ Normal(0, 1)
4444
Threads.@threads for i in 2:length(x)
4545
x[i] ~ Normal(x[i-1], 1)
@@ -70,7 +70,7 @@
7070
SampleFromPrior(), DefaultContext())
7171

7272
@model function wothreads(x)
73-
global vi_ = _varinfo
73+
global vi_ = __varinfo__
7474
x[1] ~ Normal(0, 1)
7575
for i in 2:length(x)
7676
x[i] ~ Normal(x[i-1], 1)

test/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
@testset "utils.jl" begin
22
@testset "addlogprob!" begin
33
@model function testmodel()
4-
global lp_before = getlogp(_varinfo)
4+
global lp_before = getlogp(__varinfo__)
55
@addlogprob!(42)
6-
global lp_after = getlogp(_varinfo)
6+
global lp_after = getlogp(__varinfo__)
77
end
88

99
model = testmodel()

0 commit comments

Comments
 (0)