1
+ using AbstractMCMC
1
2
using LogDensityProblems, Distributions, LinearAlgebra, Random
2
3
using OrderedCollections
3
4
5
+ # #
6
+
7
+ # TODO : introduce some kind of parameter format, for instance, a flattened vector
8
+ # then define some kind of function to transform the flattened vector into model's representation
9
+
4
10
struct Gibbs <: AbstractMCMC.AbstractSampler
5
11
sampler_map:: OrderedDict
6
12
end
7
13
8
14
struct GibbsState
9
- values :: NamedTuple
15
+ vi :: NamedTuple
10
16
states:: OrderedDict
11
17
end
12
18
@@ -15,31 +21,91 @@ struct GibbsTransition
15
21
end
16
22
17
23
function AbstractMCMC. step (
18
- rng:: AbstractRNG , model, sampler:: Gibbs , args... ; initial_params:: NamedTuple , kwargs...
24
+ rng:: AbstractRNG ,
25
+ logdensity_model:: AbstractMCMC.LogDensityModel ,
26
+ spl:: Gibbs ,
27
+ args... ;
28
+ initial_params:: NamedTuple ,
29
+ kwargs... ,
19
30
)
20
31
states = OrderedDict ()
21
- for group in keys (sampler. sampler_map)
22
- sampler = sampler. sampler_map[group]
23
- cond_val = NamedTuple {group} ([initial_params[g] for g in group]. .. )
24
- trans, state = AbstractMCMC. step (
25
- rng, condition (model, cond_val), sampler, args... ; kwargs...
32
+ for group in keys (spl. sampler_map)
33
+ sub_spl = spl. sampler_map[group]
34
+
35
+ vars_to_be_conditioned_on = setdiff (keys (initial_params), group)
36
+ cond_val = NamedTuple {Tuple(vars_to_be_conditioned_on)} (
37
+ Tuple ([initial_params[g] for g in vars_to_be_conditioned_on])
38
+ )
39
+ params_val = NamedTuple {Tuple(group)} (Tuple ([initial_params[g] for g in group]))
40
+ sub_state = last (
41
+ AbstractMCMC. step (
42
+ rng,
43
+ AbstractMCMC. LogDensityModel (
44
+ condition (logdensity_model. logdensity, cond_val)
45
+ ),
46
+ sub_spl,
47
+ args... ;
48
+ initial_params= flatten (params_val),
49
+ kwargs... ,
50
+ ),
26
51
)
27
- states[group] = state
52
+ states[group] = sub_state
28
53
end
29
54
return GibbsTransition (initial_params), GibbsState (initial_params, states)
30
55
end
31
56
32
57
function AbstractMCMC. step (
33
- rng:: AbstractRNG , model, sampler:: Gibbs , state:: GibbsState , args... ; kwargs...
58
+ rng:: AbstractRNG ,
59
+ logdensity_model:: AbstractMCMC.LogDensityModel ,
60
+ spl:: Gibbs ,
61
+ state:: GibbsState ,
62
+ args... ;
63
+ kwargs... ,
34
64
)
35
- for group in collect (keys (sampler. sampler_map))
36
- sampler = sampler. sampler_map[group]
37
- state = state. states[group]
38
- trans, state = AbstractMCMC. step (
39
- rng, condition (model, state. values[group]), sampler, state, args... ; kwargs...
65
+ vi = state. vi
66
+ for group in keys (spl. sampler_map)
67
+ for (group, sub_state) in state. states
68
+ vi = merge (vi, unflatten (getparams (sub_state), group))
69
+ end
70
+ sub_spl = spl. sampler_map[group]
71
+ sub_state = state. states[group]
72
+ group_complement = setdiff (keys (vi), group)
73
+ cond_val = NamedTuple {Tuple(group_complement)} (
74
+ Tuple ([vi[g] for g in group_complement])
75
+ )
76
+ sub_state = last (
77
+ AbstractMCMC. step (
78
+ rng,
79
+ AbstractMCMC. LogDensityModel (
80
+ condition (logdensity_model. logdensity, cond_val)
81
+ ),
82
+ sub_spl,
83
+ sub_state,
84
+ args... ;
85
+ kwargs... ,
86
+ ),
40
87
)
41
- # TODO : what values to condition on here? stored where?
42
- state. states[group] = state
88
+ state. states[group] = sub_state
43
89
end
44
- return nothing
90
+ for sub_state in values (state. states)
91
+ vi = merge (vi, getparams (sub_state))
92
+ end
93
+ return GibbsTransition (vi), GibbsState (vi, state. states)
45
94
end
95
+
96
+ # # tests
97
+
98
+ gmm = GMM ((; x= x))
99
+
100
+ samples = sample (
101
+ gmm,
102
+ Gibbs (
103
+ OrderedDict (
104
+ (:z ,) => PriorMH (product_distribution ([Categorical ([0.3 , 0.7 ]) for _ in 1 : 60 ])),
105
+ (:w ,) => PriorMH (Dirichlet (2 , 1.0 )),
106
+ (:μ , :w ) => RWMH (1 ),
107
+ ),
108
+ ),
109
+ 10000 ;
110
+ initial_params= (z= rand (Categorical ([0.3 , 0.7 ]), 60 ), μ= [0.0 , 1.0 ], w= [0.3 , 0.7 ]),
111
+ )
0 commit comments