Skip to content

Commit f36c4ac

Browse files
authored
Merge pull request #41 from neu-spiral/develop
NeurIPS 2020 supplementary material
2 parents fdd0732 + e1e5e06 commit f36c4ac

27 files changed

+42859
-51507
lines changed

htfa_torch/dtfa.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,11 @@ def __init__(self, query, mask, num_factors=tfa_models.NUM_FACTORS,
9898
'factor_log_widths': widths,
9999
}
100100

101-
self.decoder = dtfa_models.DeepTFADecoder(self.num_factors, hyper_means,
101+
self.decoder = dtfa_models.DeepTFADecoder(self.num_factors,
102+
self.voxel_locations,
102103
embedding_dim,
103-
time_series=model_time_series)
104+
time_series=model_time_series,
105+
volume=True)
104106
self.generative = dtfa_models.DeepTFAModel(
105107
self.voxel_locations, block_subjects, block_tasks,
106108
self.num_factors, self.num_blocks, self.num_times, embedding_dim
@@ -144,11 +146,12 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
144146
decoder = self.decoder
145147
variational = self.variational
146148
generative = self.generative
149+
voxel_locations = self.voxel_locations
147150
if tfa.CUDA and use_cuda:
148151
decoder.cuda()
149152
variational.cuda()
150153
generative.cuda()
151-
cuda_locations = self.voxel_locations.cuda()
154+
voxel_locations = voxel_locations.cuda()
152155
if not isinstance(learning_rate, dict):
153156
learning_rate = {
154157
'q': learning_rate,
@@ -180,6 +183,7 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
180183
optimizer, factor=0.5, min_lr=1e-5, patience=patience,
181184
verbose=True
182185
)
186+
decoder.train()
183187
variational.train()
184188
generative.train()
185189

@@ -203,9 +207,6 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
203207
activations = [{'Y': data[:, b, :]} for b in block_batch]
204208
block_batch = [training_blocks[b][0] for b in block_batch]
205209
if tfa.CUDA and use_cuda:
206-
for b in block_batch:
207-
generative.likelihoods[b].voxel_locations =\
208-
cuda_locations
209210
for acts in activations:
210211
acts['Y'] = acts['Y'].cuda()
211212
trs = (batch * batch_size, None)
@@ -217,7 +218,9 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
217218
num_particles=num_particles)
218219
p = probtorch.Trace()
219220
generative(decoder, p, times=trs, guide=q,
220-
observations=activations, blocks=block_batch)
221+
observations=activations, blocks=block_batch,
222+
locations=voxel_locations,
223+
num_particles=num_particles)
221224

222225
def block_rv_weight(node, prior=True):
223226
result = 1.0
@@ -239,9 +242,6 @@ def block_rv_weight(node, prior=True):
239242

240243
if tfa.CUDA and use_cuda:
241244
del activations
242-
for b in block_batch:
243-
generative.likelihoods[b].voxel_locations =\
244-
self.voxel_locations
245245
torch.cuda.empty_cache()
246246
if tfa.CUDA and use_cuda:
247247
epoch_free_energies[batch] = epoch_free_energies[batch].cpu().data.numpy()
@@ -300,13 +300,14 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
300300
decoder = self.decoder
301301
variational = self.variational
302302
generative = self.generative
303+
voxel_locations = self.voxel_locations
303304
if tfa.CUDA and use_cuda:
304305
decoder.cuda()
305306
variational.cuda()
306307
generative.cuda()
307-
cuda_locations = self.voxel_locations.cuda().detach()
308-
log_likelihoods = log_likelihoods.to(cuda_locations)
309-
prior_kls = prior_kls.to(cuda_locations)
308+
voxel_locations = voxel_locations.cuda().detach()
309+
log_likelihoods = log_likelihoods.to(voxel_locations)
310+
prior_kls = prior_kls.to(voxel_locations)
310311

