1
+ # TODO : Make `UniformSampling` and `Prior` algs + just use `Sampler`
2
+ # That would let us use all defaults for Sampler, combine it with other samplers etc.
1
3
"""
2
4
Robust initialization method for model parameters in Hamiltonian samplers.
3
5
"""
@@ -17,55 +19,123 @@ function init(rng, dist, ::SampleFromUniform, n::Int)
17
19
return istransformable (dist) ? inittrans (rng, dist, n) : rand (rng, dist, n)
18
20
end
19
21
20
- """
21
- has_eval_num(spl::AbstractSampler)
22
-
23
- Check whether `spl` has a field called `eval_num` in its state variables or not.
24
- """
25
- has_eval_num (spl:: SampleFromUniform ) = false
26
- has_eval_num (spl:: SampleFromPrior ) = false
27
- has_eval_num (spl:: AbstractSampler ) = :eval_num in fieldnames (typeof (spl. state))
28
-
29
- """
30
- An abstract type that mutable sampler state structs inherit from.
31
- """
32
- abstract type AbstractSamplerState end
33
-
34
22
"""
35
23
Sampler{T}
36
24
37
- Generic interface for implementing inference algorithms.
38
- An implementation of an algorithm should include the following:
39
-
40
- 1. A type specifying the algorithm and its parameters, derived from InferenceAlgorithm
41
- 2. A method of `sample` function that produces results of inference, which is where actual inference happens.
25
+ Generic sampler type for inference algorithms of type `T` in DynamicPPL.
42
26
43
- DynamicPPL translates models to chunks that call the modelling functions at specified points.
44
- The dispatch is based on the value of a `sampler` variable.
45
- To include a new inference algorithm implements the requirements mentioned above in a separate file,
46
- then include that file at the end of this one.
27
+ `Sampler` should implement the AbstractMCMC interface, and in particular
28
+ [`AbstractMCMC.step`](@ref). A default implementation of the initial sampling step is
29
+ provided that supports resuming sampling from a previous state and setting initial
30
+ parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref)
31
+ for loading previous states and actually performing the initial sampling step,
32
+ respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref)
33
+ that specifies how the initial parameter values are sampled if they are not provided.
34
+ By default, values are sampled from the prior.
47
35
"""
48
- mutable struct Sampler{T, S<: AbstractSamplerState } <: AbstractSampler
49
- alg :: T
50
- info :: Dict{Symbol, Any} # sampler infomation
51
- selector :: Selector
52
- state :: S
36
+ struct Sampler{T} <: AbstractSampler
37
+ alg:: T
38
+ selector:: Selector # Can we remove it?
39
+ # TODO : add space such that we can integrate existing external samplers in DynamicPPL
53
40
end
54
41
Sampler (alg) = Sampler (alg, Selector ())
55
42
Sampler (alg, model:: Model ) = Sampler (alg, model, Selector ())
56
- Sampler (alg, model:: Model , s:: Selector ) = Sampler (alg, model, s)
43
+ Sampler (alg, model:: Model , s:: Selector ) = Sampler (alg, s)
57
44
58
45
# AbstractMCMC interface for SampleFromUniform and SampleFromPrior
59
-
60
- function AbstractMCMC. step! (
46
+ function AbstractMCMC. step (
61
47
rng:: Random.AbstractRNG ,
62
48
model:: Model ,
63
49
sampler:: Union{SampleFromUniform,SampleFromPrior} ,
64
- :: Integer ,
65
- transition;
50
+ state = nothing ;
66
51
kwargs...
67
52
)
68
53
vi = VarInfo ()
69
- model (vi, sampler)
70
- return vi
54
+ model (rng, vi, sampler)
55
+ return vi, nothing
56
+ end
57
+
58
+ # initial step: general interface for resuming and
59
+ function AbstractMCMC. step (
60
+ rng:: Random.AbstractRNG ,
61
+ model:: Model ,
62
+ spl:: Sampler ;
63
+ resume_from = nothing ,
64
+ kwargs...
65
+ )
66
+ if resume_from != = nothing
67
+ state = loadstate (resume_from)
68
+ return AbstractMCMC. step (rng, model, spl, state; kwargs... )
69
+ end
70
+
71
+ # Sample initial values.
72
+ _spl = initialsampler (spl)
73
+ vi = VarInfo (rng, model, _spl)
74
+
75
+ # Update the parameters if provided.
76
+ if haskey (kwargs, :init_params )
77
+ initialize_parameters! (vi, kwargs[:init_params ], spl)
78
+
79
+ # Update joint log probability.
80
+ model (rng, vi, _spl)
81
+ end
82
+
83
+ return initialstep (rng, model, spl, vi; kwargs... )
84
+ end
85
+
86
+ """
87
+ loadstate(data)
88
+
89
+ Load sampler state from `data`.
90
+ """
91
+ function loadstate end
92
+
93
+ """
94
+ initialsampler(sampler::Sampler)
95
+
96
+ Return the sampler that is used for generating the initial parameters when sampling with
97
+ `sampler`.
98
+
99
+ By default, it returns an instance of [`SampleFromPrior`](@ref).
100
+ """
101
+ initialsampler (spl:: Sampler ) = SampleFromPrior ()
102
+
103
+ function initialize_parameters! (vi:: AbstractVarInfo , init_params, spl:: Sampler )
104
+ @debug " Using passed-in initial variable values" init_params
105
+
106
+ # Flatten parameters.
107
+ init_theta = mapreduce (vcat, init_params) do x
108
+ vec ([x;])
109
+ end
110
+
111
+ # Get all values.
112
+ linked = islinked (vi, spl)
113
+ linked && invlink! (vi, spl)
114
+ theta = vi[spl]
115
+ length (theta) == length (init_theta_flat) ||
116
+ error (" Provided initial value doesn't match the dimension of the model" )
117
+
118
+ # Update values that are provided.
119
+ for i in 1 : length (init_theta)
120
+ x = init_theta[i]
121
+ if x != = missing
122
+ theta[i] = x
123
+ end
124
+ end
125
+
126
+ # Update in `vi`.
127
+ vi[spl] = theta
128
+ linked && link! (vi, spl)
129
+
130
+ return
71
131
end
132
+
133
+ """
134
+ initialstep(rng, model, sampler, varinfo; kwargs...)
135
+
136
+ Perform the initial sampling step of the `sampler` for the `model`.
137
+
138
+ The `varinfo` contains the initial samples, which can be provided by the user or
139
+ sampled randomly.
140
+ """
141
+ function initialstep end
0 commit comments