Skip to content

Commit 624dfa5

Browse files
author
Eli Sennesh
committed
Merge branch 'feature/deeptfa_gradient_flow' into develop
2 parents 5fb5524 + f63dd58 commit 624dfa5

File tree

6 files changed

+1471
-1413
lines changed

6 files changed

+1471
-1413
lines changed

htfa_torch/dtfa.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from . import tfa_models
3939
from . import utils
4040

41+
EPOCH_MSG = '[Epoch %d] (%dms) Posterior free-energy %.8e = KL from prior %.8e - log-likelihood %.8e'
42+
4143
class DeepTFA:
4244
"""Overall container for a run of Deep TFA"""
4345
def __init__(self, query, mask, num_factors=tfa_models.NUM_FACTORS,
@@ -123,7 +125,7 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
123125
optimizer = torch.optim.Adam(list(variational.parameters()),
124126
lr=learning_rate, weight_decay=1e-2)
125127
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
126-
optimizer, factor=1e-1, min_lr=5e-5, patience=patience,
128+
optimizer, factor=0.5, min_lr=1e-5, patience=patience,
127129
verbose=True
128130
)
129131
variational.train()
@@ -136,9 +138,13 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
136138
for epoch in range(num_steps):
137139
start = time.time()
138140
epoch_free_energies = list(range(len(activations_loader)))
141+
epoch_lls = list(range(len(activations_loader)))
142+
epoch_prior_kls = list(range(len(activations_loader)))
139143

140144
for (batch, data) in enumerate(activations_loader):
141145
epoch_free_energies[batch] = 0.0
146+
epoch_lls[batch] = 0.0
147+
epoch_prior_kls[batch] = 0.0
142148
block_batches = utils.chunks(list(range(self.num_blocks)),
143149
n=blocks_batch_size)
144150
for block_batch in block_batches:
@@ -161,13 +167,13 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
161167
generative(p, times=trs, guide=q, observations=activations,
162168
blocks=block_batch)
163169

