Skip to content

Commit 9fbe09a

Browse files
authored
Merge pull request #63 from TuringLang/model
Less surprising model return values
2 parents 1c071df + 18a0353 commit 9fbe09a

File tree

16 files changed

+57
-52
lines changed

16 files changed

+57
-52
lines changed

src/DynamicPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ export AbstractVarInfo,
6969
getargnames,
7070
getdefaults,
7171
getgenerator,
72-
runmodel!,
7372
# Samplers
7473
Sampler,
7574
SampleFromPrior,

src/compiler.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ function generate_tilde(left, right, model_info)
278278
ctx = model_info[:main_body_names][:ctx]
279279
sampler = model_info[:main_body_names][:sampler]
280280

281-
@gensym tmpright
281+
@gensym tmpright tmpleft
282282
top = [:($tmpright = $right),
283283
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
284284
|| throw(ArgumentError($DISTMSG)))]
@@ -290,8 +290,8 @@ function generate_tilde(left, right, model_info)
290290
assumption = [
291291
:($out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
292292
$vi)),
293-
:($left = $out[1]),
294-
:($(DynamicPPL.acclogp!)($vi, $out[2]))
293+
:($(DynamicPPL.acclogp!)($vi, $out[2])),
294+
:($left = $out[1])
295295
]
296296

297297
# It can only be an observation if the LHS is an argument of the model
@@ -303,11 +303,13 @@ function generate_tilde(left, right, model_info)
303303
if $isassumption
304304
$(assumption...)
305305
else
306+
$tmpleft = $left
306307
$(DynamicPPL.acclogp!)(
307308
$vi,
308-
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vn,
309+
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vn,
309310
$inds, $vi)
310311
)
312+
$tmpleft
311313
end
312314
end
313315
end
@@ -321,10 +323,12 @@ function generate_tilde(left, right, model_info)
321323
# If the LHS is a literal, it is always an observation
322324
return quote
323325
$(top...)
326+
$tmpleft = $left
324327
$(DynamicPPL.acclogp!)(
325328
$vi,
326-
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
329+
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
327330
)
331+
$tmpleft
328332
end
329333
end
330334

@@ -341,7 +345,7 @@ function generate_dot_tilde(left, right, model_info)
341345
ctx = model_info[:main_body_names][:ctx]
342346
sampler = model_info[:main_body_names][:sampler]
343347

344-
@gensym tmpright
348+
@gensym tmpright tmpleft
345349
top = [:($tmpright = $right),
346350
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
347351
|| throw(ArgumentError($DISTMSG)))]
@@ -353,8 +357,8 @@ function generate_dot_tilde(left, right, model_info)
353357
assumption = [
354358
:($out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
355359
$vn, $inds, $vi)),
356-
:($left .= $out[1]),
357-
:($(DynamicPPL.acclogp!)($vi, $out[2]))
360+
:($(DynamicPPL.acclogp!)($vi, $out[2])),
361+
:($left .= $out[1])
358362
]
359363

360364
# It can only be an observation if the LHS is an argument of the model
@@ -366,11 +370,13 @@ function generate_dot_tilde(left, right, model_info)
366370
if $isassumption
367371
$(assumption...)
368372
else
373+
$tmpleft = $left
369374
$(DynamicPPL.acclogp!)(
370375
$vi,
371-
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left,
376+
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft,
372377
$vn, $inds, $vi)
373378
)
379+
$tmpleft
374380
end
375381
end
376382
end
@@ -384,10 +390,12 @@ function generate_dot_tilde(left, right, model_info)
384390
# If the LHS is a literal, it is always an observation
385391
return quote
386392
$(top...)
393+
$tmpleft = $left
387394
$(DynamicPPL.acclogp!)(
388395
$vi,
389-
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
396+
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
390397
)
398+
$tmpleft
391399
end
392400
end
393401

@@ -443,7 +451,6 @@ function build_output(model_info)
443451
$ctx::$(DynamicPPL.AbstractContext),
444452
)
445453
$unwrap_data_expr
446-
$(DynamicPPL.resetlogp!)($vi)
447454
$main_body
448455
end
449456

