@@ -76,16 +76,19 @@ def __init__(self, num_blocks, num_times, num_factors, num_subjects,
7676
7777 params = utils .vardict ({
7878 'factors' : {
79- 'mu' : torch .zeros (self .num_blocks , self .embedding_dim ),
80- 'sigma' : torch .sqrt (torch .rand (self .num_blocks , self .embedding_dim )),
79+ 'mu' : torch .zeros (self .num_subjects , self .embedding_dim ),
80+ 'sigma' : torch .ones (self .num_subjects , self .embedding_dim ) * \
81+ tfa_models .SOURCE_LOG_WIDTH_STD_DEV ,
8182 },
8283 'subject' : {
8384 'mu' : torch .zeros (self .num_subjects , self .embedding_dim ),
84- 'sigma' : torch .sqrt (torch .rand (self .num_blocks , self .embedding_dim )),
85+ 'sigma' : torch .ones (self .num_subjects , self .embedding_dim ) * \
86+ tfa_models .SOURCE_WEIGHT_STD_DEV ,
8587 },
8688 'task' : {
8789 'mu' : torch .zeros (self .num_tasks , self .embedding_dim ),
88- 'sigma' : torch .sqrt (torch .rand (self .num_blocks , self .embedding_dim )),
90+ 'sigma' : torch .ones (self .num_tasks , self .embedding_dim ) * \
91+ tfa_models .SOURCE_WEIGHT_STD_DEV ,
8992 },
9093 'template' : {
9194 'factor_centers' : {
@@ -95,7 +98,8 @@ def __init__(self, num_blocks, num_times, num_factors, num_subjects,
9598 'factor_log_widths' : {
9699 'mu' : hyper_means ['factor_log_widths' ] * \
97100 torch .ones (self ._num_factors ),
98- 'sigma' : torch .sqrt (torch .rand (self ._num_factors ))
101+ 'sigma' : torch .ones (self ._num_factors ) *
102+ tfa_models .SOURCE_LOG_WIDTH_STD_DEV ,
99103 }
100104 },
101105 'block' : {
@@ -136,32 +140,28 @@ def __init__(self, num_factors, block_subjects, block_tasks, num_blocks=1,
136140 embedding_dim )
137141 self .factors_embedding = nn .Sequential (
138142 nn .Linear (self ._embedding_dim , self ._num_factors ),
139- nn .Tanhshrink (),
140- nn .Linear (self ._num_factors , self ._num_factors * 8 ),
143+ nn .Softsign (),
141144 )
145+ self .centers_embedding = nn .Linear (self ._num_factors ,
146+ self ._num_factors * 3 )
147+ self .log_widths_embedding = nn .Linear (self ._num_factors ,
148+ self ._num_factors )
142149 self .weights_embedding = nn .Sequential (
143150 nn .Linear (self ._embedding_dim * 2 , self ._num_factors ),
144- nn .Tanhshrink (),
151+ nn .Softsign (),
145152 nn .Linear (self ._num_factors , self ._num_factors * 2 ),
146153 )
147154 self .softplus = nn .Softplus ()
148155
149156 self .epsilon = nn .Parameter (torch .Tensor ([tfa_models .VOXEL_NOISE ]))
150157
151158 if hyper_means is not None :
152- self .weights_embedding [- 1 ].bias = nn .Parameter (torch .cat (
153- (hyper_means ['weights' ].mean (0 ),
154- torch .sqrt (torch .rand (self ._num_factors ))),
155- dim = 0
156- ))
157- self .factors_embedding [- 1 ].bias = nn .Parameter (torch .cat (
158- (hyper_means ['factor_centers' ],
159- torch .ones (self ._num_factors , 3 ),
160- torch .ones (self ._num_factors , 1 ) *
161- hyper_means ['factor_log_widths' ],
162- torch .sqrt (torch .rand (self ._num_factors , 1 ))),
163- dim = 1 ,
164- ).view (self ._num_factors * 8 ))
159+ self .centers_embedding .bias = nn .Parameter (
160+ hyper_means ['factor_centers' ].view (self ._num_factors * 3 )
161+ )
162+ self .log_widths_embedding .bias = nn .Parameter (
163+ torch .ones (self ._num_factors ) * hyper_means ['factor_log_widths' ]
164+ )
165165
166166 def forward (self , trace , times = None , blocks = None ,
167167 num_particles = tfa_models .NUM_PARTICLES ):
@@ -187,48 +187,53 @@ def forward(self, trace, times=None, blocks=None,
187187 if ('z^F_%d' % subject ) not in trace :
188188 factors_embed = trace .normal (
189189 params ['factors' ]['mu' ][:, subject , :],
190- params ['factors' ]['sigma' ][:, subject , :],
190+ self . softplus ( params ['factors' ]['sigma' ][:, subject , :]) ,
191191 name = 'z^F_%d' % subject
192192 )
193193 if ('z^P_%d' % subject ) not in trace :
194194 subject_embed = trace .normal (
195195 params ['subject' ]['mu' ][:, subject , :],
196- params ['subject' ]['sigma' ][:, subject , :],
196+ self . softplus ( params ['subject' ]['sigma' ][:, subject , :]) ,
197197 name = 'z^P_%d' % subject
198198 )
199199 if ('z^S_%d' % task ) not in trace :
200- task_embed = trace .normal (params ['task' ]['mu' ][:, task ],
201- params ['task' ]['sigma' ][:, task ],
202- name = 'z^S_%d' % task )
200+ task_embed = trace .normal (
201+ params ['task' ]['mu' ][:, task ],
202+ self .softplus (params ['task' ]['sigma' ][:, task ]),
203+ name = 'z^S_%d' % task
204+ )
203205
204206 factor_params = self .factors_embedding (factors_embed )
205- factor_params = factor_params .view (- 1 , self ._num_factors , 8 )
207+ centers_predictions = self .centers_embedding (factor_params ).view (
208+ - 1 , self ._num_factors , 3
209+ )
210+ log_widths_predictions = self .log_widths_embedding (factor_params ).\
211+ view (- 1 , self ._num_factors )
206212 weights_embed = torch .cat ((subject_embed , task_embed ), dim = - 1 )
207- weight_params = self .weights_embedding (weights_embed ).view (
213+ weight_predictions = self .weights_embedding (weights_embed ).view (
208214 - 1 , self ._num_factors , 2
209215 )
210216
211- trace .normal (weight_params [:, :, 0 ], self .epsilon [0 ],
212- name = 'mu^W_%d' % b )
213- trace .normal (self .softplus (weight_params [:, :, 1 ]), self .epsilon [0 ],
214- name = 'sigma^W_%d' % b )
217+ weights_mu = trace .normal (weight_predictions [:, :, 0 ],
218+ self .epsilon [0 ], name = 'mu^W_%d' % b )
219+ weights_sigma = trace .normal (weight_predictions [:, :, 1 ],
220+ self .epsilon [0 ], name = 'sigma^W_%d' % b )
221+ weights_params = params ['block' ]['weights' ]
215222 weights [i ] = trace .normal (
216- params ['block' ]['weights' ]['mu' ][:, b , ts [0 ]:ts [1 ], :],
217- params ['block' ]['weights' ]['sigma' ][:, b , ts [0 ]:ts [1 ], :],
223+ weights_params ['mu' ][:, b , ts [0 ]:ts [1 ], :] +
224+ weights_mu .unsqueeze (1 ),
225+ self .softplus (weights_params ['sigma' ][:, b , ts [0 ]:ts [1 ], :] +
226+ weights_sigma .unsqueeze (1 )),
218227 name = 'Weights%dt%d-%d' % (b , ts [0 ], ts [1 ])
219228 )
220229 factor_centers [i ] = trace .normal (
221- factor_params [:, :, 0 : 3 ] ,
222- self .softplus ( factor_params [:, :, 3 : 6 ]) ,
230+ centers_predictions ,
231+ self .epsilon [ 0 ] ,
223232 name = 'FactorCenters%d' % b
224233 )
225234 factor_log_widths [i ] = trace .normal (
226- factor_params [:, :, 6 ].contiguous ().view (
227- - 1 , self ._num_factors
228- ),
229- self .softplus (factor_params [:, :, 7 ].contiguous ().view (
230- - 1 , self ._num_factors
231- )), name = 'FactorLogWidths%d' % b
235+ log_widths_predictions ,
236+ self .epsilon [0 ], name = 'FactorLogWidths%d' % b
232237 )
233238
234239 return weights , factor_centers , factor_log_widths
@@ -297,5 +302,4 @@ def forward(self, trace, times=None, guide=probtorch.Trace(),
297302 }
298303
299304 return self .htfa_model (trace , times , guide , blocks = blocks ,
300- observations = observations ,
301- weights_params = weight_params )
305+ observations = observations )
0 commit comments