Skip to content

Commit ae297c3

Browse files
committed
Implement two-argument version
1 parent bfbf727 commit ae297c3

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/interface.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@ chainsstack(c) = c
2020
chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c)
2121

2222
"""
23-
getadtype(sampler::AbstractSampler)
23+
getadtype(s::AbstractSampler)
24+
getadtype(m::AbstractModel, s::AbstractSampler)
2425
25-
If the sampler specifies an automatic differentiation (AD) backend to use, this
26-
function should return the corresponding `ADTypes.AbstractADType`.
26+
Specify the `ADTypes.AbstractADType` to be used when sampling from model `m` using sampler `s`.
27+
28+
If the model is not relevant, then the implementation of AbstractSampler can
29+
directly overload the single-argument method `getadtype(s::AbstractSampler)`.
2730
2831
By default, this returns `nothing`.
2932
"""
3033
getadtype(::AbstractSampler) = nothing
34+
getadtype(::AbstractModel, spl::AbstractSampler) = getadtype(spl)
3135

3236
"""
3337
bundle_samples(samples, model, sampler, state, chain_type[; kwargs...])

0 commit comments

Comments
 (0)