Skip to content

Commit e518254

Browse files
authored
Merge pull request #21 from neu-spiral/develop
Develop
2 parents 4405498 + f58043d commit e518254

File tree

12 files changed

+2780
-544
lines changed

12 files changed

+2780
-544
lines changed

create_mean_images.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/sh
2+
3+
for file in $1/*.nii ##path to relevant dataset group
4+
do
5+
fslmaths "$file" -Tmean -bin "${file}_mean"
6+
done
7+
8+
fslmerge -t $1/allmeanmasks4d $1/*.nii.gz
9+
fslmaths $1/allmeanmasks4d -Tmean $1/propDatavox3d
10+
fslmaths $1/propDatavox3d -thr 1 $1/wholebrain

creating_mask.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
what you want to do is to first get a mean (across time) image for each 4D file and then binarize it*.
2+
3+
In order to do this, use fslmaths for each 4D file:
4+
5+
fslmaths 4D_inputVolume1 -Tmean -bin 3d_meanmask1
6+
fslmaths 4D_inputVolume2 -Tmean -bin 3d_meanmask2
7+
...
8+
fslmaths 4D_inputVolumeN -Tmean -bin 3d_meanmaskN
9+
10+
Then, we'll want to get the proportion of subjects who have data for each voxel. We do this by creating a 4D file from all the 3D masks and then taking the mean across the 4th dim:
11+
12+
fslmerge -t allmeanmasks4d 3d_meanmask1 3d_meanmask2 ... 3d_meanmaskN
13+
14+
fslmaths allmeanmasks4d -Tmean propDatavox3d
15+
16+
One can look at this file to get a sense of how across subject alignment did and where there is consistent or spotty drop-out of data.
17+
18+
Lastly, make this a binary mask which is 1 where ALL subjects have data and 0 elsewhere (save as wholebrain.nii.gz):
19+
fslmaths propDatavox3d -thr 1 wholebrain
20+
21+
22+
23+
24+
*Note that if the data is z-scored already, this won't work (it isn't z-scored for greeneyes), because the mean will be ~0 for each voxel and so the binarize operation (turn non-zeros into 1) will be bad, so you would probably have to binarize, take the mean, then binarize again.

environment.yml

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@ channels:
33
- pytorch
44
- defaults
55
dependencies:
6+
- _nb_ext_conf=0.4.0=py36_1
7+
- anaconda-client=1.6.12=py36_0
8+
- asn1crypto=0.24.0=py36_0
69
- bleach=2.1.2=py36_0
710
- ca-certificates=2017.08.26=h1d4fec5_0
811
- certifi=2018.1.18=py36_0
912
- cffi=1.11.4=py36h9745a5d_0
13+
- chardet=3.0.4=py36h0f667ec_1
14+
- clyent=1.2.2=py36h7e57e65_1
15+
- cryptography=2.1.4=py36hd09be54_0
1016
- cycler=0.10.0=py36h93f1223_0
1117
- dbus=1.12.2=hc3f9b76_1
1218
- decorator=4.2.1=py36_0
@@ -20,6 +26,7 @@ dependencies:
2026
- gstreamer=1.12.4=hb53b477_0
2127
- html5lib=1.0.1=py36h2f9c1c0_0
2228
- icu=58.2=h9c2bf20_1
29+
- idna=2.6=py36h82fb2a8_1
2330
- intel-openmp=2018.0.0=hc7b2577_8
2431
- ipykernel=4.8.2=py36_0
2532
- ipython=6.2.1=py36h88c514a_1
@@ -47,8 +54,12 @@ dependencies:
4754
- matplotlib=2.1.2=py36h0e671d2_0
4855
- mistune=0.8.3=py36_0
4956
- mkl=2018.0.1=h19d6760_4
57+
- nb_anacondacloud=1.4.0=py36_0
58+
- nb_conda=2.2.1=py36h8118bb2_0
59+
- nb_conda_kernels=2.1.0=py36_0
5060
- nbconvert=5.3.1=py36hb41ffb7_0
5161
- nbformat=4.4.0=py36h31c9010_0
62+
- nbpresent=3.0.2=py36h5f95a39_1
5263
- ncurses=6.0=h9df7e31_2
5364
- notebook=5.4.0=py36_0
5465
- numpy=1.14.0=py36h3dfced4_1
@@ -66,15 +77,19 @@ dependencies:
6677
- ptyprocess=0.5.2=py36h69acd42_0
6778
- pycparser=2.18=py36hf9f622e_1
6879
- pygments=2.2.0=py36h0d3125c_0
80+
- pyopenssl=17.5.0=py36h20ba746_0
6981
- pyparsing=2.2.0=py36hee85983_1
7082
- pyqt=5.6.0=py36h0386399_5
83+
- pysocks=1.6.8=py36_0
7184
- python=3.6.4=hc3d631a_1
7285
- python-dateutil=2.6.1=py36h88d3b88_1
7386
- pytz=2017.3=py36h63b9c63_0
87+
- pyyaml=3.12=py36hafb9ca4_1
7488
- pyzmq=16.0.3=py36he2533c7_0
7589
- qt=5.6.2=h974d657_12
7690
- qtconsole=4.3.1=py36h8f73b5b_0
7791
- readline=7.0=ha6073c6_4
92+
- requests=2.18.4=py36he2e5f8d_1
7893
- scikit-learn=0.19.1=py36h7aa7ec6_0
7994
- scipy=1.0.0=py36hbf646e7_0
8095
- send2trash=1.4.2=py36_0
@@ -88,34 +103,32 @@ dependencies:
88103
- tk=8.6.7=hc745277_3
89104
- tornado=4.5.3=py36_0
90105
- traitlets=4.3.2=py36h674d592_0
106+
- urllib3=1.22=py36hbe7ace6_0
91107
- wcwidth=0.1.7=py36hdf4376a_0
92108
- webencodings=0.5.1=py36h800622e_1
93109
- wheel=0.30.0=py36hfd4bba0_1
94110
- widgetsnbextension=3.1.0=py36_0
95111
- xz=5.2.3=h55aa19d_2
112+
- yaml=0.1.7=had09818_2
96113
- zeromq=4.2.2=hbedb6e5_2
97114
- zlib=1.2.11=ha838bed_2
98115
- cuda90=1.0=h6433d27_0
99116
- pytorch=0.3.1=py36_cuda9.0.176_cudnn7.0.5_2
100117
- torchvision=0.2.0=py36h17b6947_1
101118
- pip:
102-
- chardet==3.0.4
103119
- deepdish==0.3.6
104120
- flatdict==2.0.1
105121
- future==0.16.0
122+
- htfatorch==0.0.0
106123
- hypertools==0.4.2
107-
- idna==2.6
108124
- nibabel==2.2.1
109125
- nilearn==0.4.0
110126
- numexpr==2.6.4
111127
- pandas==0.22.0
112128
- ppca==0.0.3
113129
- probtorch==0.0+5a2c637
114-
- pyyaml==3.12
115-
- requests==2.18.4
116130
- seaborn==0.8.1
117131
- tables==3.4.2
118132
- torch==0.3.1.post2
119-
- urllib3==1.22
120-
prefix: /home/eli/anaconda3/envs/HTFATorch
133+
prefix: /home/work/anaconda3/envs/HTFATorch
121134

htfa_torch/dtfa.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Sketch of Deep TFA architecture"""
2+
3+
__author__ = ('Jan-Willem van de Meent',
4+
'Eli Sennesh',
5+
'Zulqarnain Khan')
6+
__email__ = ('j.vandemeent@northeastern.edu',
7+
'e.sennesh@northeastern.edu',
8+
'khan.zu@husky.neu.edu')
9+
10+
from collections import defaultdict
11+
import torch
12+
import probtorch
13+
14+
# NOTE: I am writing this as a model relative to PyTorch master,
15+
# which no longer requires explicit wrapping in Variable(...)
16+
17+
class DeepTFA(torch.nn.Module):
18+
def __init__(self, N=50, T=200, D=2, E=2, K=24):
19+
# generative model
20+
self.p_z_w_mean = torch.zeros(E)
21+
self.p_z_w_std = torch.ones(E)
22+
self.w = torch.nn.Sequential(
23+
torch.nn.Linear(E, K/2),
24+
torch.nn.ReLU(),
25+
torch.nn.Linear(K/2, K))
26+
self.q_z_f_mean = torch.zeros(D)
27+
self.q_z_f_std = torch.ones(D)
28+
self.h_f = torch.nn.Sequential(
29+
torch.nn.Linear(D, K/2),
30+
torch.nn.ReLU())
31+
self.x_f = torch.nn.Linear(K/2, 3*K)
32+
self.log_rho_f = torch.nn.Linear(K/2, K)
33+
self.sigma_y = Parameter(1.0)
34+
# variational parameters
35+
self.q_z_f_mean = Parameter(torch.zeros(N, D))
36+
self.q_z_f_std = Parameter(torch.ones(N, D))
37+
self.q_z_w_mean = Parameter(torch.zeros(N, T, E))
38+
self.q_z_w_std = Parameter(torch.ones(N, T, E))
39+
40+
def forward(self, x, y, n, t):
41+
p = probtorch.Trace()
42+
q = probtorch.Trace()
43+
z_w = q.normal(self.q_z_w_mean[n, t],
44+
self.q_z_w_std[n, t],
45+
name='z_w')
46+
z_w = p.normal(self.p_z_w_mean,
47+
self.p_z_w_std,
48+
value=q['z_w'],
49+
name='z_w')
50+
w = self.w(z_w)
51+
z_f = q.normal(self.q_z_f_mean[n],
52+
self.q_z_f_std[n],
53+
name='z_f')
54+
z_f = p.normal(self.z_f_mean,
55+
self.z_f_std,
56+
value=q['z_f']
57+
name='z_f')
58+
x_f = self.x_f(z_f)
59+
rho_f = torch.exp(self.log_rho_f(z_f))
60+
f = rbf(x, x_f, rho_f)
61+
y = p.normal(w * f,
62+
self.sigma_y,
63+
value='y',
64+
name='y')
65+
return p, q
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import division
2+
import numpy as np
3+
import scipy.optimize as optimize
4+
5+
def initialize_centers_widths_weights(activations,locations,num_factors):
6+
temp_data = []
7+
temp_R = []
8+
for i in range(len(activations)):
9+
temp_data.append(activations[i].numpy())
10+
temp_R.append(locations[i].numpy())
11+
activations = temp_data
12+
locations = temp_R
13+
del temp_data,temp_R
14+
mean_activations,mean_locations = mean_image(activations,locations)
15+
template_center_mean,template_width_mean = hotspot_initialization(mean_activations,
16+
mean_locations,
17+
num_factors)
18+
F = radial_basis(mean_locations,template_center_mean,template_width_mean).T
19+
trans_F = F.T.copy()
20+
template_weights_mean = np.linalg.solve(trans_F.dot(F), trans_F.dot(mean_activations))
21+
subject_weights_mean = []
22+
subject_center_mean = []
23+
subject_width_mean = []
24+
for i in range(len(activations)):
25+
F = radial_basis(locations[i], template_center_mean, template_width_mean).T
26+
trans_F = F.T.copy()
27+
subject_weights_mean.append(np.linalg.solve(trans_F.dot(F), trans_F.dot(mean_activations)))
28+
subject_center_mean.append(template_center_mean)
29+
subject_width_mean.append(template_width_mean)
30+
31+
return template_center_mean,template_width_mean,template_weights_mean,\
32+
np.array(subject_center_mean),np.array(subject_width_mean),np.array(subject_weights_mean)
33+
34+
35+
def hotspot_initialization(activations, locations , num_factors):
36+
mean_activations = abs(activations - np.nanmean(activations))
37+
centers = np.zeros(shape=(num_factors,locations.shape[1]))
38+
widths = np.zeros(shape=(num_factors,))
39+
40+
for k in range(num_factors):
41+
ind = np.nanargmax(mean_activations)
42+
centers[k,:] = locations[ind,:]
43+
widths[k] = init_width(activations,locations,activations[ind],centers[k,:])
44+
mean_activations = mean_activations - radial_basis(locations,centers[k,:],widths[k])
45+
return centers,widths
46+
47+
def mean_image(activations, locations):
48+
49+
mean_locations = locations[0]
50+
51+
for i in range(1, len(locations)):
52+
mean_locations = np.vstack([mean_locations, locations[i]]) ## fix this to accomodate differing lengths
53+
54+
mean_locations = np.unique(mean_locations,axis=0)
55+
mean_activations = np.zeros(shape=(mean_locations.shape[0],))
56+
n = np.zeros(shape=(mean_activations.shape))
57+
58+
for i in range(len(activations)):
59+
C = intersect(mean_locations,locations[i])
60+
mean_locations_ind = get_common_indices(mean_locations,C)
61+
subject_locations_ind = get_common_indices(locations[i],C)
62+
mean_locations_ind = np.sort(mean_locations_ind)
63+
subject_locations_ind = np.sort(subject_locations_ind)
64+
mean_activations[mean_locations_ind] = mean_activations[mean_locations_ind] + \
65+
np.mean(activations[i][:,subject_locations_ind],axis=0)
66+
n[mean_locations_ind] = n[mean_locations_ind] + 1
67+
mean_activations = mean_activations/n
68+
69+
return mean_activations,mean_locations
70+
71+
72+
def init_width(activations,locations,weight,c):
73+
74+
start_width = 0
75+
objective = lambda w: np.sum(np.abs(activations - weight*radial_basis(locations,c,w)))
76+
result = optimize.minimize(objective,x0=start_width)
77+
78+
return result.x
79+
80+
def radial_basis(locations, centers, log_widths):
81+
"""The radial basis function used as the shape for the factors"""
82+
# V x 3 -> 1 x V x 3
83+
locations = np.expand_dims(locations,0)
84+
if len(centers.shape) > 3:
85+
# 1 x V x 3 -> 1 x 1 x V x 3
86+
locations = np.expand_dims(locations,0)
87+
# S x K x 3 -> S x K x 1 x 3
88+
centers = np.expand_dims(centers,len(centers.shape)-1)
89+
# S x K x V x 3
90+
delta2s = (locations - centers)**2
91+
# S x K -> S x K x 1
92+
log_widths = np.expand_dims(log_widths,len(log_widths.shape))
93+
return np.exp(-delta2s.sum(len(delta2s.shape) - 1) / np.exp(log_widths))
94+
95+
def intersect(A,B):
96+
return np.array([x for x in set(tuple(x) for x in A) & set(tuple(x) for x in B)])
97+
98+
def get_common_indices(X,C):
99+
return np.where((X==C[:,None]).all(-1))[1]

0 commit comments

Comments
 (0)