55
66import collections
77import numpy as np
8+ import scipy .spatial
89import torch
910from torch .autograd import Variable
1011import probtorch
@@ -53,15 +54,12 @@ def __init__(self, hyper_means, num_times, num_blocks,
5354 'sigma' : torch .sqrt (torch .rand (self ._num_blocks , self ._num_factors )),
5455 },
5556 'weights' : {
56- 'mu' : torch .randn (self ._num_blocks , self ._num_times ,
57- self ._num_factors ),
57+ 'mu' : hyper_means ['weights' ].mean (0 ).unsqueeze (0 ).expand (
58+ self ._num_blocks , self ._num_times , self ._num_factors
59+ ),
5860 'sigma' : torch .ones (self ._num_blocks , self ._num_times ,
5961 self ._num_factors ),
6062 },
61- 'voxel_noise' : {
62- 'mu' : torch .ones (self ._num_blocks ),
63- 'sigma' : torch .sqrt (torch .rand (self ._num_blocks ))
64- }
6563 })
6664
6765 super (self .__class__ , self ).__init__ (params , guide = True )
@@ -95,15 +93,6 @@ def forward(self, trace, params, times=None, blocks=None,
9593 # We only expand the parameters for which we're actually going to sample
9694 # values in this very method, and thus want to expand to get multiple
9795 # particles.
98- voxel_noise_params = params ['block' ]['voxel_noise' ]
99- if num_particles and num_particles > 0 :
100- voxel_noise_params = utils .unsqueeze_and_expand_vardict (
101- params ['block' ]['voxel_noise' ], 0 , num_particles , True
102- )
103- voxel_noise = trace .normal (voxel_noise_params ['mu' ],
104- voxel_noise_params ['sigma' ],
105- name = 'voxel_noise' )
106-
10796 if blocks is None :
10897 blocks = list (range (self ._num_blocks ))
10998
@@ -125,16 +114,17 @@ def forward(self, trace, params, times=None, blocks=None,
125114 factor_centers += [fc ]
126115 factor_log_widths += [flw ]
127116
128- return weights , factor_centers , factor_log_widths , voxel_noise
117+ return weights , factor_centers , factor_log_widths
129118
130119class HTFAGuide (nn .Module ):
131120 """Variational guide for hierarchical topographic factor analysis"""
132121 def __init__ (self , query , num_factors = tfa_models .NUM_FACTORS ):
133122 super (self .__class__ , self ).__init__ ()
134123 self ._num_blocks = len (query )
135- self ._num_times = niidb .query_min_time (query )
124+ self ._num_times = niidb .query_max_time (query )
136125
137- b = np .random .choice (self ._num_blocks , 1 )[0 ]
126+ b = max (range (self ._num_blocks ), key = lambda b : query [b ].end_time -
127+ query [b ].start_time )
138128 query [b ].load ()
139129 centers , widths , weights = utils .initial_hypermeans (
140130 query [b ].activations .numpy ().T , query [b ].locations .numpy (),
@@ -161,7 +151,7 @@ def forward(self, trace, times=None, blocks=None,
161151
162152class HTFAGenerativeHyperParams (tfa_models .HyperParams ):
163153 def __init__ (self , brain_center , brain_center_std_dev , num_blocks ,
164- num_factors = tfa_models .NUM_FACTORS ):
154+ num_factors = tfa_models .NUM_FACTORS , volume = None ):
165155 self ._num_factors = num_factors
166156 self ._num_blocks = num_blocks
167157
@@ -177,20 +167,22 @@ def __init__(self, brain_center, brain_center_std_dev, num_blocks,
177167 params ['template' ]['factor_centers' ]['sigma' ] = \
178168 brain_center_std_dev .expand (self ._num_factors , 3 )
179169
170+ coefficient = 1.0
171+ if volume is not None :
172+ coefficient = np .log (np .cbrt (volume / self ._num_factors ))
180173 params ['template' ]['factor_log_widths' ]['mu' ] = \
181- torch .ones (self ._num_factors )
174+ coefficient * torch .ones (self ._num_factors )
182175 params ['template' ]['factor_log_widths' ]['sigma' ] = \
183176 tfa_models .SOURCE_LOG_WIDTH_STD_DEV * torch .ones (self ._num_factors )
184177
185178 params ['block' ] = {
186179 'factor_center_noise' : torch .ones (self ._num_blocks ),
187180 'factor_log_width_noise' : torch .ones (self ._num_blocks ),
188181 'weights' : {
189- 'mu' : torch .rand (self ._num_blocks , self ._num_factors ),
182+ 'mu' : torch .zeros (self ._num_blocks , self ._num_factors ),
190183 'sigma' : tfa_models .SOURCE_WEIGHT_STD_DEV * \
191184 torch .ones (self ._num_blocks , self ._num_factors )
192185 },
193- 'voxel_noise' : utils .gaussian_populator (self ._num_blocks )
194186 }
195187 super (self .__class__ , self ).__init__ (params , guide = False )
196188
@@ -216,19 +208,16 @@ def __init__(self, num_blocks, num_times):
216208 for b in range (self ._num_blocks )]
217209
218210 def forward (self , trace , params , template , times = None , blocks = None ,
219- guide = probtorch .Trace ()):
220- voxel_noise = trace .normal (params ['block' ]['voxel_noise' ]['mu' ],
221- params ['block' ]['voxel_noise' ]['sigma' ],
222- value = guide ['voxel_noise' ],
223- name = 'voxel_noise' )
224-
211+ guide = probtorch .Trace (), weights_params = None ):
225212 if blocks is None :
226213 blocks = list (range (self ._num_blocks ))
214+ if times is None :
215+ times = (0 , self ._num_times )
227216
228217 weights = []
229218 factor_centers = []
230219 factor_log_widths = []
231- for b in blocks :
220+ for ( i , b ) in enumerate ( blocks ) :
232221 sparams = utils .vardict ({
233222 'factor_centers' : {
234223 'mu' : template ['factor_centers' ],
@@ -243,56 +232,58 @@ def forward(self, trace, params, template, times=None, blocks=None,
243232 'sigma' : params ['block' ]['weights' ]['sigma' ][b ],
244233 }
245234 })
235+ if weights_params is not None :
236+ sparams ['weights' ] = weights_params [i ]
246237 w , fc , flw = self ._tfa_priors [b ](trace , sparams , times = times ,
247238 guide = guide )
248239 weights += [w ]
249240 factor_centers += [fc ]
250241 factor_log_widths += [flw ]
251242
252- return weights , factor_centers , factor_log_widths , voxel_noise
243+ return weights , factor_centers , factor_log_widths
253244
254245class HTFAModel (nn .Module ):
255246 """Generative model for hierarchical topographic factor analysis"""
256- def __init__ (self , query , num_blocks , num_times ,
257- num_factors = tfa_models .NUM_FACTORS ):
247+ def __init__ (self , locations , num_blocks , num_times ,
248+ num_factors = tfa_models .NUM_FACTORS , volume = None ):
258249 super (self .__class__ , self ).__init__ ()
259250
260251 self ._num_factors = num_factors
261252 self ._num_blocks = num_blocks
262253 self ._num_times = num_times
263254
264- b = np .random .choice (self ._num_blocks , 1 )[0 ]
265- query [b ].load ()
266- center , center_sigma = utils .brain_centroid (query [b ].locations )
255+ center , center_sigma = utils .brain_centroid (locations )
256+ hull = scipy .spatial .ConvexHull (locations )
257+ if volume is not None :
258+ volume = hull .volume
267259
268260 self ._hyperparams = HTFAGenerativeHyperParams (center , center_sigma ,
269261 self ._num_blocks ,
270- self ._num_factors )
262+ self ._num_factors ,
263+ volume = volume )
271264 self ._template_prior = HTFAGenerativeTemplatePrior ()
272265 self ._subject_prior = HTFAGenerativeSubjectPrior (
273266 self ._num_blocks , self ._num_times
274267 )
275- for block in query :
276- block .load ()
277268 self .likelihoods = [tfa_models .TFAGenerativeLikelihood (
278- query [ b ]. locations , self ._num_times [b ], tfa_models .VOXEL_NOISE ,
269+ locations , self ._num_times [b ], tfa_models .VOXEL_NOISE ,
279270 block = b , register_locations = False
280271 ) for b in range (self ._num_blocks )]
281272 for b , block_likelihood in enumerate (self .likelihoods ):
282273 self .add_module ('likelihood' + str (b ), block_likelihood )
283274
284275 def forward (self , trace , times = None , guide = probtorch .Trace (), blocks = None ,
285- observations = []):
276+ observations = [], weights_params = None ):
286277 if blocks is None :
287278 blocks = list (range (self ._num_blocks ))
288279 params = self ._hyperparams .state_vardict ()
289280
290281 template = self ._template_prior (trace , params , guide = guide )
291- weights , centers , log_widths , voxel_noise = self ._subject_prior (
292- trace , params , template , times = times , blocks = blocks , guide = guide
282+ weights , centers , log_widths = self ._subject_prior (
283+ trace , params , template , times = times , blocks = blocks , guide = guide ,
284+ weights_params = weights_params
293285 )
294286
295287 return [self .likelihoods [b ](trace , weights [i ], centers [i ], log_widths [i ],
296- times = times , observations = observations [i ],
297- voxel_noise = voxel_noise )
288+ times = times , observations = observations [i ])
298289 for (i , b ) in enumerate (blocks )]
0 commit comments