Skip to content

Commit ec02e50

Browse files
committed
Fix literally everything else that I broke
1 parent c8cee86 commit ec02e50

34 files changed

+346
-523
lines changed

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8181
end
8282

8383
adbackend = to_backend(adbackend)
84-
context = DynamicPPL.DefaultContext()
8584

8685
if islinked
8786
vi = DynamicPPL.link(vi, model)
8887
end
8988

90-
f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
89+
f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend)
9190
# The parameters at which we evaluate f.
9291
θ = vi[:]
9392

docs/src/api.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ getargnames
3636
getmissings
3737
```
3838

39+
The context of a model can be set using [`contextualize`](@ref):
40+
41+
```@docs
42+
contextualize
43+
```
44+
3945
## Evaluation
4046

4147
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
@@ -438,13 +444,21 @@ DynamicPPL.varname_and_value_leaves
438444

439445
### Evaluation Contexts
440446

441-
Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref).
447+
Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref).
442448

443449
```@docs
444450
AbstractPPL.evaluate!!
445451
```
446452

447-
The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function.
453+
This method mutates the `varinfo` used for execution.
454+
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
455+
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
456+
457+
```@docs
458+
DynamicPPL.sample!!
459+
```
460+
461+
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
448462
Contexts are subtypes of `AbstractPPL.AbstractContext`.
449463

450464
```@docs

ext/DynamicPPLJETExt.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL
44
using JET: JET
55

66
function DynamicPPL.Experimental.is_suitable_varinfo(
7-
model::DynamicPPL.Model,
8-
context::DynamicPPL.AbstractContext,
9-
varinfo::DynamicPPL.AbstractVarInfo;
10-
only_ddpl::Bool=true,
7+
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
118
)
129
# Let's make sure that both evaluation and sampling doesn't result in type errors.
13-
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
14-
model, varinfo, context
15-
)
10+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
1611
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
1712
# This way we don't just fall back to untyped if the user's code is the issue.
1813
result = if only_ddpl
@@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo(
2419
end
2520

2621
function DynamicPPL.Experimental._determine_varinfo_jet(
27-
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
22+
model::DynamicPPL.Model; only_ddpl::Bool=true
2823
)
24+
# Use SamplingContext to test type stability.
25+
sampling_model = DynamicPPL.contextualize(
26+
model, DynamicPPL.SamplingContext(model.context)
27+
)
28+
2929
# First we try with the typed varinfo.
30-
varinfo = DynamicPPL.typed_varinfo(model, context)
30+
varinfo = DynamicPPL.typed_varinfo(sampling_model)
3131

3232
# Let's make sure that both evaluation and sampling doesn't result in type errors.
3333
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
model, context, varinfo; only_ddpl
34+
sampling_model, varinfo; only_ddpl
3535
)
3636

