Skip to content

Commit f58043d

Browse files
authored
Merge pull request #20 from neu-spiral/feature/large_runs
Feature/large runs
2 parents 9174bd7 + 59693f9 commit f58043d

File tree

5 files changed

+694
-518
lines changed

5 files changed

+694
-518
lines changed

create_mean_images.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/sh
2+
3+
for file in $1/*.nii ##path to relevant dataset group
4+
do
5+
fslmaths "$file" -Tmean -bin "${file}_mean"
6+
done
7+
8+
fslmerge -t $1/allmeanmasks4d $1/*.nii.gz
9+
fslmaths $1/allmeanmasks4d -Tmean $1/propDatavox3d
10+
fslmaths $1/propDatavox3d -thr 1 $1/wholebrain

creating_mask.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
what you want to do is to first get a mean (across time) image for each 4D file and then binarize it*.
2+
3+
In order to do this, use fslmaths for each 4D file:
4+
5+
fslmaths 4D_inputVolume1 -Tmean -bin 3d_meanmask1
6+
fslmaths 4D_inputVolume2 -Tmean -bin 3d_meanmask2
7+
...
8+
fslmaths 4D_inputVolumeN -Tmean -bin 3d_meanmaskN
9+
10+
Then, we'll want to get the proportion of subjects who have data for each voxel. We do this by creating a 4D file from all the 3D masks and then taking the mean across the 4th dim:
11+
12+
fslmerge -t allmeanmasks4d 3d_meanmask1 3d_meanmask2 ... 3d_meanmaskN
13+
14+
fslmaths allmeanmasks4d -Tmean propDatavox3d
15+
16+
One can look at this file to get a sense of how across subject alignment did and where there is consistent or spotty drop-out of data.
17+
18+
Lastly, make this a binary mask which is 1 where ALL subjects have data and 0 elsewhere (save as wholebrain.nii.gz):
19+
fslmaths propDatavox3d -thr 1 wholebrain
20+
21+
22+
23+
24+
*Note that if the data is z-scored already, this won't work (it isn't z-scored for greeneyes), because the mean will be ~0 for each voxel and so the binarize operation (turn non-zeros into 1) will be bad, so you would probably have to binarize, take the mean, then binarize again.

htfa_torch/htfa.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@
2929

3030
class HierarchicalTopographicFactorAnalysis:
3131
"""Overall container for a run of TFA"""
32-
def __init__(self, data_files, num_factors=tfa_models.NUM_FACTORS):
32+
def __init__(self, data_files, num_factors=tfa_models.NUM_FACTORS,
33+
mask=None):
3334
self.num_factors = num_factors
3435
self.num_subjects = len(data_files)
35-
datasets = [utils.load_dataset(data_file) for data_file in data_files]
36+
if mask is None:
37+
raise ValueError('please provide a mask')
38+
else:
39+
self.mask = mask
40+
datasets = [utils.load_dataset(data_file, mask=mask)
41+
for data_file in data_files]
3642
self.voxel_activations = [dataset[0] for dataset in datasets]
3743
self._images = [dataset[1] for dataset in datasets]
3844
self.voxel_locations = [dataset[2] for dataset in datasets]
@@ -52,25 +58,30 @@ def __init__(self, data_files, num_factors=tfa_models.NUM_FACTORS):
5258

5359
def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
5460
log_level=logging.WARNING, num_particles=tfa_models.NUM_PARTICLES,
55-
use_cuda=True):
61+
batch_size=64, use_cuda=True):
5662
"""Optimize the variational guide to reflect the data for `num_steps`"""
5763
logging.basicConfig(format='%(asctime)s %(message)s',
5864
datefmt='%m/%d/%Y %H:%M:%S',
5965
level=log_level)
60-
61-
activations = [{'Y': Variable(self.voxel_activations[s])}
62-
for s in range(self.num_subjects)]
66+
activations = torch.Tensor(max(self.num_times), max(self.num_voxels),
67+
len(self.voxel_activations))
68+
for s in range(self.num_subjects):
69+
activations[:, :, s] = self.voxel_activations[s]
70+
activations_loader = torch.utils.data.DataLoader(
71+
torch.utils.data.TensorDataset(
72+
activations,
73+
torch.zeros(activations.shape[0])
74+
),
75+
batch_size=batch_size
76+
)
6377
if tfa.CUDA and use_cuda:
6478
enc = torch.nn.DataParallel(self.enc)
6579
dec = torch.nn.DataParallel(self.dec)
6680
enc.cuda()
67-
dec.cuda()
68-
for acts in activations:
69-
acts['Y'] = acts['Y'].cuda()
81+
dec.cuda(0)
7082
else:
7183
enc = self.enc
7284
dec = self.dec
73-
7485
optimizer = torch.optim.Adam(list(self.enc.parameters()),
7586
lr=learning_rate)
7687
enc.train()
@@ -81,24 +92,37 @@ def train(self, num_steps=10, learning_rate=tfa.LEARNING_RATE,
8192

8293
for epoch in range(num_steps):
8394
start = time.time()
84-
85-
optimizer.zero_grad()
86-
q = probtorch.Trace()
87-
enc(q, num_particles=num_particles)
88-
p = probtorch.Trace()
89-
dec(p, guide=q, observations=activations)
90-
91-
free_energies[epoch] = tfa.free_energy(q, p, num_particles=num_particles)
92-
lls[epoch] = tfa.log_likelihood(q, p, num_particles=num_particles)
93-
94-
free_energies[epoch].backward()
95-
optimizer.step()
96-
97-
if tfa.CUDA and use_cuda:
98-
free_energies[epoch] = free_energies[epoch].cpu()
99-
lls[epoch] = lls[epoch].cpu()
100-
free_energies[epoch] = free_energies[epoch].data.numpy().sum(0)
101-
lls[epoch] = lls[epoch].data.numpy().sum(0)
95+
epoch_free_energies = list(range(len(activations_loader)))
96+
epoch_lls = list(range(len(activations_loader)))
97+
98+
for (batch, (data, _)) in enumerate(activations_loader):
99+
activations = [{'Y': Variable(data[:, :, s])}
100+
for s in range(self.num_subjects)]
101+
trs = (batch * batch_size, None)
102+
trs = (trs[0], trs[0] + activations[0]['Y'].shape[0])
103+
104+
105+
optimizer.zero_grad()
106+
q = probtorch.Trace()
107+
enc(q, times=trs, num_particles=num_particles)
108+
p = probtorch.Trace()
109+
dec(p, times=trs, guide=q, observations=activations)
110+
111+
112+
epoch_free_energies[batch] =\
113+
tfa.free_energy(q, p, num_particles=num_particles)
114+
epoch_lls[batch] =\
115+
tfa.log_likelihood(q, p, num_particles=num_particles)
116+
epoch_free_energies[batch].backward()
117+
optimizer.step()
118+
if tfa.CUDA and use_cuda:
119+
epoch_free_energies[batch] = epoch_free_energies[batch].cpu().data.numpy()
120+
epoch_lls[batch] = epoch_lls[batch].cpu().data.numpy()
121+
122+
free_energies[epoch] = np.array(epoch_free_energies).sum(0)
123+
free_energies[epoch] = free_energies[epoch].sum(0)
124+
lls[epoch] = np.array(epoch_lls).sum(0)
125+
lls[epoch] = lls[epoch].sum(0)
102126

103127
end = time.time()
104128
msg = tfa.EPOCH_MSG % (epoch + 1, (end - start) * 1000, free_energies[epoch])

htfa_torch/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ def cmu2nii(activations, locations, template):
141141

142142
return nib.Nifti1Image(data, affine=sform)
143143

144-
def load_dataset(data_file):
144+
def load_dataset(data_file, mask=None):
145145
name, ext = os.path.splitext(data_file)
146146
if ext == '.nii':
147-
dataset, image = nii2cmu(data_file)
147+
dataset, image = nii2cmu(data_file, mask_file=mask)
148148
template = data_file
149149
else:
150150
dataset = sio.loadmat(data_file)

notebooks/example_htfa.ipynb

Lines changed: 606 additions & 488 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)