311312
for k in range(sample_size // num_particles):
312313
for (batch, data) in enumerate(activations_loader):
@@ -316,9 +317,6 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
316317
activations = [{'Y': data[:, b, :]} for b in block_batch]
317318
block_batch = [testing_blocks[b][0] for b in block_batch]
318319
if tfa.CUDA and use_cuda:
319-
for b in block_batch:
320-
generative.likelihoods[b].voxel_locations =\
321-
cuda_locations
322320
for acts in activations:
323321
acts['Y'] = acts['Y'].cuda()
324322
trs = (batch * batch_size, None)
@@ -329,7 +327,9 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
329327
num_particles=num_particles)
330328
p = probtorch.Trace()
331329
generative(decoder, p, times=trs, guide=q,
332-
observations=activations, blocks=block_batch)
330+
observations=activations, blocks=block_batch,
331+
locations=voxel_locations,
332+
num_particles=num_particles)
333333

334334
_, ll, prior_kl = tfa.hierarchical_free_energy(
335335
q, p, num_particles=num_particles
@@ -342,9 +342,6 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
342342

343343
if tfa.CUDA and use_cuda:
344344
del activations
345-
for b in block_batch:
346-
generative.likelihoods[b].voxel_locations =\
347-
self.voxel_locations
348345
torch.cuda.empty_cache()
349346

350347
if tfa.CUDA and use_cuda:
@@ -365,10 +362,11 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
365362
prior_kl.mean(dim=0).item()],
366363
[iwae_free_energy, iwae_log_likelihood, iwae_prior_kl]]
367364

368-
def results(self, block=None, subject=None, task=None, hist_weights=False):
365+
def results(self, block=None, subject=None, task=None, hist_weights=False,
366+
generative=False):
369367
hyperparams = self.variational.hyperparams.state_vardict()
370368
for k, v in hyperparams.items():
371-
hyperparams[k] = v.expand(1, *v.shape)
369+
hyperparams[k] = v.unsqueeze(0)
372370

373371
guide = probtorch.Trace()
374372
if block is not None:
@@ -389,49 +387,55 @@ def results(self, block=None, subject=None, task=None, hist_weights=False):
389387
guide.variable(
390388
torch.distributions.Normal,
391389
hyperparams['subject']['mu'][:, subject],
392-
softplus(hyperparams['subject']['sigma'][:, subject]),
390+
torch.exp(hyperparams['subject']['log_sigma'][:, subject]),
393391
value=hyperparams['subject']['mu'][:, subject],
394392
name='z^P_{%d,%d}' % (subject, b),
395393
)
396394
factor_centers_params = hyperparams['factor_centers']
397395
guide.variable(
398396
torch.distributions.Normal,
399397
factor_centers_params['mu'][:, subject],
400-
softplus(factor_centers_params['sigma'][:, subject]),
398+
torch.exp(factor_centers_params['log_sigma'][:, subject]),
401399
value=factor_centers_params['mu'][:, subject],
402400
name='FactorCenters%d' % b,
403401
)
404402
factor_log_widths_params = hyperparams['factor_log_widths']
405403
guide.variable(
406404
torch.distributions.Normal,
407405
factor_log_widths_params['mu'][:, subject],
408-
softplus(factor_log_widths_params['sigma'][:, subject]),
406+
torch.exp(factor_log_widths_params['log_sigma'][:, subject]),
409407
value=factor_log_widths_params['mu'][:, subject],
410408
name='FactorLogWidths%d' % b,
411409
)
412410
if task is not None:
413411
guide.variable(
414412
torch.distributions.Normal,
415413
hyperparams['task']['mu'][:, task],
416-
softplus(hyperparams['task']['sigma'][:, task]),
414+
torch.exp(hyperparams['task']['log_sigma'][:, task]),
417415
value=hyperparams['task']['mu'][:, task],
418416
name='z^S_{%d,%d}' % (task, b),
419417
)
420-
if self._time_series:
418+
if self._time_series and not generative:
421419
for k, v in hyperparams['weights'].items():
422420
hyperparams['weights'][k] = v[:, :, times[0]:times[1]]
423421
weights_params = hyperparams['weights']
424422
guide.variable(
425423
torch.distributions.Normal,
426424
weights_params['mu'][:, b],
427-
softplus(weights_params['sigma'][:, b]),
425+
torch.exp(weights_params['log_sigma'][:, b]),
428426
value=weights_params['mu'][:, b],
429427
name='Weights%d_%d-%d' % (b, times[0], times[1])
430428
)
431429

