Skip to content

Commit 1c69eda

Browse files
authored
Merge pull request #23 from neu-spiral/develop
Develop
2 parents e518254 + 243a9a0 commit 1c69eda

File tree

5 files changed

+1606
-59
lines changed

5 files changed

+1606
-59
lines changed

htfa_torch/dtfa.py

Lines changed: 304 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Sketch of Deep TFA architecture"""
1+
"""Perform deep topographic factor analysis on fMRI data"""
22

33
__author__ = ('Jan-Willem van de Meent',
44
'Eli Sennesh',
@@ -7,59 +7,311 @@
77
'e.sennesh@northeastern.edu',
88
'khan.zu@husky.neu.edu')
99

10-
from collections import defaultdict
10+
import logging
11+
import os
12+
import pickle
13+
import time
14+
15+
try:
16+
if __name__ == '__main__':
17+
import matplotlib
18+
matplotlib.use('TkAgg')
19+
finally:
20+
import matplotlib.pyplot as plt
21+
import nilearn.image
22+
import nilearn.plotting as niplot
23+
import numpy as np
24+
import scipy.io as sio
1125
import torch
26+
import torch.distributions as dists
27+
from torch.autograd import Variable
28+
import torch.nn as nn
29+
from torch.nn import Parameter
30+
import torch.utils.data
31+
1232
import probtorch
1333

14-
# NOTE: I am writing this as a model relative to PyTorch master,
15-
# which no longer requires explicit wrapping in Variable(...)
16-
17-
class DeepTFA(torch.nn.Module):
18-
def __init__(self, N=50, T=200, D=2, E=2, K=24):
19-
# generative model
20-
self.p_z_w_mean = torch.zeros(E)
21-
self.p_z_w_std = torch.ones(E)
22-
self.w = torch.nn.Sequential(
23-
torch.nn.Linear(E, K/2),
24-
torch.nn.ReLU(),
25-
torch.nn.Linear(K/2, K))
26-
self.q_z_f_mean = torch.zeros(D)
27-
self.q_z_f_std = torch.ones(D)
28-
self.h_f = torch.nn.Sequential(
29-
torch.nn.Linear(D, K/2),
30-
torch.nn.ReLU())
31-
self.x_f = torch.nn.Linear(K/2, 3*K)
32-
self.log_rho_f = torch.nn.Linear(K/2, K)
33-
self.sigma_y = Parameter(1.0)
34-
# variational parameters
35-
self.q_z_f_mean = Parameter(torch.zeros(N, D))
36-
self.q_z_f_std = Parameter(torch.ones(N, D))
37-
self.q_z_w_mean = Parameter(torch.zeros(N, T, E))
38-
self.q_z_w_std = Parameter(torch.ones(N, T, E))
39-
40-
def forward(self, x, y, n, t):
41-
p = probtorch.Trace()
34+
from . import dtfa_models
35+
from . import tfa
36+
from . import tfa_models
37+
from . import utils
38+
39+
class DeepTFA:
40+
"""Overall container for a run of Deep TFA"""
41+
def __init__(self, data_files, mask, num_factors=tfa_models.NUM_FACTORS,
42+
embedding_dim=2, tasks=[]):
43+
self.num_factors = num_factors
44+
self.num_subjects = len(data_files)
45+
self.mask = mask
46+
datasets = [utils.load_dataset(data_file, mask=mask)
47+
for data_file in data_files]
48+
self.voxel_activations = [dataset[0] for dataset in datasets]
49+
self._images = [dataset[1] for dataset in datasets]
50+
self.voxel_locations = [dataset[2] for dataset in datasets]
51+
self._names = [dataset[3] for dataset in datasets]
52+
self._templates = [dataset[4] for dataset in datasets]
53+
self._tasks = tasks
54+
55+
# Pull out relevant dimensions: the number of time instants and the
56+
# number of voxels in each timewise "slice"
57+
self.num_times = [acts.shape[0] for acts in self.voxel_activations]
58+
self.num_voxels = [acts.shape[1] for acts in self.voxel_activations]
59+
60+
self.generative = dtfa_models.DeepTFAModel(
61+
self.voxel_locations, self.voxel_activations, self.num_factors,
62+
self.num_subjects, self.num_times, embedding_dim
63+
)
64+
self.variational = dtfa_models.DeepTFAGuide(self.num_subjects,
65+
self.num_times,
66+
embedding_dim)
67+
68+
def sample(self, posterior_predictive=False, num_particles=1):
4269
q = probtorch.Trace()
43-
z_w = q.normal(self.q_z_w_mean[n, t],
44-
self.q_z_w_std[n, t],
45-
name='z_w')
46-
z_w = p.normal(self.p_z_w_mean,
47-
self.p_z_w_std,
48-
value=q['z_w'],
49-
name='z_w')
50-
w = self.w(z_w)
51-
z_f = q.normal(self.q_z_f_mean[n],
52-
self.q_z_f_std[n],
53-
name='z_f')
54-
z_f = p.normal(self.z_f_mean,
55-
self.z_f_std,
56-
value=q['z_f']
57-
name='z_f')
58-
x_f = self.x_f(z_f)
59-
rho_f = torch.exp(self.log_rho_f(z_f))
60-
f = rbf(x, x_f, rho_f)
61-
y = p.normal(w * f,
62-
self.sigma_y,
63-
value='y',
64-
name='y')
70+
if posterior_predictive:
71+
self.variational(q, self.generative.embedding,
72+
num_particles=num_particles)
73+
p = probtorch.Trace()
74+
self.generative(p, guide=q,
75+
observations=[q for s in range(self.num_subjects)])
6576
return p, q
77+
78+
def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
79+
log_level=logging.WARNING, num_particles=tfa_models.NUM_PARTICLES,
80+
batch_size=64, use_cuda=True):
81+
"""Optimize the variational guide to reflect the data for `num_steps`"""
82+
logging.basicConfig(format='%(asctime)s %(message)s',
83+
datefmt='%m/%d/%Y %H:%M:%S',
84+
level=log_level)
85+
activations = torch.Tensor(self.num_times[0], self.num_voxels[0],
86+
len(self.voxel_activations))
87+
for s in range(self.num_subjects):
88+
activations[:, :, s] = self.voxel_activations[s]
89+
activations_loader = torch.utils.data.DataLoader(
90+
torch.utils.data.TensorDataset(
91+
activations,
92+
torch.zeros(activations.shape[0])
93+
),
94+
batch_size=batch_size
95+
)
96+
if tfa.CUDA and use_cuda:
97+
variational = torch.nn.DataParallel(self.variational)
98+
generative = torch.nn.DataParallel(self.generative)
99+
variational.cuda()
100+
generative.cuda()
101+
else:
102+
variational = self.variational
103+
generative = self.generative
104+
105+
optimizer = torch.optim.Adam(list(variational.parameters()) +\
106+
list(generative.parameters()),
107+
lr=learning_rate)
108+
variational.train()
109+
generative.train()
110+
111+
free_energies = list(range(num_steps))
112+
lls = list(range(num_steps))
113+
114+
for epoch in range(num_steps):
115+
start = time.time()
116+
epoch_free_energies = list(range(len(activations_loader)))
117+
epoch_lls = list(range(len(activations_loader)))
118+
119+
for (batch, (data, _)) in enumerate(activations_loader):
120+
activations = [{'Y': Variable(data[:, :, s])}
121+
for s in range(self.num_subjects)]
122+
trs = (batch * batch_size, None)
123+
trs = (trs[0], trs[0] + activations[0]['Y'].shape[0])
124+
125+
optimizer.zero_grad()
126+
q = probtorch.Trace()
127+
variational(q, self.generative.embedding, times=trs,
128+
num_particles=num_particles)
129+
p = probtorch.Trace()
130+
generative(p, times=trs, guide=q, observations=activations)
131+
132+
epoch_free_energies[batch] =\
133+
tfa.free_energy(q, p, num_particles=num_particles)
134+
epoch_lls[batch] =\
135+
tfa.log_likelihood(q, p, num_particles=num_particles)
136+
epoch_free_energies[batch].backward()
137+
optimizer.step()
138+
if tfa.CUDA and use_cuda:
139+
epoch_free_energies[batch] = epoch_free_energies[batch].cpu().data.numpy()
140+
epoch_lls[batch] = epoch_lls[batch].cpu().data.numpy()
141+
142+
143+
144+
free_energies[epoch] = np.array(epoch_free_energies).sum(0)
145+
free_energies[epoch] = free_energies[epoch].sum(0)
146+
lls[epoch] = np.array(epoch_lls).sum(0)
147+
lls[epoch] = lls[epoch].sum(0)
148+
149+
end = time.time()
150+
msg = tfa.EPOCH_MSG % (epoch + 1, (end - start) * 1000, free_energies[epoch])
151+
logging.info(msg)
152+
153+
if tfa.CUDA and use_cuda:
154+
variational.cpu()
155+
generative.cpu()
156+
157+
return np.vstack([free_energies, lls])
158+
159+
def results(self, subject):
160+
hyperparams = self.variational.hyperparams.state_vardict()
161+
162+
z_f = hyperparams['embedding']['factors']['mu'][subject]
163+
z_f_embedded = self.generative.embedding.embedder(z_f)
164+
165+
factors = self.generative.embedding.factors_generator(z_f_embedded)
166+
factors_shape = (self.num_factors, 4)
167+
if len(factors.shape) > 1:
168+
factors_shape = (-1,) + factors_shape
169+
factors = factors.view(*factors_shape)
170+
if len(factors.shape) > 2:
171+
centers = factors[:, :, 0:3]
172+
log_widths = factors[:, :, 3]
173+
else:
174+
centers = factors[:, 0:3]
175+
log_widths = factors[:, 3]
176+
177+
z_w = hyperparams['embedding']['weights']['mu'][subject]
178+
weights = self.generative.embedding.weights_generator(z_w)
179+
180+
return {
181+
'weights': weights[0:self.voxel_activations[subject].shape[0], :],
182+
'factors': tfa_models.radial_basis(self.voxel_locations[subject],
183+
centers.data, log_widths.data),
184+
'factor_centers': centers.data,
185+
'factor_log_widths': log_widths.data,
186+
}
187+
188+
def embeddings(self):
189+
hyperparams = self.variational.hyperparams.state_vardict()
190+
191+
return {
192+
'factors': hyperparams['embedding']['factors']['mu'],
193+
'weights': hyperparams['embedding']['weights']['mu'],
194+
}
195+
196+
def plot_factor_centers(self, subject, filename=None, show=True,
197+
trace=None):
198+
hyperparams = self.variational.hyperparams.state_vardict()
199+
z_f_std_dev = hyperparams['embedding']['factors']['sigma'][subject]
200+
201+
if trace:
202+
z_f = trace['z_f%d' % subject].value
203+
if len(z_f.shape) > 1:
204+
if z_f.shape[0] > 1:
205+
z_f_std_dev = z_f.std(0)
206+
z_f = z_f.mean(0)
207+
else:
208+
z_f = hyperparams['embedding']['factors']['mu'][subject]
209+
210+
z_f_embedded = self.generative.embedding.embedder(z_f)
211+
212+
factors = self.generative.embedding.factors_generator(z_f_embedded)
213+
factors_shape = (self.num_factors, 4)
214+
if len(factors.shape) > 1:
215+
factors_shape = (-1,) + factors_shape
216+
factors = factors.view(*factors_shape)
217+
if len(factors.shape) > 2:
218+
factor_centers = factors[:, :, 0:3]
219+
factor_log_widths = factors[:, :, 3]
220+
else:
221+
factor_centers = factors[:, 0:3]
222+
factor_log_widths = factors[:, 3]
223+
224+
factor_uncertainties = z_f_std_dev.norm().expand(self.num_factors, 1)
225+
226+
plot = niplot.plot_connectome(
227+
np.eye(self.num_factors),
228+
factor_centers.data.numpy(),
229+
node_color=utils.uncertainty_palette(factor_uncertainties.data),
230+
node_size=np.exp(factor_log_widths.data.numpy() - np.log(2))
231+
)
232+
233+
if filename is not None:
234+
plot.savefig(filename)
235+
if show:
236+
niplot.show()
237+
238+
return plot
239+
240+
def plot_original_brain(self, subject=None, filename=None, show=True,
241+
plot_abs=False, t=0):
242+
if subject is None:
243+
subject = np.random.choice(self.num_subjects, 1)[0]
244+
image = nilearn.image.index_img(self._images[subject], t)
245+
plot = niplot.plot_glass_brain(image, plot_abs=plot_abs)
246+
247+
if filename is not None:
248+
plot.savefig(filename)
249+
if show:
250+
niplot.show()
251+
252+
return plot
253+
254+
def plot_reconstruction(self, subject=None, filename=None, show=True,
255+
plot_abs=False, t=0):
256+
if subject is None:
257+
subject = np.random.choice(self.num_subjects, 1)[0]
258+
259+
results = self.results(subject)
260+
261+
reconstruction = results['weights'].data @ results['factors']
262+
263+
image = utils.cmu2nii(reconstruction.numpy(),
264+
self.voxel_locations[subject].numpy(),
265+
self._templates[subject])
266+
image_slice = nilearn.image.index_img(image, t)
267+
plot = niplot.plot_glass_brain(image_slice, plot_abs=plot_abs)
268+
269+
logging.info(
270+
'Reconstruction Error (Frobenius Norm): %.8e',
271+
np.linalg.norm(
272+
(reconstruction - self.voxel_activations[subject]).numpy()
273+
)
274+
)
275+
276+
if filename is not None:
277+
plot.savefig(filename)
278+
if show:
279+
niplot.show()
280+
281+
return plot
282+
283+
def scatter_factor_embedding(self, filename=None, show=True):
284+
hyperparams = self.variational.hyperparams.state_vardict()
285+
z_f = hyperparams['embedding']['factors']['mu'].data
286+
287+
tasks = self._tasks
288+
if tasks is None or len(tasks) == 0:
289+
tasks = list(range(self.num_subjects))
290+
palette = dict(zip(tasks, utils.compose_palette(len(tasks))))
291+
subject_colors = np.array([palette[task] for task in tasks])
292+
293+
plt.scatter(x=z_f[:, 0], y=z_f[:, 1], c=subject_colors)
294+
utils.palette_legend(list(palette.keys()), list(palette.values()))
295+
296+
if filename is not None:
297+
plt.savefig(filename)
298+
if show:
299+
plt.show()
300+
301+
def scatter_weights_embedding(self, t=0, filename=None, show=True):
302+
hyperparams = self.variational.hyperparams.state_vardict()
303+
z_f = hyperparams['embedding']['weights']['mu'][:, t, :].data
304+
305+
tasks = self._tasks
306+
if tasks is None or len(tasks) == 0:
307+
tasks = list(range(self.num_subjects))
308+
palette = dict(zip(tasks, utils.compose_palette(len(tasks))))
309+
subject_colors = np.array([palette[task] for task in tasks])
310+
311+
plt.scatter(x=z_f[:, 0], y=z_f[:, 1], c=subject_colors)
312+
utils.palette_legend(list(palette.keys()), list(palette.values()))
313+
314+
if filename is not None:
315+
plt.savefig(filename)
316+
if show:
317+
plt.show()

0 commit comments

Comments
 (0)