src/model.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,34 +109,34 @@ function Model{missings}(
109109
return Model{missings}(model.f, args, modelgen)
110110
end
111111

112+
"""
113+
(model::Model)([spl = SampleFromPrior(), ctx = DefaultContext()])
112114
115+
Sample from `model` using the sampler `spl`.
116+
"""
113117
function (model::Model)(
114-
vi::AbstractVarInfo=VarInfo(),
115118
spl::AbstractSampler=SampleFromPrior(),
116119
ctx::AbstractContext=DefaultContext()
117120
)
118-
return model.f(model, vi, spl, ctx)
121+
return model(VarInfo(), spl, ctx)
119122
end
120123

121-
122124
"""
123-
runmodel!(model::Model, vi::AbstractVarInfo[, spl::AbstractSampler, ctx::AbstractContext])
125+
(model::Model)(vi::AbstractVarInfo[, spl = SampleFromPrior(), ctx = DefaultContext()])
124126
125127
Sample from `model` using the sampler `spl` storing the sample and log joint probability in `vi`.
126128
Resets the `vi` and increases `spl`s `state.eval_num`.
127129
"""
128-
function runmodel!(
129-
model::Model,
130+
function (model::Model)(
130131
vi::AbstractVarInfo,
131132
spl::AbstractSampler=SampleFromPrior(),
132133
ctx::AbstractContext=DefaultContext()
133134
)
134-
setlogp!(vi, 0)
135+
resetlogp!(vi)
135136
if has_eval_num(spl)
136137
spl.state.eval_num += 1
137138
end
138-
model(vi, spl, ctx)
139-
return vi
139+
return model.f(model, vi, spl, ctx)
140140
end
141141

142142

src/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ const TypedVarInfo = VarInfo{<:NamedTuple}
107107

108108
function VarInfo(model::Model, ctx = DefaultContext())
109109
vi = VarInfo()
110-
runmodel!(model, vi, SampleFromPrior(), ctx)
110+
model(vi, SampleFromPrior(), ctx)
111111
return TypedVarInfo(vi)
112112
end
113113

test/Turing/Turing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Markdown, Libtask, MacroTools
1616
using Tracker: Tracker
1717

1818
import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
19-
import DynamicPPL: getspace, runmodel!
19+
import DynamicPPL: getspace
2020

2121
const PROGRESS = Ref(true)
2222
function turnprogress(switch::Bool)

test/Turing/contrib/inference/dynamichmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ function AbstractMCMC.sample_init!(
6060
gradient_logp(x, spl.state.vi, model, spl)
6161
end
6262

63-
runmodel!(model, spl.state.vi, SampleFromUniform())
63+
model(spl.state.vi, SampleFromUniform())
6464

6565
if spl.selector.tag == :default
6666
link!(spl.state.vi, spl)
67-
runmodel!(model, spl.state.vi, spl)
67+
model(spl.state.vi, spl)
6868
end
6969

7070
# Set the parameters to a starting value.
@@ -145,4 +145,4 @@ function AbstractMCMC.psample(
145145
end
146146
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
147147
chain_type=chain_type, progress=false, kwargs...)
148-
end
148+
end

test/Turing/contrib/inference/sghmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function step(
8585
Turing.DEBUG && @debug "X-> R..."
8686
if spl.selector.tag != :default
8787
link!(vi, spl)
88-
runmodel!(model, vi, spl)
88+
model(vi, spl)
8989
end
9090

9191
Turing.DEBUG && @debug "recording old variables..."
@@ -198,7 +198,7 @@ function step(
198198
Turing.DEBUG && @debug "X-> R..."
199199
if spl.selector.tag != :default
200200
link!(vi, spl)
201-
runmodel!(model, vi, spl)
201+
model(vi, spl)
202202
end
203203

204204
Turing.DEBUG && @debug "recording old variables..."

test/Turing/core/Core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Distributions, LinearAlgebra
66
using ..Utilities, Reexport
77
using Tracker: Tracker
88
using ..Turing: Turing
9-
using DynamicPPL: Model, runmodel!,
9+
using DynamicPPL: Model,
1010
AbstractSampler, Sampler, SampleFromPrior
1111
using LinearAlgebra: copytri!
1212
using Bijectors: PDMatDistribution

test/Turing/core/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ function gradient_logp(
9292
logp_old = getlogp(vi)
9393
function f(θ)
9494
new_vi = VarInfo(vi, sampler, θ)
95-
logp = getlogp(runmodel!(model, new_vi, sampler))
95+
model(new_vi, sampler)
96+
logp = getlogp(new_vi)
9697
setlogp!(vi, ForwardDiff.value(logp))
9798
return logp
9899
end
@@ -119,7 +120,8 @@ function gradient_logp(
119120
# Specify objective function.
120121
function f(θ)
121122
new_vi = VarInfo(vi, sampler, θ)
122-
return getlogp(runmodel!(model, new_vi, sampler))
123+
model(new_vi, sampler)
124+
return getlogp(new_vi)
123125
end
124126

125127
# Compute forward and reverse passes.

test/Turing/core/compat/zygote.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function gradient_logp(
1616
# Specify objective function.
1717
function f(θ)
1818
new_vi = VarInfo(vi, sampler, θ)
19-
return getlogp(runmodel!(model, new_vi, sampler))
19+
model(new_vi, sampler)
20+
return getlogp(new_vi)
2021
end
2122

2223
# Compute forward and reverse passes.

0 commit comments

Comments
 (0)