Skip to content

Commit e60873a

Browse files
committed
Fix _evaluate!! correctly to handle submodels
1 parent b1a7650 commit e60873a

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

src/fastldf.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,23 @@ struct FastLDF{
133133
end
134134
end
135135

136+
function _evaluate!!(
137+
model::Model{F,A,D,M,TA,TD,<:FastLDFContext}, varinfo::OnlyAccsVarInfo
138+
) where {F,A,D,M,TA,TD}
139+
args = map(maybe_deepcopy, model.args)
140+
return model.f(model, varinfo, args...; model.defaults...)
141+
end
142+
maybe_deepcopy(@nospecialize(x)) = x
143+
function maybe_deepcopy(x::AbstractArray{T}) where {T}
144+
if T >: Missing
145+
# avoid overwriting missing elements of model arguments when
146+
# evaluating the model.
147+
deepcopy(x)
148+
else
149+
x
150+
end
151+
end
152+
136153
struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
137154
_model::M
138155
_getlogdensity::F
@@ -147,21 +164,10 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
147164
accs = AccumulatorTuple((
148165
LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator()
149166
))
150-
# _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs))
151-
args = map(maybe_deepcopy, model.args)
152-
_, vi = model.f(model, OnlyAccsVarInfo(accs), args...; model.defaults...)
167+
_, vi = _evaluate!!(model, OnlyAccsVarInfo(accs))
153168
return f._getlogdensity(vi)
154169
end
155170

156-
maybe_deepcopy(@nospecialize(x)) = x
157-
function maybe_deepcopy(x::AbstractArray{T}) where {T}
158-
if T >: Missing
159-
deepcopy(x)
160-
else
161-
x
162-
end
163-
end
164-
165171
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
166172
return FastLogDensityAt(
167173
fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges

0 commit comments

Comments
 (0)