|
1 | | -"""Sketch of Deep TFA architecture""" |
| 1 | +"""Perform deep topographic factor analysis on fMRI data""" |
2 | 2 |
|
3 | 3 | __author__ = ('Jan-Willem van de Meent', |
4 | 4 | 'Eli Sennesh', |
|
7 | 7 | 'e.sennesh@northeastern.edu', |
8 | 8 | 'khan.zu@husky.neu.edu') |
9 | 9 |
|
10 | | -from collections import defaultdict |
| 10 | +import logging |
| 11 | +import os |
| 12 | +import pickle |
| 13 | +import time |
| 14 | + |
| 15 | +try: |
| 16 | + if __name__ == '__main__': |
| 17 | + import matplotlib |
| 18 | + matplotlib.use('TkAgg') |
| 19 | +finally: |
| 20 | + import matplotlib.pyplot as plt |
| 21 | +import nilearn.image |
| 22 | +import nilearn.plotting as niplot |
| 23 | +import numpy as np |
| 24 | +import scipy.io as sio |
11 | 25 | import torch |
| 26 | +import torch.distributions as dists |
| 27 | +from torch.autograd import Variable |
| 28 | +import torch.nn as nn |
| 29 | +from torch.nn import Parameter |
| 30 | +import torch.utils.data |
| 31 | + |
12 | 32 | import probtorch |
13 | 33 |
|
14 | | -# NOTE: I am writing this as a model relative to PyTorch master, |
15 | | -# which no longer requires explicit wrapping in Variable(...) |
16 | | - |
17 | | -class DeepTFA(torch.nn.Module): |
18 | | - def __init__(self, N=50, T=200, D=2, E=2, K=24): |
19 | | - # generative model |
20 | | - self.p_z_w_mean = torch.zeros(E) |
21 | | - self.p_z_w_std = torch.ones(E) |
22 | | - self.w = torch.nn.Sequential( |
23 | | - torch.nn.Linear(E, K/2), |
24 | | - torch.nn.ReLU(), |
25 | | - torch.nn.Linear(K/2, K)) |
26 | | - self.q_z_f_mean = torch.zeros(D) |
27 | | - self.q_z_f_std = torch.ones(D) |
28 | | - self.h_f = torch.nn.Sequential( |
29 | | - torch.nn.Linear(D, K/2), |
30 | | - torch.nn.ReLU()) |
31 | | - self.x_f = torch.nn.Linear(K/2, 3*K) |
32 | | - self.log_rho_f = torch.nn.Linear(K/2, K) |
33 | | - self.sigma_y = Parameter(1.0) |
34 | | - # variational parameters |
35 | | - self.q_z_f_mean = Parameter(torch.zeros(N, D)) |
36 | | - self.q_z_f_std = Parameter(torch.ones(N, D)) |
37 | | - self.q_z_w_mean = Parameter(torch.zeros(N, T, E)) |
38 | | - self.q_z_w_std = Parameter(torch.ones(N, T, E)) |
39 | | - |
40 | | - def forward(self, x, y, n, t): |
41 | | - p = probtorch.Trace() |
| 34 | +from . import dtfa_models |
| 35 | +from . import tfa |
| 36 | +from . import tfa_models |
| 37 | +from . import utils |
| 38 | + |
| 39 | +class DeepTFA: |
| 40 | + """Overall container for a run of Deep TFA""" |
| 41 | + def __init__(self, data_files, mask, num_factors=tfa_models.NUM_FACTORS, |
| 42 | + embedding_dim=2, tasks=[]): |
| 43 | + self.num_factors = num_factors |
| 44 | + self.num_subjects = len(data_files) |
| 45 | + self.mask = mask |
| 46 | + datasets = [utils.load_dataset(data_file, mask=mask) |
| 47 | + for data_file in data_files] |
| 48 | + self.voxel_activations = [dataset[0] for dataset in datasets] |
| 49 | + self._images = [dataset[1] for dataset in datasets] |
| 50 | + self.voxel_locations = [dataset[2] for dataset in datasets] |
| 51 | + self._names = [dataset[3] for dataset in datasets] |
| 52 | + self._templates = [dataset[4] for dataset in datasets] |
| 53 | + self._tasks = tasks |
| 54 | + |
| 55 | + # Pull out relevant dimensions: the number of time instants and the |
| 56 | + # number of voxels in each timewise "slice" |
| 57 | + self.num_times = [acts.shape[0] for acts in self.voxel_activations] |
| 58 | + self.num_voxels = [acts.shape[1] for acts in self.voxel_activations] |
| 59 | + |
| 60 | + self.generative = dtfa_models.DeepTFAModel( |
| 61 | + self.voxel_locations, self.voxel_activations, self.num_factors, |
| 62 | + self.num_subjects, self.num_times, embedding_dim |
| 63 | + ) |
| 64 | + self.variational = dtfa_models.DeepTFAGuide(self.num_subjects, |
| 65 | + self.num_times, |
| 66 | + embedding_dim) |
| 67 | + |
| 68 | + def sample(self, posterior_predictive=False, num_particles=1): |
42 | 69 | q = probtorch.Trace() |
43 | | - z_w = q.normal(self.q_z_w_mean[n, t], |
44 | | - self.q_z_w_std[n, t], |
45 | | - name='z_w') |
46 | | - z_w = p.normal(self.p_z_w_mean, |
47 | | - self.p_z_w_std, |
48 | | - value=q['z_w'], |
49 | | - name='z_w') |
50 | | - w = self.w(z_w) |
51 | | - z_f = q.normal(self.q_z_f_mean[n], |
52 | | - self.q_z_f_std[n], |
53 | | - name='z_f') |
54 | | - z_f = p.normal(self.z_f_mean, |
55 | | - self.z_f_std, |
56 | | - value=q['z_f'] |
57 | | - name='z_f') |
58 | | - x_f = self.x_f(z_f) |
59 | | - rho_f = torch.exp(self.log_rho_f(z_f)) |
60 | | - f = rbf(x, x_f, rho_f) |
61 | | - y = p.normal(w * f, |
62 | | - self.sigma_y, |
63 | | - value='y', |
64 | | - name='y') |
| 70 | + if posterior_predictive: |
| 71 | + self.variational(q, self.generative.embedding, |
| 72 | + num_particles=num_particles) |
| 73 | + p = probtorch.Trace() |
| 74 | + self.generative(p, guide=q, |
| 75 | + observations=[q for s in range(self.num_subjects)]) |
65 | 76 | return p, q |
| 77 | + |
| 78 | + def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE, |
| 79 | + log_level=logging.WARNING, num_particles=tfa_models.NUM_PARTICLES, |
| 80 | + batch_size=64, use_cuda=True): |
| 81 | + """Optimize the variational guide to reflect the data for `num_steps`""" |
| 82 | + logging.basicConfig(format='%(asctime)s %(message)s', |
| 83 | + datefmt='%m/%d/%Y %H:%M:%S', |
| 84 | + level=log_level) |
| 85 | + activations = torch.Tensor(self.num_times[0], self.num_voxels[0], |
| 86 | + len(self.voxel_activations)) |
| 87 | + for s in range(self.num_subjects): |
| 88 | + activations[:, :, s] = self.voxel_activations[s] |
| 89 | + activations_loader = torch.utils.data.DataLoader( |
| 90 | + torch.utils.data.TensorDataset( |
| 91 | + activations, |
| 92 | + torch.zeros(activations.shape[0]) |
| 93 | + ), |
| 94 | + batch_size=batch_size |
| 95 | + ) |
| 96 | + if tfa.CUDA and use_cuda: |
| 97 | + variational = torch.nn.DataParallel(self.variational) |
| 98 | + generative = torch.nn.DataParallel(self.generative) |
| 99 | + variational.cuda() |
| 100 | + generative.cuda() |
| 101 | + else: |
| 102 | + variational = self.variational |
| 103 | + generative = self.generative |
| 104 | + |
| 105 | + optimizer = torch.optim.Adam(list(variational.parameters()) +\ |
| 106 | + list(generative.parameters()), |
| 107 | + lr=learning_rate) |
| 108 | + variational.train() |
| 109 | + generative.train() |
| 110 | + |
| 111 | + free_energies = list(range(num_steps)) |
| 112 | + lls = list(range(num_steps)) |
| 113 | + |
| 114 | + for epoch in range(num_steps): |
| 115 | + start = time.time() |
| 116 | + epoch_free_energies = list(range(len(activations_loader))) |
| 117 | + epoch_lls = list(range(len(activations_loader))) |
| 118 | + |
| 119 | + for (batch, (data, _)) in enumerate(activations_loader): |
| 120 | + activations = [{'Y': Variable(data[:, :, s])} |
| 121 | + for s in range(self.num_subjects)] |
| 122 | + trs = (batch * batch_size, None) |
| 123 | + trs = (trs[0], trs[0] + activations[0]['Y'].shape[0]) |
| 124 | + |
| 125 | + optimizer.zero_grad() |
| 126 | + q = probtorch.Trace() |
| 127 | + variational(q, self.generative.embedding, times=trs, |
| 128 | + num_particles=num_particles) |
| 129 | + p = probtorch.Trace() |
| 130 | + generative(p, times=trs, guide=q, observations=activations) |
| 131 | + |
| 132 | + epoch_free_energies[batch] =\ |
| 133 | + tfa.free_energy(q, p, num_particles=num_particles) |
| 134 | + epoch_lls[batch] =\ |
| 135 | + tfa.log_likelihood(q, p, num_particles=num_particles) |
| 136 | + epoch_free_energies[batch].backward() |
| 137 | + optimizer.step() |
| 138 | + if tfa.CUDA and use_cuda: |
| 139 | + epoch_free_energies[batch] = epoch_free_energies[batch].cpu().data.numpy() |
| 140 | + epoch_lls[batch] = epoch_lls[batch].cpu().data.numpy() |
| 141 | + |
| 142 | + |
| 143 | + |
| 144 | + free_energies[epoch] = np.array(epoch_free_energies).sum(0) |
| 145 | + free_energies[epoch] = free_energies[epoch].sum(0) |
| 146 | + lls[epoch] = np.array(epoch_lls).sum(0) |
| 147 | + lls[epoch] = lls[epoch].sum(0) |
| 148 | + |
| 149 | + end = time.time() |
| 150 | + msg = tfa.EPOCH_MSG % (epoch + 1, (end - start) * 1000, free_energies[epoch]) |
| 151 | + logging.info(msg) |
| 152 | + |
| 153 | + if tfa.CUDA and use_cuda: |
| 154 | + variational.cpu() |
| 155 | + generative.cpu() |
| 156 | + |
| 157 | + return np.vstack([free_energies, lls]) |
| 158 | + |
| 159 | + def results(self, subject): |
| 160 | + hyperparams = self.variational.hyperparams.state_vardict() |
| 161 | + |
| 162 | + z_f = hyperparams['embedding']['factors']['mu'][subject] |
| 163 | + z_f_embedded = self.generative.embedding.embedder(z_f) |
| 164 | + |
| 165 | + factors = self.generative.embedding.factors_generator(z_f_embedded) |
| 166 | + factors_shape = (self.num_factors, 4) |
| 167 | + if len(factors.shape) > 1: |
| 168 | + factors_shape = (-1,) + factors_shape |
| 169 | + factors = factors.view(*factors_shape) |
| 170 | + if len(factors.shape) > 2: |
| 171 | + centers = factors[:, :, 0:3] |
| 172 | + log_widths = factors[:, :, 3] |
| 173 | + else: |
| 174 | + centers = factors[:, 0:3] |
| 175 | + log_widths = factors[:, 3] |
| 176 | + |
| 177 | + z_w = hyperparams['embedding']['weights']['mu'][subject] |
| 178 | + weights = self.generative.embedding.weights_generator(z_w) |
| 179 | + |
| 180 | + return { |
| 181 | + 'weights': weights[0:self.voxel_activations[subject].shape[0], :], |
| 182 | + 'factors': tfa_models.radial_basis(self.voxel_locations[subject], |
| 183 | + centers.data, log_widths.data), |
| 184 | + 'factor_centers': centers.data, |
| 185 | + 'factor_log_widths': log_widths.data, |
| 186 | + } |
| 187 | + |
| 188 | + def embeddings(self): |
| 189 | + hyperparams = self.variational.hyperparams.state_vardict() |
| 190 | + |
| 191 | + return { |
| 192 | + 'factors': hyperparams['embedding']['factors']['mu'], |
| 193 | + 'weights': hyperparams['embedding']['weights']['mu'], |
| 194 | + } |
| 195 | + |
| 196 | + def plot_factor_centers(self, subject, filename=None, show=True, |
| 197 | + trace=None): |
| 198 | + hyperparams = self.variational.hyperparams.state_vardict() |
| 199 | + z_f_std_dev = hyperparams['embedding']['factors']['sigma'][subject] |
| 200 | + |
| 201 | + if trace: |
| 202 | + z_f = trace['z_f%d' % subject].value |
| 203 | + if len(z_f.shape) > 1: |
| 204 | + if z_f.shape[0] > 1: |
| 205 | + z_f_std_dev = z_f.std(0) |
| 206 | + z_f = z_f.mean(0) |
| 207 | + else: |
| 208 | + z_f = hyperparams['embedding']['factors']['mu'][subject] |
| 209 | + |
| 210 | + z_f_embedded = self.generative.embedding.embedder(z_f) |
| 211 | + |
| 212 | + factors = self.generative.embedding.factors_generator(z_f_embedded) |
| 213 | + factors_shape = (self.num_factors, 4) |
| 214 | + if len(factors.shape) > 1: |
| 215 | + factors_shape = (-1,) + factors_shape |
| 216 | + factors = factors.view(*factors_shape) |
| 217 | + if len(factors.shape) > 2: |
| 218 | + factor_centers = factors[:, :, 0:3] |
| 219 | + factor_log_widths = factors[:, :, 3] |
| 220 | + else: |
| 221 | + factor_centers = factors[:, 0:3] |
| 222 | + factor_log_widths = factors[:, 3] |
| 223 | + |
| 224 | + factor_uncertainties = z_f_std_dev.norm().expand(self.num_factors, 1) |
| 225 | + |
| 226 | + plot = niplot.plot_connectome( |
| 227 | + np.eye(self.num_factors), |
| 228 | + factor_centers.data.numpy(), |
| 229 | + node_color=utils.uncertainty_palette(factor_uncertainties.data), |
| 230 | + node_size=np.exp(factor_log_widths.data.numpy() - np.log(2)) |
| 231 | + ) |
| 232 | + |
| 233 | + if filename is not None: |
| 234 | + plot.savefig(filename) |
| 235 | + if show: |
| 236 | + niplot.show() |
| 237 | + |
| 238 | + return plot |
| 239 | + |
| 240 | + def plot_original_brain(self, subject=None, filename=None, show=True, |
| 241 | + plot_abs=False, t=0): |
| 242 | + if subject is None: |
| 243 | + subject = np.random.choice(self.num_subjects, 1)[0] |
| 244 | + image = nilearn.image.index_img(self._images[subject], t) |
| 245 | + plot = niplot.plot_glass_brain(image, plot_abs=plot_abs) |
| 246 | + |
| 247 | + if filename is not None: |
| 248 | + plot.savefig(filename) |
| 249 | + if show: |
| 250 | + niplot.show() |
| 251 | + |
| 252 | + return plot |
| 253 | + |
| 254 | + def plot_reconstruction(self, subject=None, filename=None, show=True, |
| 255 | + plot_abs=False, t=0): |
| 256 | + if subject is None: |
| 257 | + subject = np.random.choice(self.num_subjects, 1)[0] |
| 258 | + |
| 259 | + results = self.results(subject) |
| 260 | + |
| 261 | + reconstruction = results['weights'].data @ results['factors'] |
| 262 | + |
| 263 | + image = utils.cmu2nii(reconstruction.numpy(), |
| 264 | + self.voxel_locations[subject].numpy(), |
| 265 | + self._templates[subject]) |
| 266 | + image_slice = nilearn.image.index_img(image, t) |
| 267 | + plot = niplot.plot_glass_brain(image_slice, plot_abs=plot_abs) |
| 268 | + |
| 269 | + logging.info( |
| 270 | + 'Reconstruction Error (Frobenius Norm): %.8e', |
| 271 | + np.linalg.norm( |
| 272 | + (reconstruction - self.voxel_activations[subject]).numpy() |
| 273 | + ) |
| 274 | + ) |
| 275 | + |
| 276 | + if filename is not None: |
| 277 | + plot.savefig(filename) |
| 278 | + if show: |
| 279 | + niplot.show() |
| 280 | + |
| 281 | + return plot |
| 282 | + |
| 283 | + def scatter_factor_embedding(self, filename=None, show=True): |
| 284 | + hyperparams = self.variational.hyperparams.state_vardict() |
| 285 | + z_f = hyperparams['embedding']['factors']['mu'].data |
| 286 | + |
| 287 | + tasks = self._tasks |
| 288 | + if tasks is None or len(tasks) == 0: |
| 289 | + tasks = list(range(self.num_subjects)) |
| 290 | + palette = dict(zip(tasks, utils.compose_palette(len(tasks)))) |
| 291 | + subject_colors = np.array([palette[task] for task in tasks]) |
| 292 | + |
| 293 | + plt.scatter(x=z_f[:, 0], y=z_f[:, 1], c=subject_colors) |
| 294 | + utils.palette_legend(list(palette.keys()), list(palette.values())) |
| 295 | + |
| 296 | + if filename is not None: |
| 297 | + plt.savefig(filename) |
| 298 | + if show: |
| 299 | + plt.show() |
| 300 | + |
| 301 | + def scatter_weights_embedding(self, t=0, filename=None, show=True): |
| 302 | + hyperparams = self.variational.hyperparams.state_vardict() |
| 303 | + z_f = hyperparams['embedding']['weights']['mu'][:, t, :].data |
| 304 | + |
| 305 | + tasks = self._tasks |
| 306 | + if tasks is None or len(tasks) == 0: |
| 307 | + tasks = list(range(self.num_subjects)) |
| 308 | + palette = dict(zip(tasks, utils.compose_palette(len(tasks)))) |
| 309 | + subject_colors = np.array([palette[task] for task in tasks]) |
| 310 | + |
| 311 | + plt.scatter(x=z_f[:, 0], y=z_f[:, 1], c=subject_colors) |
| 312 | + utils.palette_legend(list(palette.keys()), list(palette.values())) |
| 313 | + |
| 314 | + if filename is not None: |
| 315 | + plt.savefig(filename) |
| 316 | + if show: |
| 317 | + plt.show() |
0 commit comments