164-
def block_rv_weight(node):
170+
def block_rv_weight(node, prior=True):
165171
result = 1.0
166172
if measure_occurrences:
167173
rv_occurrences[node] += 1
168174
result /= rv_occurrences[node]
169175
return result
170-
free_energy = tfa.hierarchical_free_energy(
176+
free_energy, ll, prior_kl = tfa.hierarchical_free_energy(
171177
q, p,
172178
rv_weight=block_rv_weight,
173179
num_particles=num_particles
@@ -176,6 +182,8 @@ def block_rv_weight(node):
176182
free_energy.backward()
177183
optimizer.step()
178184
epoch_free_energies[batch] += free_energy
185+
epoch_lls[batch] += ll
186+
epoch_prior_kls[batch] += prior_kl
179187

180188
if tfa.CUDA and use_cuda:
181189
del activations
@@ -185,8 +193,12 @@ def block_rv_weight(node):
185193
torch.cuda.empty_cache()
186194
if tfa.CUDA and use_cuda:
187195
epoch_free_energies[batch] = epoch_free_energies[batch].cpu().data.numpy()
196+
epoch_lls[batch] = epoch_lls[batch].cpu().data.numpy()
197+
epoch_prior_kls[batch] = epoch_prior_kls[batch].cpu().data.numpy()
188198
else:
189199
epoch_free_energies[batch] = epoch_free_energies[batch].data.numpy()
200+
epoch_lls[batch] = epoch_lls[batch].data.numpy()
201+
epoch_prior_kls[batch] = epoch_prior_kls[batch].data.numpy()
190202

191203
free_energies[epoch] = np.array(epoch_free_energies).sum(0)
192204
free_energies[epoch] = free_energies[epoch].sum(0)
@@ -195,7 +207,9 @@ def block_rv_weight(node):
195207
measure_occurrences = False
196208

197209
end = time.time()
198-
msg = tfa.EPOCH_MSG % (epoch + 1, (end - start) * 1000, free_energies[epoch])
210+
msg = EPOCH_MSG % (epoch + 1, (end - start) * 1000,
211+
free_energies[epoch], sum(epoch_prior_kls),
212+
sum(epoch_lls))
199213
logging.info(msg)
200214
if checkpoint_steps is not None and epoch % checkpoint_steps == 0:
201215
now = datetime.datetime.now()
@@ -214,22 +228,33 @@ def block_rv_weight(node):
214228

215229
return np.vstack([free_energies])
216230

217-
def results(self, block):
231+
def results(self, block, hist_weights=False):
218232
hyperparams = self.variational.hyperparams.state_vardict()
219233
subject = self.generative.block_subjects[block]
234+
task = self.generative.block_tasks[block]
220235

221236
factors_embed = hyperparams['factors']['mu'][subject]
222237

223-
weights = hyperparams['block']['weights']['mu'][block]\
224-
[self._blocks[block].start_time:
225-
self._blocks[block].end_time]
226-
factor_params = self.variational.factors_embedding(factors_embed).view(
227-
self.num_factors, 8
238+
factor_params = self.variational.factors_embedding(factors_embed)
239+
factor_centers = self.variational.centers_embedding(factor_params).view(
240+
self.num_factors, 3
228241
)
229-
factor_centers = factor_params[:, :3]
230-
factor_log_widths = factor_params[:, 6].contiguous().view(
231-
self.num_factors
242+
factor_log_widths = self.variational.log_widths_embedding(factor_params)
243+
244+
weight_deltas = hyperparams['block']['weights']['mu'][block]\
245+
[self._blocks[block].start_time:
246+
self._blocks[block].end_time]
247+
subject_embed = hyperparams['subject']['mu'][subject]
248+
task_embed = hyperparams['task']['mu'][task]
249+
weights_embed = torch.cat((subject_embed, task_embed), dim=-1)
250+
weight_params = self.variational.weights_embedding(weights_embed).view(
251+
self.num_factors, 2
232252
)
253+
weights = weight_params[:, 0] + weight_deltas
254+
255+
if hist_weights:
256+
plt.hist(weights.view(weights.numel()).data.numpy())
257+
plt.show()
233258

234259
return {
235260
'weights': weights.data,
@@ -356,6 +381,46 @@ def plot_reconstruction(self, block=None, filename=None, show=True,
356381

357382
return plot
358383

384+
def visualize_factor_embedding(self, filename=None, show=True,
385+
num_samples=100, hist_log_widths=True,
386+
**kwargs):
387+
hyperprior = self.generative.hyperparams.state_vardict()
388+
389+
factor_prior = utils.unsqueeze_and_expand_vardict({
390+
'mu': hyperprior['factors']['mu'][0],
391+
'sigma': hyperprior['factors']['sigma'][0]
392+
}, 0, num_samples, True)
393+
394+
embedding = torch.normal(factor_prior['mu'], factor_prior['sigma'] * 2)
395+
factor_params = self.variational.factors_embedding(embedding)
396+
centers = self.variational.centers_embedding(factor_params).view(
397+
-1, self.num_factors, 3
398+
).data
399+
widths = torch.exp(self.variational.log_widths_embedding(factor_params))
400+
widths = widths.view(-1, self.num_factors).data
401+
402+
plot = niplot.plot_connectome(
403+
np.eye(num_samples * self.num_factors),
404+
centers.view(num_samples * self.num_factors, 3).numpy(),
405+
node_size=widths.view(num_samples * self.num_factors).numpy(),
406+
title="$z^F$ std-dev %.8e, $x^F$ std-dev %.8e, $\\rho^F$ std-dev %.8e" %
407+
(embedding.std(0).norm(), centers.std(0).norm(),
408+
torch.log(widths).std(0).norm()),
409+
**kwargs
410+
)
411+
412+
if filename is not None:
413+
plot.savefig(filename)
414+
if show:
415+
niplot.show()
416+
417+
if hist_log_widths:
418+
log_widths = torch.log(widths)
419+
plt.hist(log_widths.view(log_widths.numel()).numpy())
420+
plt.show()
421+
422+
return plot, centers, torch.log(widths)
423+
359424
def scatter_factor_embedding(self, labeler=None, filename=None, show=True,
360425
xlims=None, ylims=None, figsize=(3.75, 2.75),
361426
colormap='Set1'):

htfa_torch/dtfa_models.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,19 @@ def __init__(self, num_blocks, num_times, num_factors, num_subjects,
7676

7777
params = utils.vardict({
7878
'factors': {
79-
'mu': torch.zeros(self.num_blocks, self.embedding_dim),
80-
'sigma': torch.sqrt(torch.rand(self.num_blocks, self.embedding_dim)),
79+
'mu': torch.zeros(self.num_subjects, self.embedding_dim),
80+
'sigma': torch.ones(self.num_subjects, self.embedding_dim) *\
81+
tfa_models.SOURCE_LOG_WIDTH_STD_DEV,
8182
},
8283
'subject': {
8384
'mu': torch.zeros(self.num_subjects, self.embedding_dim),
84-
'sigma': torch.sqrt(torch.rand(self.num_blocks, self.embedding_dim)),
85+
'sigma': torch.ones(self.num_subjects, self.embedding_dim) *\
86+
tfa_models.SOURCE_WEIGHT_STD_DEV,
8587
},
8688
'task': {
8789
'mu': torch.zeros(self.num_tasks, self.embedding_dim),
88-
'sigma': torch.sqrt(torch.rand(self.num_blocks, self.embedding_dim)),
90+
'sigma': torch.ones(self.num_tasks, self.embedding_dim) *\
91+
tfa_models.SOURCE_WEIGHT_STD_DEV,
8992
},
9093
'template': {
9194
'factor_centers': {
@@ -95,7 +98,8 @@ def __init__(self, num_blocks, num_times, num_factors, num_subjects,
9598
'factor_log_widths': {
9699
'mu': hyper_means['factor_log_widths'] *\
97100
torch.ones(self._num_factors),
98-
'sigma': torch.sqrt(torch.rand(self._num_factors))
101+
'sigma': torch.ones(self._num_factors) *
102+
tfa_models.SOURCE_LOG_WIDTH_STD_DEV,
99103
}
100104
},
101105
'block': {
@@ -136,32 +140,28 @@ def __init__(self, num_factors, block_subjects, block_tasks, num_blocks=1,
136140
embedding_dim)
137141
self.factors_embedding = nn.Sequential(
138142
nn.Linear(self._embedding_dim, self._num_factors),
139-
nn.Tanhshrink(),
140-
nn.Linear(self._num_factors, self._num_factors * 8),
143+
nn.Softsign(),
141144
)
145+
self.centers_embedding = nn.Linear(self._num_factors,
146+
self._num_factors * 3)
147+
self.log_widths_embedding = nn.Linear(self._num_factors,
148+
self._num_factors)
142149
self.weights_embedding = nn.Sequential(
143150
nn.Linear(self._embedding_dim * 2, self._num_factors),
144-
nn.Tanhshrink(),
151+
nn.Softsign(),
145152
nn.Linear(self._num_factors, self._num_factors * 2),
146153
)
147154
self.softplus = nn.Softplus()
148155

149156
self.epsilon = nn.Parameter(torch.Tensor([tfa_models.VOXEL_NOISE]))
150157

151158
if hyper_means is not None:
152-
self.weights_embedding[-1].bias = nn.Parameter(torch.cat(
153-
(hyper_means['weights'].mean(0),
154-
torch.sqrt(torch.rand(self._num_factors))),
155-
dim=0
156-
))
157-
self.factors_embedding[-1].bias = nn.Parameter(torch.cat(
158-
(hyper_means['factor_centers'],
159-
torch.ones(self._num_factors, 3),
160-
torch.ones(self._num_factors, 1) *
161-
hyper_means['factor_log_widths'],
162-
torch.sqrt(torch.rand(self._num_factors, 1))),
163-
dim=1,
164-
).view(self._num_factors * 8))
159+
self.centers_embedding.bias = nn.Parameter(
160+
hyper_means['factor_centers'].view(self._num_factors * 3)
161+
)
162+
self.log_widths_embedding.bias = nn.Parameter(
163+
torch.ones(self._num_factors) * hyper_means['factor_log_widths']
164+
)
165165

166166
def forward(self, trace, times=None, blocks=None,
167167
num_particles=tfa_models.NUM_PARTICLES):
@@ -187,48 +187,53 @@ def forward(self, trace, times=None, blocks=None,
187187
if ('z^F_%d' % subject) not in trace:
188188
factors_embed = trace.normal(
189189
params['factors']['mu'][:, subject, :],
190-
params['factors']['sigma'][:, subject, :],
190+
self.softplus(params['factors']['sigma'][:, subject, :]),
191191
name='z^F_%d' % subject
192192
)
193193
if ('z^P_%d' % subject) not in trace:
194194
subject_embed = trace.normal(
195195
params['subject']['mu'][:, subject, :],
196-
params['subject']['sigma'][:, subject, :],
196+
self.softplus(params['subject']['sigma'][:, subject, :]),
197197
name='z^P_%d' % subject
198198
)
199199
if ('z^S_%d' % task) not in trace:
200-
task_embed = trace.normal(params['task']['mu'][:, task],
201-
params['task']['sigma'][:, task],
202-
name='z^S_%d' % task)
200+
task_embed = trace.normal(
201+
params['task']['mu'][:, task],
202+
self.softplus(params['task']['sigma'][:, task]),
203+
name='z^S_%d' % task
204+
)
203205

204206
factor_params = self.factors_embedding(factors_embed)
205-
factor_params = factor_params.view(-1, self._num_factors, 8)
207+
centers_predictions = self.centers_embedding(factor_params).view(
208+
-1, self._num_factors, 3
209+
)
210+
log_widths_predictions = self.log_widths_embedding(factor_params).\
211+
view(-1, self._num_factors)
206212
weights_embed = torch.cat((subject_embed, task_embed), dim=-1)
207-
weight_params = self.weights_embedding(weights_embed).view(
213+
weight_predictions = self.weights_embedding(weights_embed).view(
208214
-1, self._num_factors, 2
209215
)
210216

211-
trace.normal(weight_params[:, :, 0], self.epsilon[0],
212-
name='mu^W_%d' % b)
213-
trace.normal(self.softplus(weight_params[:, :, 1]), self.epsilon[0],
214-
name='sigma^W_%d' % b)
217+
weights_mu = trace.normal(weight_predictions[:, :, 0],
218+
self.epsilon[0], name='mu^W_%d' % b)
219+
weights_sigma = trace.normal(weight_predictions[:, :, 1],
220+
self.epsilon[0], name='sigma^W_%d' % b)
221+
weights_params = params['block']['weights']
215222
weights[i] = trace.normal(
216-
params['block']['weights']['mu'][:, b, ts[0]:ts[1], :],
217-
params['block']['weights']['sigma'][:, b, ts[0]:ts[1], :],
223+
weights_params['mu'][:, b, ts[0]:ts[1], :] +
224+
weights_mu.unsqueeze(1),
225+
self.softplus(weights_params['sigma'][:, b, ts[0]:ts[1], :] +
226+
weights_sigma.unsqueeze(1)),
218227
name='Weights%dt%d-%d' % (b, ts[0], ts[1])
219228
)
220229
factor_centers[i] = trace.normal(
221-
factor_params[:, :, 0:3],
222-
self.softplus(factor_params[:, :, 3:6]),
230+
centers_predictions,
231+
self.epsilon[0],
223232
name='FactorCenters%d' % b
224233
)
225234
factor_log_widths[i] = trace.normal(
226-
factor_params[:, :, 6].contiguous().view(
227-
-1, self._num_factors
228-
),
229-
self.softplus(factor_params[:, :, 7].contiguous().view(
230-
-1, self._num_factors
231-
)), name='FactorLogWidths%d' % b
235+
log_widths_predictions,
236+
self.epsilon[0], name='FactorLogWidths%d' % b
232237
)
233238

234239
return weights, factor_centers, factor_log_widths
@@ -297,5 +302,4 @@ def forward(self, trace, times=None, guide=probtorch.Trace(),
297302
}
298303

299304
return self.htfa_model(trace, times, guide, blocks=blocks,
300-
observations=observations,
301-
weights_params=weight_params)
305+
observations=observations)

htfa_torch/htfa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,13 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
130130
dec(p, times=trs, guide=q, observations=activations,
131131
blocks=block_batch)
132132

133-
def block_rv_weight(node):
133+
def block_rv_weight(node, prior=True):
134134
result = 1.0
135135
if measure_occurrences:
136136
rv_occurrences[node] += 1
137137
result /= rv_occurrences[node]
138138
return result
139-
free_energy = tfa.hierarchical_free_energy(
139+
free_energy, _, _ = tfa.hierarchical_free_energy(
140140
q, p,
141141
rv_weight=block_rv_weight,
142142
num_particles=num_particles

htfa_torch/htfa_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(self, brain_center, brain_center_std_dev, num_blocks,
179179
'factor_center_noise': torch.ones(self._num_blocks),
180180
'factor_log_width_noise': torch.ones(self._num_blocks),
181181
'weights': {
182-
'mu': torch.randn(self._num_blocks, self._num_factors),
182+
'mu': torch.zeros(self._num_blocks, self._num_factors),
183183
'sigma': tfa_models.SOURCE_WEIGHT_STD_DEV *\
184184
torch.ones(self._num_blocks, self._num_factors)
185185
},

0 commit comments

Comments
 (0)