Skip to content

Commit b8d033a

Browse files
committed
Fixed keytype argument
1 parent 8591375 commit b8d033a

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

src/pointwise_logdensities.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,19 @@ function _pointwise_tilde_assume(
196196
end
197197

198198
"""
199-
pointwise_logdensities(model::Model, chain::Chains, keytype = String)
199+
pointwise_logdensities(model::Model, chain::Chains[, keytype::Type, context::AbstractContext])
200200
201201
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
202202
with keys corresponding to symbols of the variables, and values being matrices
203203
of shape `(num_chains, num_samples)`.
204204
205-
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
206-
Currently, only `String` and `VarName` are supported.
205+
# Arguments
206+
- `model`: the `Model` to run.
207+
- `chain`: the `Chains` to run the model on.
208+
- `keytype`: the type of the keys used in the returned `OrderedDict` are.
209+
Currently, only `String` and `VarName` are supported.
210+
- `context`: the context to use when running the model. Default: `DefaultContext`.
211+
The [`leafcontext`](@ref) is used to decide which variables to include.
207212
208213
# Notes
209214
Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ`
@@ -294,7 +299,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])],
294299
295300
"""
296301
function pointwise_logdensities(
297-
model::Model, chain, context::AbstractContext=DefaultContext(), keytype::Type{T}=String
302+
model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext()
298303
) where {T}
299304
# Get the data by executing the model once
300305
vi = VarInfo(model)
@@ -329,7 +334,7 @@ function pointwise_logdensities(
329334
end
330335

331336
"""
332-
pointwise_loglikelihoods(model, chain[, context])
337+
pointwise_loglikelihoods(model, chain[, keytype, context])
333338
334339
Compute the pointwise log-likelihoods of the model given the chain.
335340
@@ -341,14 +346,14 @@ See also: [`pointwise_logdensities`](@ref).
341346
function pointwise_loglikelihoods(
342347
model::Model,
343348
chain,
344-
context::AbstractContext=LikelihoodContext(),
345349
keytype::Type{T}=String,
350+
context::AbstractContext=LikelihoodContext(),
346351
) where {T}
347352
if !(leafcontext(context) isa LikelihoodContext)
348353
throw(ArgumentError("Leaf context should be a LikelihoodContext"))
349354
end
350355

351-
return pointwise_logdensities(model, chain, context, keytype)
356+
return pointwise_logdensities(model, chain, T, context)
352357
end
353358

354359
function pointwise_loglikelihoods(
@@ -362,7 +367,7 @@ function pointwise_loglikelihoods(
362367
end
363368

364369
"""
365-
pointwise_prior_logdensities(model, chain[, context])
370+
pointwise_prior_logdensities(model, chain[, keytype, context])
366371
367372
Compute the pointwise log-prior-densities of the model given the chain.
368373
@@ -372,13 +377,13 @@ including the prior terms.
372377
See also: [`pointwise_logdensities`](@ref).
373378
"""
374379
function pointwise_prior_logdensities(
375-
model::Model, chain, context::AbstractContext=PriorContext(), keytype::Type{T}=String
380+
model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext()
376381
) where {T}
377382
if !(leafcontext(context) isa PriorContext)
378383
throw(ArgumentError("Leaf context should be a PriorContext"))
379384
end
380385

381-
return pointwise_logdensities(model, chain, context, keytype)
386+
return pointwise_logdensities(model, chain, T, context)
382387
end
383388

384389
function pointwise_prior_logdensities(

0 commit comments

Comments
 (0)