Skip to content

Commit 2da3e03

Browse files
committed
Remove runmodel!
1 parent 3d2c1eb commit 2da3e03

File tree

15 files changed

+38
-40
lines changed

15 files changed

+38
-40
lines changed

src/DynamicPPL.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ export VarName,
6868
getargnames,
6969
getdefaults,
7070
getgenerator,
71-
runmodel!,
7271
# Samplers
7372
Sampler,
7473
SampleFromPrior,

src/model.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,25 @@ function Model{missings}(
109109
return Model{missings}(model.f, args, modelgen)
110110
end
111111

112+
"""
113+
(model::Model)([spl = SampleFromPrior(), ctx = DefaultContext()])
112114
115+
Sample from `model` using the sampler `spl`.
116+
"""
113117
function (model::Model)(
114-
vi::AbstractVarInfo=VarInfo(),
115118
spl::AbstractSampler=SampleFromPrior(),
116119
ctx::AbstractContext=DefaultContext()
117120
)
118-
return model.f(model, vi, spl, ctx)
121+
return model(VarInfo(), spl, ctx)
119122
end
120123

121-
122124
"""
123-
runmodel!(model::Model, vi::AbstractVarInfo[, spl::AbstractSampler, ctx::AbstractContext])
125+
(model::Model)(vi::AbstractVarInfo[, spl = SampleFromPrior(), ctx = DefaultContext()])
124126
125127
Sample from `model` using the sampler `spl` storing the sample and log joint probability in `vi`.
126128
Resets the `vi` and increases `spl`s `state.eval_num`.
127129
"""
128-
function runmodel!(
129-
model::Model,
130+
function (model::Model)(
130131
vi::AbstractVarInfo,
131132
spl::AbstractSampler=SampleFromPrior(),
132133
ctx::AbstractContext=DefaultContext()
@@ -135,8 +136,7 @@ function runmodel!(
135136
if has_eval_num(spl)
136137
spl.state.eval_num += 1
137138
end
138-
model(vi, spl, ctx)
139-
return vi
139+
return model.f(model, vi, spl, ctx)
140140
end
141141

142142

src/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ const TypedVarInfo = VarInfo{<:NamedTuple}
107107

108108
function VarInfo(model::Model, ctx = DefaultContext())
109109
vi = VarInfo()
110-
runmodel!(model, vi, SampleFromPrior(), ctx)
110+
model(vi, SampleFromPrior(), ctx)
111111
return TypedVarInfo(vi)
112112
end
113113

test/Turing/Turing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Markdown, Libtask, MacroTools
1616
using Tracker: Tracker
1717

1818
import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
19-
import DynamicPPL: getspace, runmodel!
19+
import DynamicPPL: getspace
2020

2121
const PROGRESS = Ref(true)
2222
function turnprogress(switch::Bool)

test/Turing/contrib/inference/dynamichmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ function AbstractMCMC.sample_init!(
6060
gradient_logp(x, spl.state.vi, model, spl)
6161
end
6262

63-
runmodel!(model, spl.state.vi, SampleFromUniform())
63+
model(spl.state.vi, SampleFromUniform())
6464

6565
if spl.selector.tag == :default
6666
link!(spl.state.vi, spl)
67-
runmodel!(model, spl.state.vi, spl)
67+
model(spl.state.vi, spl)
6868
end
6969

7070
# Set the parameters to a starting value.
@@ -145,4 +145,4 @@ function AbstractMCMC.psample(
145145
end
146146
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
147147
chain_type=chain_type, progress=false, kwargs...)
148-
end
148+
end

test/Turing/contrib/inference/sghmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ function step(
8585
Turing.DEBUG && @debug "X-> R..."
8686
if spl.selector.tag != :default
8787
link!(vi, spl)
88-
runmodel!(model, vi, spl)
88+
model(vi, spl)
8989
end
9090

9191
Turing.DEBUG && @debug "recording old variables..."
@@ -198,7 +198,7 @@ function step(
198198
Turing.DEBUG && @debug "X-> R..."
199199
if spl.selector.tag != :default
200200
link!(vi, spl)
201-
runmodel!(model, vi, spl)
201+
model(vi, spl)
202202
end
203203

204204
Turing.DEBUG && @debug "recording old variables..."

test/Turing/core/Core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Distributions, LinearAlgebra
66
using ..Utilities, Reexport
77
using Tracker: Tracker
88
using ..Turing: Turing
9-
using DynamicPPL: Model, runmodel!,
9+
using DynamicPPL: Model,
1010
AbstractSampler, Sampler, SampleFromPrior
1111
using LinearAlgebra: copytri!
1212
using Bijectors: PDMatDistribution

test/Turing/core/ad.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ function gradient_logp(
9292
logp_old = getlogp(vi)
9393
function f(θ)
9494
new_vi = VarInfo(vi, sampler, θ)
95-
logp = getlogp(runmodel!(model, new_vi, sampler))
95+
model(new_vi, sampler)
96+
logp = getlogp(new_vi)
9697
setlogp!(vi, ForwardDiff.value(logp))
9798
return logp
9899
end
@@ -119,7 +120,8 @@ function gradient_logp(
119120
# Specify objective function.
120121
function f(θ)
121122
new_vi = VarInfo(vi, sampler, θ)
122-
return getlogp(runmodel!(model, new_vi, sampler))
123+
model(new_vi, sampler)
124+
return getlogp(new_vi)
123125
end
124126

125127
# Compute forward and reverse passes.

test/Turing/core/compat/zygote.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ function gradient_logp(
1616
# Specify objective function.
1717
function f(θ)
1818
new_vi = VarInfo(vi, sampler, θ)
19-
return getlogp(runmodel!(model, new_vi, sampler))
19+
model(new_vi, sampler)
20+
return getlogp(new_vi)
2021
end
2122

2223
# Compute forward and reverse passes.

test/Turing/inference/Inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ..Core, ..Utilities
44
using DynamicPPL: Metadata, _tail, VarInfo, TypedVarInfo,
55
islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
66
settrans!, _getvns, getdist, split_var_str, CACHERESET, AbstractSampler,
7-
Model, runmodel!, Sampler, SampleFromPrior, SampleFromUniform,
7+
Model, Sampler, SampleFromPrior, SampleFromUniform,
88
Selector, AbstractSamplerState, DefaultContext, PriorContext,
99
LikelihoodContext, MiniBatchContext, set_flag!, unset_flag!, NamedDist, NoDist
1010
using Distributions, Libtask, Bijectors

0 commit comments

Comments
 (0)