430+
431+
if generative:
432+
for k, v in hyperparams.items():
433+
hyperparams[k] = v.squeeze(0)
434+
432435
weights, factor_centers, factor_log_widths =\
433436
self.decoder(probtorch.Trace(), blocks, block_subjects, block_tasks,
434-
hyperparams, times, guide=guide, num_particles=1)
437+
hyperparams, times, guide=guide, num_particles=1,
438+
generative=generative)
435439

436440
if block is not None:
437441
weights = weights[0]
@@ -453,14 +457,17 @@ def results(self, block=None, subject=None, task=None, hist_weights=False):
453457
'factor_centers': factor_centers.data,
454458
'factor_log_widths': factor_log_widths.data,
455459
}
460+
if generative:
461+
for k, v in hyperparams.items():
462+
hyperparams[k] = v.unsqueeze(0)
456463
if subject is not None:
457464
result['z^P_%d' % subject] = hyperparams['subject']['mu'][:, subject]
458465
if task is not None:
459466
result['z^S_%d' % task] = hyperparams['task']['mu'][:, task]
460467
return result
461468

462469
def reconstruction(self, block=None, subject=None, task=None, t=0):
463-
results = self.results(block, subject, task)
470+
results = self.results(block, subject, task, generative=t is None)
464471
reconstruction = results['weights'] @ results['factors']
465472

466473
image = utils.cmu2nii(reconstruction.numpy(),
@@ -823,7 +830,7 @@ def heatmap_subject_embedding(self, heatmaps=[], filename='', show=True,
823830
filename = self.common_name() + '_subject_heatmap.pdf'
824831
hyperparams = self.variational.hyperparams.state_vardict()
825832
z_p_mu = hyperparams['subject']['mu'].data
826-
z_p_sigma = softplus(hyperparams['subject']['sigma'].data)
833+
z_p_sigma = torch.exp(hyperparams['subject']['log_sigma'].data)
827834
subjects = self.subjects()
828835

829836
minus_lims = torch.min(z_p_mu - z_p_sigma * 2, dim=0)[0].tolist()
@@ -886,7 +893,7 @@ def scatter_subject_embedding(self, labeler=None, filename='', show=True,
886893
filename = self.common_name() + '_subject_embedding.pdf'
887894
hyperparams = self.variational.hyperparams.state_vardict()
888895
z_p_mu = hyperparams['subject']['mu'].data
889-
z_p_sigma = softplus(hyperparams['subject']['sigma'].data)
896+
z_p_sigma = torch.exp(hyperparams['subject']['log_sigma'].data)
890897
subjects = self.subjects()
891898

892899
minus_lims = torch.min(z_p_mu - z_p_sigma * 2, dim=0)[0].tolist()
@@ -935,7 +942,7 @@ def scatter_task_embedding(self, labeler=None, filename='', show=True,
935942
filename = self.common_name() + '_task_embedding.pdf'
936943
hyperparams = self.variational.hyperparams.state_vardict()
937944
z_s_mu = hyperparams['task']['mu'].data
938-
z_s_sigma = softplus(hyperparams['task']['sigma'].data)
945+
z_s_sigma = torch.exp(hyperparams['task']['log_sigma'].data)
939946
tasks = self.tasks()
940947

941948
minus_lims = torch.min(z_s_mu - z_s_sigma * 2, dim=0)[0].tolist()

0 commit comments

Comments
 (0)