Skip to content

Commit c40a193

Browse files
committed
Add the option of no fallback for ParamsInit
1 parent be8a1b3 commit c40a193

File tree

2 files changed

+83
-36
lines changed

2 files changed

+83
-36
lines changed

src/contexts/init.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,25 +68,35 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::Uniform
6868
end
6969

7070
"""
71-
ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit())
72-
ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit())
71+
ParamsInit(
72+
params::Union{AbstractDict{<:VarName},NamedTuple},
73+
default::Union{AbstractInitStrategy,Nothing}=PriorInit()
74+
)
7375
7476
Obtain new values by extracting them from the given dictionary or NamedTuple.
77+
7578
The parameter `default` specifies how new values are to be obtained if they
76-
cannot be found in `params`, or they are specified as `missing`. The default
77-
for `default` is `PriorInit()`.
79+
cannot be found in `params`, or they are specified as `missing`. `default`
80+
can either be an initialisation strategy itself, in which case it will be
81+
used to obtain new values, or it can be `nothing`, in which case an error
82+
will be thrown. The default for `default` is `PriorInit()`.
7883
7984
!!! note
80-
These values must be provided in the space of the untransformed distribution.
85+
The values in `params` must be provided in the space of the untransformed
86+
distribution.
8187
"""
82-
struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy
88+
struct ParamsInit{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy
8389
params::P
8490
default::S
85-
function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy)
91+
function ParamsInit(
92+
params::AbstractDict{<:VarName}, default::Union{AbstractInitStrategy,Nothing}
93+
)
8694
return new{typeof(params),typeof(default)}(params, default)
8795
end
8896
ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit())
89-
function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit())
97+
function ParamsInit(
98+
params::NamedTuple, default::Union{AbstractInitStrategy,Nothing}=PriorInit()
99+
)
90100
return ParamsInit(to_varname_dict(params), default)
91101
end
92102
end
@@ -98,13 +108,16 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::Param
98108
return if hasvalue(p.params, vn, dist)
99109
x = getvalue(p.params, vn, dist)
100110
if x === missing
111+
p.default === nothing &&
112+
error("A `missing` value was provided for the variable `$(vn)`.")
101113
init(rng, vn, dist, p.default)
102114
else
103115
# TODO(penelopeysm): Since x is user-supplied, maybe we could also
104116
# check here that the type / size of x matches the dist?
105117
x
106118
end
107119
else
120+
p.default === nothing && error("No value was provided for the variable `$(vn)`.")
108121
init(rng, vn, dist, p.default)
109122
end
110123
end

test/contexts.jl

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -435,12 +435,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
435435

436436
@testset "InitContext" begin
437437
empty_varinfos = [
438-
VarInfo(),
439-
DynamicPPL.typed_varinfo(VarInfo()),
440-
VarInfo(DynamicPPL.VarNamedVector()),
441-
DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())),
442-
SimpleVarInfo(),
443-
SimpleVarInfo(Dict{VarName,Any}()),
438+
("untyped+metadata", VarInfo()),
439+
("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())),
440+
("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())),
441+
(
442+
"typed+VNV",
443+
DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())),
444+
),
445+
("SVI+NamedTuple", SimpleVarInfo()),
446+
("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())),
444447
]
445448

446449
@model function test_init_model()
@@ -455,7 +458,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
455458
# Check that init!! can generate values that weren't there
456459
# previously.
457460
model = test_init_model()
458-
for empty_vi in empty_varinfos
461+
@testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos
459462
this_vi = deepcopy(empty_vi)
460463
_, vi = DynamicPPL.init!!(model, this_vi, strategy)
461464
@test Set(keys(vi)) == Set([@varname(x), @varname(y)])
@@ -475,7 +478,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
475478
@testset "replacing old values: $(typeof(strategy))" begin
476479
# Check that init!! can overwrite values that were already there.
477480
model = test_init_model()
478-
for empty_vi in empty_varinfos
481+
@testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos
479482
# start by generating some rubbish values
480483
vi = deepcopy(empty_vi)
481484
old_x, old_y = 100000.00, [300000.00, 500000.00]
@@ -494,7 +497,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
494497
function test_rng_respected(strategy::AbstractInitStrategy)
495498
@testset "check that RNG is respected: $(typeof(strategy))" begin
496499
model = test_init_model()
497-
for empty_vi in empty_varinfos
500+
@testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos
498501
_, vi1 = DynamicPPL.init!!(
499502
Xoshiro(468), model, deepcopy(empty_vi), strategy
500503
)
@@ -613,29 +616,60 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
613616
end
614617

615618
@testset "given only partial parameters" begin
616-
# In this case, we expect `ParamsInit` to use the value of x, and
617-
# generate a new value for y.
618619
my_x = 1.0
619620
params_nt = (; x=my_x)
620621
params_dict = Dict(@varname(x) => my_x)
621622
model = test_init_model()
622-
for empty_vi in empty_varinfos
623-
_, vi = DynamicPPL.init!!(
624-
Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt)
625-
)
626-
@test vi[@varname(x)] == my_x
627-
nt_y = vi[@varname(y)]
628-
@test nt_y isa AbstractVector{<:Real}
629-
@test length(nt_y) == 2
630-
_, vi = DynamicPPL.init!!(
631-
Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict)
632-
)
633-
@test vi[@varname(x)] == my_x
634-
dict_y = vi[@varname(y)]
635-
@test dict_y isa AbstractVector{<:Real}
636-
@test length(dict_y) == 2
637-
# the values should be different since we used different seeds
638-
@test dict_y != nt_y
623+
@testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos
624+
@testset "with PriorInit fallback" begin
625+
_, vi = DynamicPPL.init!!(
626+
Xoshiro(468),
627+
model,
628+
deepcopy(empty_vi),
629+
ParamsInit(params_nt, PriorInit()),
630+
)
631+
@test vi[@varname(x)] == my_x
632+
nt_y = vi[@varname(y)]
633+
@test nt_y isa AbstractVector{<:Real}
634+
@test length(nt_y) == 2
635+
_, vi = DynamicPPL.init!!(
636+
Xoshiro(469),
637+
model,
638+
deepcopy(empty_vi),
639+
ParamsInit(params_dict, PriorInit()),
640+
)
641+
@test vi[@varname(x)] == my_x
642+
dict_y = vi[@varname(y)]
643+
@test dict_y isa AbstractVector{<:Real}
644+
@test length(dict_y) == 2
645+
# the values should be different since we used different seeds
646+
@test dict_y != nt_y
647+
end
648+
649+
@testset "with no fallback" begin
650+
# These just don't have an entry for `y`.
651+
@test_throws ErrorException DynamicPPL.init!!(
652+
model, deepcopy(empty_vi), ParamsInit(params_nt, nothing)
653+
)
654+
@test_throws ErrorException DynamicPPL.init!!(
655+
model, deepcopy(empty_vi), ParamsInit(params_dict, nothing)
656+
)
657+
# We also explicitly test the case where `y = missing`.
658+
params_nt_missing = (; x=my_x, y=missing)
659+
params_dict_missing = Dict(
660+
@varname(x) => my_x, @varname(y) => missing
661+
)
662+
@test_throws ErrorException DynamicPPL.init!!(
663+
model,
664+
deepcopy(empty_vi),
665+
ParamsInit(params_nt_missing, nothing),
666+
)
667+
@test_throws ErrorException DynamicPPL.init!!(
668+
model,
669+
deepcopy(empty_vi),
670+
ParamsInit(params_dict_missing, nothing),
671+
)
672+
end
639673
end
640674
end
641675
end

0 commit comments

Comments
 (0)