@@ -32,7 +32,7 @@ using Random: Random
3232"""
3333 DynamicPPL.LogDensityFunction(
3434 model::Model,
35- getlogdensity::Function =getlogjoint_internal,
35+ getlogdensity::Any =getlogjoint_internal,
3636 varinfo::AbstractVarInfo=VarInfo(model);
3737 adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
3838 )
@@ -47,7 +47,9 @@ using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gra
4747`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD
4848backend type, then `logdensity_and_gradient` is also implemented.
4949
50- There are several options for `getlogdensity` that are 'supported' out of the box:
50+ `getlogdensity` should be a callable which takes a single argument: a `VarInfo`, and returns
51+ a `Real` corresponding to the log density of interest. There are several functions in
52+ DynamicPPL that are 'supported' out of the box:
5153
5254- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term
5355 for any variables that have been linked in the provided VarInfo.
@@ -145,7 +147,7 @@ struct LogDensityFunction{
145147 Tlink,
146148 M<: Model ,
147149 AD<: Union{ADTypes.AbstractADType,Nothing} ,
148- F<: Function ,
150+ F,
149151 N<: NamedTuple ,
150152 ADP<: Union{Nothing,DI.GradientPrep} ,
151153 # type of the vector passed to logdensity functions
@@ -161,7 +163,7 @@ struct LogDensityFunction{
161163
162164 function LogDensityFunction(
163165 model:: Model ,
164- getlogdensity:: Function = getlogjoint_internal,
166+ getlogdensity:: Any = getlogjoint_internal,
165167 varinfo:: AbstractVarInfo = VarInfo(model);
166168 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
167169 )
@@ -219,12 +221,12 @@ end
219221# LogDensityProblems.jl interface #
220222# ##################################
221223"""
222- ldf_accs(getlogdensity::Function )
224+ ldf_accs(getlogdensity::Any )
223225
224226Determine which accumulators are needed for fast evaluation with the given
225- `getlogdensity` function .
227+ `getlogdensity` callable .
226228"""
227- ldf_accs(:: Function ) = default_accumulators()
229+ ldf_accs(:: Any ) = default_accumulators()
228230ldf_accs(:: typeof (getlogjoint_internal)) = default_accumulators()
229231function ldf_accs(:: typeof (getlogjoint))
230232 return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator()))
235237ldf_accs(:: typeof (getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
236238ldf_accs(:: typeof (getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
237239
238- struct LogDensityAt{Tlink,M<: Model ,F<: Function ,N<: NamedTuple }
240+ struct LogDensityAt{Tlink,M<: Model ,F,N<: NamedTuple }
239241 model:: M
240242 getlogdensity:: F
241243 iden_varname_ranges:: N
0 commit comments