Skip to content

Commit cc2e84d

Browse files
committed
Simplify varinfo constructors
1 parent 632ab09 commit cc2e84d

File tree

2 files changed

+4
-30
lines changed

2 files changed

+4
-30
lines changed

src/sampler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function AbstractMCMC.step(
8787
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
8888
)
8989
# Sample initial values.
90-
vi = typed_varinfo(rng, model, initialsampler(spl), DefaultContext())
90+
vi = VarInfo(rng, model, initialsampler(spl), DefaultContext())
9191

9292
# Update the parameters if provided.
9393
if initial_params !== nothing

src/varinfo.jl

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -163,42 +163,16 @@ function has_varnamedvector(vi::VarInfo)
163163
(vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata)))
164164
end
165165

166-
"""
167-
untyped_varinfo([rng, ]model[, sampler, context])
168-
169-
Return an untyped `VarInfo` instance for the model `model`.
170-
"""
171-
function untyped_varinfo(
172-
rng::Random.AbstractRNG,
173-
model::Model,
174-
sampler::AbstractSampler=SampleFromPrior(),
175-
context::AbstractContext=DefaultContext(),
176-
metadata::Union{Metadata,VarNamedVector}=Metadata(),
177-
)
178-
varinfo = VarInfo(metadata)
179-
return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
180-
end
181-
function untyped_varinfo(
182-
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
183-
)
184-
return untyped_varinfo(Random.default_rng(), model, args...)
185-
end
186-
187-
"""
188-
typed_varinfo([rng, ]model[, sampler, context])
189-
190-
Return a typed `VarInfo` instance for the model `model`.
191-
"""
192-
typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...))
193-
194166
function VarInfo(
195167
rng::Random.AbstractRNG,
196168
model::Model,
197169
sampler::AbstractSampler=SampleFromPrior(),
198170
context::AbstractContext=DefaultContext(),
199171
metadata::Union{Metadata,VarNamedVector}=Metadata(),
200172
)
201-
return typed_varinfo(rng, model, sampler, context, metadata)
173+
varinfo = VarInfo(metadata)
174+
untyped_varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
175+
return TypedVarInfo(untyped_varinfo)
202176
end
203177
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)
204178

0 commit comments

Comments
 (0)