@@ -44,42 +44,16 @@ function condition(gmm::GMM, conditioned_values::NamedTuple)
44
44
return ConditionedGMM (gmm. data, conditioned_values)
45
45
end
46
46
47
- function _logdensity (gmm:: Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}} , params)
48
- return log_joint (;
49
- μ= gmm. conditioned_values. μ, w= gmm. conditioned_values. w, z= params. z, x= gmm. data. x
50
- )
51
- end
52
-
53
- function _logdensity (gmm:: ConditionedGMM{(:z,)} , params)
54
- return log_joint (; μ= params. μ, w= params. w, z= gmm. conditioned_values. z, x= gmm. data. x)
55
- end
56
-
57
- function LogDensityProblems. logdensity (
58
- gmm:: Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}} ,
59
- params_vec:: AbstractVector ,
60
- )
61
- @assert length (params_vec) == 60
62
- return _logdensity (gmm, (; z= params_vec))
63
- end
64
- function LogDensityProblems. logdensity (
65
- gmm:: ConditionedGMM{(:z,)} , params_vec:: AbstractVector
66
- )
67
- @assert length (params_vec) == 4 " length(params_vec) = $(length (params_vec)) "
68
- return _logdensity (gmm, (; μ= params_vec[1 : 2 ], w= params_vec[3 : 4 ]))
69
- end
70
-
71
- function LogDensityProblems. dimension (gmm:: GMM )
72
- return 4 + size (gmm. data. x, 1 )
73
- end
74
-
75
- function LogDensityProblems. dimension (
76
- gmm:: Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}
77
- )
78
- return 4
79
- end
80
-
81
- function LogDensityProblems. dimension (gmm:: ConditionedGMM{(:z,)} )
82
- return size (gmm. data. x, 1 )
47
+ function LogDensityProblems. logdensity (gmm:: ConditionedGMM{names} , params:: AbstractVector ) where {names}
48
+ if Set (names) == Set ([:μ , :w ]) # conditioned on μ, w, so params are z
49
+ return log_joint (; μ= gmm. conditioned_values. μ, w= gmm. conditioned_values. w, z= params, x= gmm. data. x)
50
+ elseif Set (names) == Set ([:z , :w ]) # conditioned on z, w, so params are μ
51
+ return log_joint (; μ= params, w= gmm. conditioned_values. w, z= gmm. conditioned_values. z, x= gmm. data. x)
52
+ elseif Set (names) == Set ([:z , :μ ]) # conditioned on z, μ, so params are w
53
+ return log_joint (; μ= gmm. conditioned_values. μ, w= params, z= gmm. conditioned_values. z, x= gmm. data. x)
54
+ else
55
+ error (" Unsupported conditioning configuration." )
56
+ end
83
57
end
84
58
85
59
function LogDensityProblems. capabilities (:: GMM )
@@ -91,41 +65,22 @@ function LogDensityProblems.capabilities(::ConditionedGMM)
91
65
end
92
66
93
67
function flatten (nt:: NamedTuple )
94
- if Set (keys (nt)) == Set ([:μ , :w ])
95
- return vcat (nt. μ, nt. w)
96
- elseif Set (keys (nt)) == Set ([:z ])
97
- return nt. z
98
- else
99
- error ()
100
- end
68
+ return only (values (nt))
101
69
end
102
70
103
71
function unflatten (vec:: AbstractVector , group:: Tuple )
104
- if Set (group) == Set ([:μ , :w ])
105
- return (; μ= vec[1 : 2 ], w= vec[3 : 4 ])
106
- elseif Set (group) == Set ([:z ])
107
- return (; z= vec)
108
- else
109
- error ()
110
- end
72
+ return NamedTuple ((only (group) => vec,))
111
73
end
112
74
113
- # sampler's states to internal representation
114
- # ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?)
115
-
116
- # the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed
117
75
function recompute_logprob!! (gmm:: ConditionedGMM , vals, state)
118
- return setlogp! (state, _logdensity (gmm, vals))
76
+ return setlogp!! (state, LogDensityProblems . logdensity (gmm, vals))
119
77
end
120
78
121
79
# # test using Turing
122
80
123
81
# data generation
124
82
125
- using Distributions
126
83
using FillArrays
127
- using LinearAlgebra
128
- using Random
129
84
130
85
w = [0.5 , 0.5 ]
131
86
μ = [- 3.5 , 0.5 ]
0 commit comments