Skip to content

Commit 15ac97b

Browse files
committed
WIP other changes
1 parent c8db5a4 commit 15ac97b

20 files changed

+190
-348
lines changed

ext/DynamicPPLJETExt.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL
44
using JET: JET
55

66
function DynamicPPL.Experimental.is_suitable_varinfo(
7-
model::DynamicPPL.Model,
8-
context::DynamicPPL.AbstractContext,
9-
varinfo::DynamicPPL.AbstractVarInfo;
10-
only_ddpl::Bool=true,
7+
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
118
)
129
# Let's make sure that both evaluation and sampling doesn't result in type errors.
13-
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
14-
model, varinfo, context
15-
)
10+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
1611
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
1712
# This way we don't just fall back to untyped if the user's code is the issue.
1813
result = if only_ddpl
@@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo(
2419
end
2520

2621
function DynamicPPL.Experimental._determine_varinfo_jet(
27-
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
22+
model::DynamicPPL.Model; only_ddpl::Bool=true
2823
)
24+
# Use SamplingContext to test type stability.
25+
sampling_model = DynamicPPL.contextualize(
26+
model, DynamicPPL.SamplingContext(model.context)
27+
)
28+
2929
# First we try with the typed varinfo.
30-
varinfo = DynamicPPL.typed_varinfo(model, context)
30+
varinfo = DynamicPPL.typed_varinfo(sampling_model)
3131

3232
# Let's make sure that both evaluation and sampling doesn't result in type errors.
3333
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
model, context, varinfo; only_ddpl
34+
model, varinfo; only_ddpl
3535
)
3636

3737
if !issuccess
@@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
4646
else
4747
# Warn the user that we can't use the type stable one.
4848
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(model, context)
49+
DynamicPPL.untyped_varinfo(model)
5050
end
5151
end
5252

src/compiler.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
1+
const INTERNALNAMES = (:__model__, :__varinfo__)
22

33
"""
44
need_concretize(expr)
@@ -63,9 +63,9 @@ used in its place.
6363
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
6464
return quote
6565
if $(DynamicPPL.contextual_isassumption)(
66-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
66+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
6767
)
68-
# Considered an assumption by `__context__` which means either:
68+
# Considered an assumption by `__model__.context` which means either:
6969
# 1. We hit the default implementation, e.g. using `DefaultContext`,
7070
# which in turn means that we haven't considered if it's one of
7171
# the model arguments, hence we need to check this.
@@ -116,7 +116,7 @@ end
116116
isfixed(expr, vn) = false
117117
function isfixed(::Union{Symbol,Expr}, vn)
118118
return :($(DynamicPPL.contextual_isfixed)(
119-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
119+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
120120
))
121121
end
122122

@@ -417,7 +417,7 @@ function generate_assign(left, right)
417417
return quote
418418
$right_val = $right
419419
if $(DynamicPPL.is_extracting_values)(__varinfo__)
420-
$vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left)))
420+
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
421421
__varinfo__ = $(map_accumulator!!)(
422422
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
423423
)
@@ -431,7 +431,11 @@ function generate_tilde_literal(left, right)
431431
@gensym value
432432
return quote
433433
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
434-
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__
434+
__model__.context,
435+
$(DynamicPPL.check_tilde_rhs)($right),
436+
$left,
437+
nothing,
438+
__varinfo__,
435439
)
436440
$value
437441
end
@@ -456,20 +460,20 @@ function generate_tilde(left, right)
456460
$isassumption = $(DynamicPPL.isassumption(left, vn))
457461
if $(DynamicPPL.isfixed(left, vn))
458462
$left = $(DynamicPPL.getfixed_nested)(
459-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
463+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
460464
)
461465
elseif $isassumption
462466
$(generate_tilde_assume(left, dist, vn))
463467
else
464468
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
465469
if !$(DynamicPPL.inargnames)($vn, __model__)
466470
$left = $(DynamicPPL.getconditioned_nested)(
467-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
471+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
468472
)
469473
end
470474

471475
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
472-
__context__,
476+
__model__.context,
473477
$(DynamicPPL.check_tilde_rhs)($dist),
474478
$(maybe_view(left)),
475479
$vn,
@@ -494,7 +498,7 @@ function generate_tilde_assume(left, right, vn)
494498

495499
return quote
496500
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
497-
__context__,
501+
__model__.context,
498502
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
499503
__varinfo__,
500504
)
@@ -652,11 +656,7 @@ function build_output(modeldef, linenumbernode)
652656

653657
# Add the internal arguments to the user-specified arguments (positional + keywords).
654658
evaluatordef[:args] = vcat(
655-
[
656-
:(__model__::$(DynamicPPL.Model)),
657-
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
658-
:(__context__::$(DynamicPPL.AbstractContext)),
659-
],
659+
[:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))],
660660
args,
661661
)
662662

src/debug_utils.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,13 @@ function has_static_constraints(
535535
end
536536

537537
"""
538-
gen_evaluator_call_with_types(model[, varinfo, context])
538+
gen_evaluator_call_with_types(model[, varinfo])
539539
540540
Generate the evaluator call and the types of the arguments.
541541
542542
# Arguments
543543
- `model::Model`: The model whose evaluator is of interest.
544544
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
545-
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
546545
547546
# Returns
548547
A 2-tuple with the following elements:
@@ -551,11 +550,9 @@ A 2-tuple with the following elements:
551550
- `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator.
552551
"""
553552
function gen_evaluator_call_with_types(
554-
model::Model,
555-
varinfo::AbstractVarInfo=VarInfo(model),
556-
context::AbstractContext=DefaultContext(),
553+
model::Model, varinfo::AbstractVarInfo=VarInfo(model)
557554
)
558-
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
555+
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo)
559556
return if isempty(kwargs)
560557
(model.f, Base.typesof(args...))
561558
else
@@ -564,7 +561,7 @@ function gen_evaluator_call_with_types(
564561
end
565562

566563
"""
567-
model_warntype(model[, varinfo, context]; optimize=true)
564+
model_warntype(model[, varinfo]; optimize=true)
568565
569566
Check the type stability of the model's evaluator, warning about any potential issues.
570567
@@ -573,23 +570,19 @@ This simply calls `@code_warntype` on the model's evaluator, filling in internal
573570
# Arguments
574571
- `model::Model`: The model to check.
575572
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
576-
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
577573
578574
# Keyword Arguments
579575
- `optimize::Bool`: Whether to generate optimized code. Default: `false`.
580576
"""
581577
function model_warntype(
582-
model::Model,
583-
varinfo::AbstractVarInfo=VarInfo(model),
584-
context::AbstractContext=DefaultContext();
585-
optimize::Bool=false,
578+
model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=false
586579
)
587-
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context)
580+
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo)
588581
return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize)
589582
end
590583

591584
"""
592-
model_typed(model[, varinfo, context]; optimize=true)
585+
model_typed(model[, varinfo]; optimize=true)
593586
594587
Return the type inference for the model's evaluator.
595588
@@ -598,18 +591,14 @@ This simply calls `@code_typed` on the model's evaluator, filling in internal ar
598591
# Arguments
599592
- `model::Model`: The model to check.
600593
- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`.
601-
- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref).
602594
603595
# Keyword Arguments
604596
- `optimize::Bool`: Whether to generate optimized code. Default: `true`.
605597
"""
606598
function model_typed(
607-
model::Model,
608-
varinfo::AbstractVarInfo=VarInfo(model),
609-
context::AbstractContext=DefaultContext();
610-
optimize::Bool=true,
599+
model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=true
611600
)
612-
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context)
601+
ftype, argtypes = gen_evaluator_call_with_types(model, varinfo)
613602
return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize))
614603
end
615604

src/experimental.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ using DynamicPPL: DynamicPPL
44

55
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
66
"""
7-
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)
7+
is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...)
88
9-
Check if the `model` supports evaluation using the provided `context` and `varinfo`.
9+
Check if the `model` supports evaluation using the provided `varinfo`.
1010
1111
!!! warning
1212
Loading JET.jl is required before calling this function.
1313
1414
# Arguments
1515
- `model`: The model to verify the support for.
16-
- `context`: The context to use for the model evaluation.
1716
- `varinfo`: The varinfo to verify the support for.
1817
1918
# Keyword Arguments
@@ -29,7 +28,7 @@ function is_suitable_varinfo end
2928
function _determine_varinfo_jet end
3029

