Skip to content

Commit fdd0732

Browse files
authored
Merge pull request #38 from neu-spiral/develop
Merge for ICML 2020 code submission
2 parents a2d8415 + aa30232 commit fdd0732

File tree

62 files changed

+210064
-5140
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+210064
-5140
lines changed

htfa_torch/dtfa.py

Lines changed: 309 additions & 152 deletions
Large diffs are not rendered by default.

htfa_torch/dtfa_models.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,11 @@ def __init__(self, num_subjects, num_tasks, embedding_dim=2):
3232
params = utils.vardict({
3333
'subject': {
3434
'mu': torch.zeros(self.num_subjects, self.embedding_dim),
35-
'sigma': torch.ones(self.num_subjects, self.embedding_dim) *\
36-
np.sqrt(tfa_models.SOURCE_WEIGHT_STD_DEV**2 +\
37-
tfa_models.SOURCE_LOG_WIDTH_STD_DEV**2),
35+
'sigma': torch.ones(self.num_subjects, self.embedding_dim),
3836
},
3937
'task': {
4038
'mu': torch.zeros(self.num_tasks, self.embedding_dim),
41-
'sigma': torch.ones(self.num_tasks, self.embedding_dim) *\
42-
tfa_models.SOURCE_WEIGHT_STD_DEV,
39+
'sigma': torch.ones(self.num_tasks, self.embedding_dim),
4340
},
4441
'voxel_noise': torch.ones(1) * tfa_models.VOXEL_NOISE,
4542
})
@@ -48,7 +45,7 @@ def __init__(self, num_subjects, num_tasks, embedding_dim=2):
4845

4946
class DeepTFAGuideHyperparams(tfa_models.HyperParams):
5047
def __init__(self, num_blocks, num_times, num_factors, num_subjects,
51-
num_tasks, hyper_means, embedding_dim=2):
48+
num_tasks, hyper_means, embedding_dim=2, time_series=True):
5249
self.num_blocks = num_blocks
5350
self.num_subjects = num_subjects
5451
self.num_tasks = num_tasks
@@ -78,23 +75,26 @@ def __init__(self, num_blocks, num_times, num_factors, num_subjects,
7875
'sigma': torch.ones(self.num_subjects, self._num_factors) *\
7976
hyper_means['factor_log_widths'].std(),
8077
},
81-
'weights': {
78+
})
79+
if time_series:
80+
params['weights'] = {
8281
'mu': torch.zeros(self.num_blocks, self.num_times,
8382
self._num_factors),
8483
'sigma': torch.ones(self.num_blocks, self.num_times,
8584
self._num_factors),
86-
},
87-
})
85+
}
8886

8987
super(self.__class__, self).__init__(params, guide=True)
9088

9189
class DeepTFADecoder(nn.Module):
9290
"""Neural network module mapping from embeddings to a topographic factor
9391
analysis"""
94-
def __init__(self, num_factors, hyper_means, embedding_dim=2):
92+
def __init__(self, num_factors, hyper_means, embedding_dim=2,
93+
time_series=True):
9594
super(DeepTFADecoder, self).__init__()
9695
self._embedding_dim = embedding_dim
9796
self._num_factors = num_factors
97+
self._time_series = time_series
9898

9999
self.factors_embedding = nn.Sequential(
100100
nn.Linear(self._embedding_dim, self._embedding_dim * 2),
@@ -200,7 +200,8 @@ def predict(self, trace, params, guide, subject, task, times=(0, 1),
200200
weight_predictions = self._predict_param(
201201
params, 'weights', block, weight_predictions,
202202
'Weights%d_%d-%d' % (block, times[0], times[1]), trace,
203-
predict=generative or block < 0, guide=guide,
203+
predict=generative or block < 0 or not self._time_series,
204+
guide=guide,
204205
)
205206

206207
return centers_predictions, log_widths_predictions, weight_predictions
@@ -237,12 +238,14 @@ def forward(self, trace, blocks, block_subjects, block_tasks, params, times,
237238
class DeepTFAGuide(nn.Module):
238239
"""Variational guide for deep topographic factor analysis"""
239240
def __init__(self, num_factors, block_subjects, block_tasks, num_blocks=1,
240-
num_times=[1], embedding_dim=2, hyper_means=None):
241+
num_times=[1], embedding_dim=2, hyper_means=None,
242+
time_series=True):
241243
super(self.__class__, self).__init__()
242244
self._num_blocks = num_blocks
243245
self._num_times = num_times
244246
self._num_factors = num_factors
245247
self._embedding_dim = embedding_dim
248+
self._time_series = time_series
246249

247250
self.block_subjects = block_subjects
248251
self.block_tasks = block_tasks
@@ -254,7 +257,7 @@ def __init__(self, num_factors, block_subjects, block_tasks, num_blocks=1,
254257
self._num_factors,
255258
num_subjects, num_tasks,
256259
hyper_means,
257-
embedding_dim)
260+
embedding_dim, time_series)
258261

259262
def forward(self, decoder, trace, times=None, blocks=None,
260263
num_particles=tfa_models.NUM_PARTICLES):
@@ -269,7 +272,7 @@ def forward(self, decoder, trace, times=None, blocks=None,
269272
if b in blocks]
270273
block_tasks = [self.block_tasks[b] for b in range(self._num_blocks)
271274
if b in blocks]
272-
if times:
275+
if times and self._time_series:
273276
for k, v in params['weights'].items():
274277
params['weights'][k] = v[:, :, times[0]:times[1], :]
275278

0 commit comments

Comments
 (0)