@@ -98,9 +98,11 @@ def __init__(self, query, mask, num_factors=tfa_models.NUM_FACTORS,
9898 'factor_log_widths' : widths ,
9999 }
100100
101- self .decoder = dtfa_models .DeepTFADecoder (self .num_factors , hyper_means ,
101+ self .decoder = dtfa_models .DeepTFADecoder (self .num_factors ,
102+ self .voxel_locations ,
102103 embedding_dim ,
103- time_series = model_time_series )
104+ time_series = model_time_series ,
105+ volume = True )
104106 self .generative = dtfa_models .DeepTFAModel (
105107 self .voxel_locations , block_subjects , block_tasks ,
106108 self .num_factors , self .num_blocks , self .num_times , embedding_dim
@@ -144,11 +146,12 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
144146 decoder = self .decoder
145147 variational = self .variational
146148 generative = self .generative
149+ voxel_locations = self .voxel_locations
147150 if tfa .CUDA and use_cuda :
148151 decoder .cuda ()
149152 variational .cuda ()
150153 generative .cuda ()
151- cuda_locations = self . voxel_locations .cuda ()
154+ voxel_locations = voxel_locations .cuda ()
152155 if not isinstance (learning_rate , dict ):
153156 learning_rate = {
154157 'q' : learning_rate ,
@@ -180,6 +183,7 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
180183 optimizer , factor = 0.5 , min_lr = 1e-5 , patience = patience ,
181184 verbose = True
182185 )
186+ decoder .train ()
183187 variational .train ()
184188 generative .train ()
185189
@@ -203,9 +207,6 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
203207 activations = [{'Y' : data [:, b , :]} for b in block_batch ]
204208 block_batch = [training_blocks [b ][0 ] for b in block_batch ]
205209 if tfa .CUDA and use_cuda :
206- for b in block_batch :
207- generative .likelihoods [b ].voxel_locations = \
208- cuda_locations
209210 for acts in activations :
210211 acts ['Y' ] = acts ['Y' ].cuda ()
211212 trs = (batch * batch_size , None )
@@ -217,7 +218,9 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
217218 num_particles = num_particles )
218219 p = probtorch .Trace ()
219220 generative (decoder , p , times = trs , guide = q ,
220- observations = activations , blocks = block_batch )
221+ observations = activations , blocks = block_batch ,
222+ locations = voxel_locations ,
223+ num_particles = num_particles )
221224
222225 def block_rv_weight (node , prior = True ):
223226 result = 1.0
@@ -239,9 +242,6 @@ def block_rv_weight(node, prior=True):
239242
240243 if tfa .CUDA and use_cuda :
241244 del activations
242- for b in block_batch :
243- generative .likelihoods [b ].voxel_locations = \
244- self .voxel_locations
245245 torch .cuda .empty_cache ()
246246 if tfa .CUDA and use_cuda :
247247 epoch_free_energies [batch ] = epoch_free_energies [batch ].cpu ().data .numpy ()
@@ -300,13 +300,14 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
300300 decoder = self .decoder
301301 variational = self .variational
302302 generative = self .generative
303+ voxel_locations = self .voxel_locations
303304 if tfa .CUDA and use_cuda :
304305 decoder .cuda ()
305306 variational .cuda ()
306307 generative .cuda ()
307- cuda_locations = self . voxel_locations .cuda ().detach ()
308- log_likelihoods = log_likelihoods .to (cuda_locations )
309- prior_kls = prior_kls .to (cuda_locations )
308+ voxel_locations = voxel_locations .cuda ().detach ()
309+ log_likelihoods = log_likelihoods .to (voxel_locations )
310+ prior_kls = prior_kls .to (voxel_locations )
310311
311312 for k in range (sample_size // num_particles ):
312313 for (batch , data ) in enumerate (activations_loader ):
@@ -316,9 +317,6 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
316317 activations = [{'Y' : data [:, b , :]} for b in block_batch ]
317318 block_batch = [testing_blocks [b ][0 ] for b in block_batch ]
318319 if tfa .CUDA and use_cuda :
319- for b in block_batch :
320- generative .likelihoods [b ].voxel_locations = \
321- cuda_locations
322320 for acts in activations :
323321 acts ['Y' ] = acts ['Y' ].cuda ()
324322 trs = (batch * batch_size , None )
@@ -329,7 +327,9 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
329327 num_particles = num_particles )
330328 p = probtorch .Trace ()
331329 generative (decoder , p , times = trs , guide = q ,
332- observations = activations , blocks = block_batch )
330+ observations = activations , blocks = block_batch ,
331+ locations = voxel_locations ,
332+ num_particles = num_particles )
333333
334334 _ , ll , prior_kl = tfa .hierarchical_free_energy (
335335 q , p , num_particles = num_particles
@@ -342,9 +342,6 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
342342
343343 if tfa .CUDA and use_cuda :
344344 del activations
345- for b in block_batch :
346- generative .likelihoods [b ].voxel_locations = \
347- self .voxel_locations
348345 torch .cuda .empty_cache ()
349346
350347 if tfa .CUDA and use_cuda :
@@ -365,10 +362,11 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
365362 prior_kl .mean (dim = 0 ).item ()],
366363 [iwae_free_energy , iwae_log_likelihood , iwae_prior_kl ]]
367364
368- def results (self , block = None , subject = None , task = None , hist_weights = False ):
365+ def results (self , block = None , subject = None , task = None , hist_weights = False ,
366+ generative = False ):
369367 hyperparams = self .variational .hyperparams .state_vardict ()
370368 for k , v in hyperparams .items ():
371- hyperparams [k ] = v .expand ( 1 , * v . shape )
369+ hyperparams [k ] = v .unsqueeze ( 0 )
372370
373371 guide = probtorch .Trace ()
374372 if block is not None :
@@ -389,49 +387,55 @@ def results(self, block=None, subject=None, task=None, hist_weights=False):
389387 guide .variable (
390388 torch .distributions .Normal ,
391389 hyperparams ['subject' ]['mu' ][:, subject ],
392- softplus (hyperparams ['subject' ]['sigma ' ][:, subject ]),
390+ torch . exp (hyperparams ['subject' ]['log_sigma ' ][:, subject ]),
393391 value = hyperparams ['subject' ]['mu' ][:, subject ],
394392 name = 'z^P_{%d,%d}' % (subject , b ),
395393 )
396394 factor_centers_params = hyperparams ['factor_centers' ]
397395 guide .variable (
398396 torch .distributions .Normal ,
399397 factor_centers_params ['mu' ][:, subject ],
400- softplus (factor_centers_params ['sigma ' ][:, subject ]),
398+ torch . exp (factor_centers_params ['log_sigma ' ][:, subject ]),
401399 value = factor_centers_params ['mu' ][:, subject ],
402400 name = 'FactorCenters%d' % b ,
403401 )
404402 factor_log_widths_params = hyperparams ['factor_log_widths' ]
405403 guide .variable (
406404 torch .distributions .Normal ,
407405 factor_log_widths_params ['mu' ][:, subject ],
408- softplus (factor_log_widths_params ['sigma ' ][:, subject ]),
406+ torch . exp (factor_log_widths_params ['log_sigma ' ][:, subject ]),
409407 value = factor_log_widths_params ['mu' ][:, subject ],
410408 name = 'FactorLogWidths%d' % b ,
411409 )
412410 if task is not None :
413411 guide .variable (
414412 torch .distributions .Normal ,
415413 hyperparams ['task' ]['mu' ][:, task ],
416- softplus (hyperparams ['task' ]['sigma ' ][:, task ]),
414+ torch . exp (hyperparams ['task' ]['log_sigma ' ][:, task ]),
417415 value = hyperparams ['task' ]['mu' ][:, task ],
418416 name = 'z^S_{%d,%d}' % (task , b ),
419417 )
420- if self ._time_series :
418+ if self ._time_series and not generative :
421419 for k , v in hyperparams ['weights' ].items ():
422420 hyperparams ['weights' ][k ] = v [:, :, times [0 ]:times [1 ]]
423421 weights_params = hyperparams ['weights' ]
424422 guide .variable (
425423 torch .distributions .Normal ,
426424 weights_params ['mu' ][:, b ],
427- softplus (weights_params ['sigma ' ][:, b ]),
425+ torch . exp (weights_params ['log_sigma ' ][:, b ]),
428426 value = weights_params ['mu' ][:, b ],
429427 name = 'Weights%d_%d-%d' % (b , times [0 ], times [1 ])
430428 )
431429
430+
431+ if generative :
432+ for k , v in hyperparams .items ():
433+ hyperparams [k ] = v .squeeze (0 )
434+
432435 weights , factor_centers , factor_log_widths = \
433436 self .decoder (probtorch .Trace (), blocks , block_subjects , block_tasks ,
434- hyperparams , times , guide = guide , num_particles = 1 )
437+ hyperparams , times , guide = guide , num_particles = 1 ,
438+ generative = generative )
435439
436440 if block is not None :
437441 weights = weights [0 ]
@@ -453,14 +457,17 @@ def results(self, block=None, subject=None, task=None, hist_weights=False):
453457 'factor_centers' : factor_centers .data ,
454458 'factor_log_widths' : factor_log_widths .data ,
455459 }
460+ if generative :
461+ for k , v in hyperparams .items ():
462+ hyperparams [k ] = v .unsqueeze (0 )
456463 if subject is not None :
457464 result ['z^P_%d' % subject ] = hyperparams ['subject' ]['mu' ][:, subject ]
458465 if task is not None :
459466 result ['z^S_%d' % task ] = hyperparams ['task' ]['mu' ][:, task ]
460467 return result
461468
462469 def reconstruction (self , block = None , subject = None , task = None , t = 0 ):
463- results = self .results (block , subject , task )
470+ results = self .results (block , subject , task , generative = t is None )
464471 reconstruction = results ['weights' ] @ results ['factors' ]
465472
466473 image = utils .cmu2nii (reconstruction .numpy (),
@@ -823,7 +830,7 @@ def heatmap_subject_embedding(self, heatmaps=[], filename='', show=True,
823830 filename = self .common_name () + '_subject_heatmap.pdf'
824831 hyperparams = self .variational .hyperparams .state_vardict ()
825832 z_p_mu = hyperparams ['subject' ]['mu' ].data
826- z_p_sigma = softplus (hyperparams ['subject' ]['sigma ' ].data )
833+ z_p_sigma = torch . exp (hyperparams ['subject' ]['log_sigma ' ].data )
827834 subjects = self .subjects ()
828835
829836 minus_lims = torch .min (z_p_mu - z_p_sigma * 2 , dim = 0 )[0 ].tolist ()
@@ -886,7 +893,7 @@ def scatter_subject_embedding(self, labeler=None, filename='', show=True,
886893 filename = self .common_name () + '_subject_embedding.pdf'
887894 hyperparams = self .variational .hyperparams .state_vardict ()
888895 z_p_mu = hyperparams ['subject' ]['mu' ].data
889- z_p_sigma = softplus (hyperparams ['subject' ]['sigma ' ].data )
896+ z_p_sigma = torch . exp (hyperparams ['subject' ]['log_sigma ' ].data )
890897 subjects = self .subjects ()
891898
892899 minus_lims = torch .min (z_p_mu - z_p_sigma * 2 , dim = 0 )[0 ].tolist ()
@@ -935,7 +942,7 @@ def scatter_task_embedding(self, labeler=None, filename='', show=True,
935942 filename = self .common_name () + '_task_embedding.pdf'
936943 hyperparams = self .variational .hyperparams .state_vardict ()
937944 z_s_mu = hyperparams ['task' ]['mu' ].data
938- z_s_sigma = softplus (hyperparams ['task' ]['sigma ' ].data )
945+ z_s_sigma = torch . exp (hyperparams ['task' ]['log_sigma ' ].data )
939946 tasks = self .tasks ()
940947
941948 minus_lims = torch .min (z_s_mu - z_s_sigma * 2 , dim = 0 )[0 ].tolist ()
0 commit comments