3130
"""
32-
determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true)
31+
determine_suitable_varinfo(model; only_ddpl::Bool=true)
3332
3433
Return a suitable varinfo for the given `model`.
3534
@@ -41,7 +40,6 @@ See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref).
4140
4241
# Arguments
4342
- `model`: The model for which to determine the varinfo.
44-
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.
4543
4644
# Keyword Arguments
4745
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
@@ -85,14 +83,10 @@ julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
8583
true
8684
```
8785
"""
88-
function determine_suitable_varinfo(
89-
model::DynamicPPL.Model,
90-
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext();
91-
only_ddpl::Bool=true,
92-
)
86+
function determine_suitable_varinfo(model::DynamicPPL.Model, only_ddpl::Bool=true)
9387
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
9488
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
95-
_determine_varinfo_jet(model, context; only_ddpl)
89+
_determine_varinfo_jet(model; only_ddpl)
9690
else
9791
# Warn the user.
9892
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."

src/extract_priors.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ function extract_priors(rng::Random.AbstractRNG, model::Model)
116116
# workaround for the fact that `order` is still hardcoded in VarInfo, and hence you
117117
# can't push new variables without knowing the num_produce. Remove this when possible.
118118
varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator()))
119-
varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng)))
119+
new_model = contextualize(model, SamplingContext(rng, model.context))
120+
varinfo = last(evaluate!!(new_model, varinfo))
120121
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
121122
end
122123

@@ -135,6 +136,6 @@ function extract_priors(model::Model, varinfo::AbstractVarInfo)
135136
varinfo = setaccs!!(
136137
deepcopy(varinfo), (PriorDistributionAccumulator(), NumProduceAccumulator())
137138
)
138-
varinfo = last(evaluate!!(model, varinfo, DefaultContext()))
139+
varinfo = last(evaluate!!(model, varinfo))
139140
return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors
140141
end

0 commit comments

Comments
 (0)