3737
if !issuccess
@@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
4646
else
4747
# Warn the user that we can't use the type stable one.
4848
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(model, context)
49+
DynamicPPL.untyped_varinfo(sampling_model)
5050
end
5151
end
5252

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function DynamicPPL.predict(
115115
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116116
predictive_samples = map(iters) do (sample_idx, chain_idx)
117117
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
model(rng, varinfo, DynamicPPL.SampleFromPrior())
118+
varinfo = last(DynamicPPL.sample!!(rng, model, varinfo))
119119

120120
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121121
varname_vals = mapreduce(

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export AbstractVarInfo,
102102
# LogDensityFunction
103103
LogDensityFunction,
104104
# Contexts
105+
contextualize,
105106
SamplingContext,
106107
DefaultContext,
107108
PrefixContext,

src/compiler.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
1+
const INTERNALNAMES = (:__model__, :__varinfo__)
22

33
"""
44
need_concretize(expr)
@@ -63,9 +63,9 @@ used in its place.
6363
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
6464
return quote
6565
if $(DynamicPPL.contextual_isassumption)(
66-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
66+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
6767
)
68-
# Considered an assumption by `__context__` which means either:
68+
# Considered an assumption by `__model__.context` which means either:
6969
# 1. We hit the default implementation, e.g. using `DefaultContext`,
7070
# which in turn means that we haven't considered if it's one of
7171
# the model arguments, hence we need to check this.
@@ -116,7 +116,7 @@ end
116116
isfixed(expr, vn) = false
117117
function isfixed(::Union{Symbol,Expr}, vn)
118118
return :($(DynamicPPL.contextual_isfixed)(
119-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
119+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
120120
))
121121
end
122122

@@ -417,7 +417,7 @@ function generate_assign(left, right)
417417
return quote
418418
$right_val = $right
419419
if $(DynamicPPL.is_extracting_values)(__varinfo__)
420-
$vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left)))
420+
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
421421
__varinfo__ = $(map_accumulator!!)(
422422
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
423423
)
@@ -431,7 +431,11 @@ function generate_tilde_literal(left, right)
431431
@gensym value
432432
return quote
433433
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
434-
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__
434+
__model__.context,
435+
$(DynamicPPL.check_tilde_rhs)($right),
436+
$left,
437+
nothing,
438+
__varinfo__,
435439
)
436440
$value
437441
end
@@ -456,20 +460,20 @@ function generate_tilde(left, right)
456460
$isassumption = $(DynamicPPL.isassumption(left, vn))
457461
if $(DynamicPPL.isfixed(left, vn))
458462
$left = $(DynamicPPL.getfixed_nested)(
459-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
463+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
460464
)
461465
elseif $isassumption
462466
$(generate_tilde_assume(left, dist, vn))
463467
else
464468
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
465469
if !$(DynamicPPL.inargnames)($vn, __model__)
466470
$left = $(DynamicPPL.getconditioned_nested)(
467-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
471+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
468472
)
469473
end
470474

471475
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
472-
__context__,
476+
__model__.context,
473477
$(DynamicPPL.check_tilde_rhs)($dist),
474478
$(maybe_view(left)),
475479
$vn,
@@ -494,7 +498,7 @@ function generate_tilde_assume(left, right, vn)
494498

495499
return quote
496500
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
497-
__context__,
501+
__model__.context,
498502
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
499503
__varinfo__,
500504
)
@@ -652,11 +656,7 @@ function build_output(modeldef, linenumbernode)
652656

653657
# Add the internal arguments to the user-specified arguments (positional + keywords).
654658
evaluatordef[:args] = vcat(
655-
[
656-
:(__model__::$(DynamicPPL.Model)),
657-
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
658-
:(__context__::$(DynamicPPL.AbstractContext)),
659-
],
659+
[:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))],
660660
args,
661661
)
662662

src/debug_utils.jl

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,7 @@ A context used for checking validity of a model.
131131
# Fields
132132
$(FIELDS)
133133
"""
134-
struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext
135-
"model that is being run"
136-
model::M
134+
struct DebugContext{C<:AbstractContext} <: AbstractContext
137135
"context used for running the model"
138136
context::C
139137
"mapping from varnames to the number of times they have been seen"
@@ -149,7 +147,6 @@ struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext
149147
end
150148

151149
function DebugContext(
152-
model::Model,
153150
context::AbstractContext=DefaultContext();
154151
varnames_seen=OrderedDict{VarName,Int}(),
155152
statements=Vector{Stmt}(),
@@ -158,7 +155,6 @@ function DebugContext(
158155
record_varinfo=false,
159156
)
160157
return DebugContext(
161-
model,
162158
context,
163159
varnames_seen,
164160
statements,
@@ -344,7 +340,7 @@ function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int})
344340
end
345341

346342
# A check we run on the model before evaluating it.
347-
function check_model_pre_evaluation(context::DebugContext, model::Model)
343+
function check_model_pre_evaluation(model::Model)
348344
issuccess = true
349345
# If something is in the model arguments, then it should NOT be in `condition`,
350346
# nor should there be any symbol present in `condition` that has the same symbol.
@@ -361,8 +357,8 @@ function check_model_pre_evaluation(context::DebugContext, model::Model)
361357
return issuccess
362358
end
363359

364-
function check_model_post_evaluation(context::DebugContext, model::Model)
365-
return check_varnames_seen(context.varnames_seen)
360+
function check_model_post_evaluation(model::Model)
361+
return check_varnames_seen(model.context.varnames_seen)
366362
end
367363

368364
"""
@@ -438,25 +434,23 @@ function check_model_and_trace(
438434
rng::Random.AbstractRNG,
439435
model::Model;
440436
varinfo=VarInfo(),
441-
context=SamplingContext(rng),
442437
error_on_failure=false,
443438
kwargs...,
444439
)
445440
# Execute the model with the debug context.
446441
debug_context = DebugContext(
447-
model, context; error_on_failure=error_on_failure, kwargs...
442+
SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs...
448443
)
444+
debug_model = DynamicPPL.contextualize(model, debug_context)
449445

450446
# Perform checks before evaluating the model.
451-
issuccess = check_model_pre_evaluation(debug_context, model)
447+
issuccess = check_model_pre_evaluation(debug_model)
452448

453449
# Force single-threaded execution.
454-
retval, varinfo_result = DynamicPPL.evaluate_threadunsafe!!(
455-
model, varinfo, debug_context
456-
)
450+
DynamicPPL.evaluate_threadunsafe!!(debug_model, varinfo)
457451

458452
# Perform checks after evaluating the model.
459-
issuccess &= check_model_post_evaluation(debug_context, model)
453+
issuccess &= check_model_post_evaluation(debug_model)
460454

461455
if !issuccess && error_on_failure
462456
error("model check failed")
@@ -535,14 +529,13 @@ function has_static_constraints(
535529
end
536530

537531
"""
538-
gen_evaluator_call_with_types(model[, varinfo, context])
532+
gen_evaluator_call_with_types(model[, varinfo])
539533
540534
Generate the evaluator call and the types of the arguments.
541535
542536
# Arguments
543537
- `model::Model`: The model whose evaluator is of interest.
544538
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
545-
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
546539
547540
# Returns
548541
A 2-tuple with the following elements:
@@ -551,11 +544,9 @@ A 2-tuple with the following elements:
551544
- `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator.
552545
"""
553546
function gen_evaluator_call_with_types(
554-
model::Model,
555-
varinfo::AbstractVarInfo=VarInfo(model),
556-
context::AbstractContext=DefaultContext(),
547+
model::Model, varinfo::AbstractVarInfo=VarInfo(model)
557548
)
558-
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
549+
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo)
559550
return if isempty(kwargs)
560551
(model.f, Base.typesof(args...))
561552
else
@@ -564,7 +555,7 @@ function gen_evaluator_call_with_types(
564555
end
565556

566557
"""
567-
model_warntype(model[, varinfo, context]; optimize=true)
558+
model_warntype(model[, varinfo]; optimize=true)
568559
569560
Check the type stability of the model's evaluator, warning about any potential issues.
570561
@@ -573,23 +564,19 @@ This simply calls `@code_warntype` on the model's evaluator, filling in internal
573564
# Arguments
574565
- `model::Model`: The model to check.
575566
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
576-
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
577567
578568
# Keyword Arguments
579569
- `optimize::Bool`: Whether to generate optimized code. Default: `false`.
580570
"""
581571
function model_warntype(
582-
model::Model,
583-
varinfo::AbstractVarInfo=VarInfo(model),
584-
context::AbstractContext=DefaultContext();
585-
optimize::Bool=false,
572+
model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=false
586573
)
587-
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context)
574+
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo)
588575
return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize)
589576
end
590577

591578
"""
592-
model_typed(model[, varinfo, context]; optimize=true)
579+
model_typed(model[, varinfo]; optimize=true)
593580
594581
Return the type inference for the model's evaluator.
595582
@@ -598,18 +585,14 @@ This simply calls `@code_typed` on the model's evaluator, filling in internal ar
598585
# Arguments
599586
- `model::Model`: The model to check.
600587
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
601-
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
602588
603589
# Keyword Arguments
604590
- `optimize::Bool`: Whether to generate optimized code. Default: `true`.
605591
"""
606592
function model_typed(
607-
model::Model,
608-
varinfo::AbstractVarInfo=VarInfo(model),
609-
context::AbstractContext=DefaultContext();
610-
optimize::Bool=true,
593+
model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=true
611594
)
612-
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context)
595+
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo)
613596
return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize))
614597
end
615598

0 commit comments

Comments
 (0)