Skip to content

Commit e3ded5c

Browse files
committed
fixes
1 parent 0100d6d commit e3ded5c

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export AbstractVarInfo,
102102
# LogDensityFunction
103103
LogDensityFunction,
104104
# Contexts
105+
contextualize,
105106
SamplingContext,
106107
DefaultContext,
107108
PrefixContext,

src/experimental.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
8383
true
8484
```
8585
"""
86-
function determine_suitable_varinfo(model::DynamicPPL.Model, only_ddpl::Bool=true)
86+
function determine_suitable_varinfo(model::DynamicPPL.Model; only_ddpl::Bool=true)
8787
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
8888
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
8989
_determine_varinfo_jet(model; only_ddpl)

src/model.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k
8585
return Model(f, args, NamedTuple(kwargs), context)
8686
end
8787

88+
"""
89+
contextualize(model::Model, context::AbstractContext)
90+
91+
Return a new `Model` with the same evaluation function and other arguments, but
92+
with its underlying context set to `context`.
93+
"""
8894
function contextualize(model::Model, context::AbstractContext)
8995
return Model(model.f, model.args, model.defaults, context)
9096
end

test/ad.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,8 @@ using DynamicPPL: LogDensityFunction
110110
# Compiling the ReverseDiff tape used to fail here
111111
spl = Sampler(MyEmptyAlg())
112112
vi = VarInfo(model)
113-
ldf = LogDensityFunction(
114-
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
115-
)
113+
sampling_model = contextualize(model, SamplingConext(model.context))
114+
ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true))
116115
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
117116
end
118117

0 commit comments

Comments
 (0)