Skip to content

Commit 0679094

Browse files
authored
Merge pull request #31 from neu-spiral/develop
Develop
2 parents d6cf9d6 + 624dfa5 commit 0679094

20 files changed

+39006
-3394
lines changed

htfa_torch/dtfa.py

Lines changed: 222 additions & 94 deletions
Large diffs are not rendered by default.

htfa_torch/dtfa_models.py

Lines changed: 194 additions & 163 deletions
Large diffs are not rendered by default.

htfa_torch/htfa.py

Lines changed: 269 additions & 75 deletions
Large diffs are not rendered by default.

htfa_torch/htfa_models.py

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import collections
77
import numpy as np
8+
import scipy.spatial
89
import torch
910
from torch.autograd import Variable
1011
import probtorch
@@ -53,15 +54,12 @@ def __init__(self, hyper_means, num_times, num_blocks,
5354
'sigma': torch.sqrt(torch.rand(self._num_blocks, self._num_factors)),
5455
},
5556
'weights': {
56-
'mu': torch.randn(self._num_blocks, self._num_times,
57-
self._num_factors),
57+
'mu': hyper_means['weights'].mean(0).unsqueeze(0).expand(
58+
self._num_blocks, self._num_times, self._num_factors
59+
),
5860
'sigma': torch.ones(self._num_blocks, self._num_times,
5961
self._num_factors),
6062
},
61-
'voxel_noise': {
62-
'mu': torch.ones(self._num_blocks),
63-
'sigma': torch.sqrt(torch.rand(self._num_blocks))
64-
}
6563
})
6664

6765
super(self.__class__, self).__init__(params, guide=True)
@@ -95,15 +93,6 @@ def forward(self, trace, params, times=None, blocks=None,
9593
# We only expand the parameters for which we're actually going to sample
9694
# values in this very method, and thus want to expand to get multiple
9795
# particles.
98-
voxel_noise_params = params['block']['voxel_noise']
99-
if num_particles and num_particles > 0:
100-
voxel_noise_params = utils.unsqueeze_and_expand_vardict(
101-
params['block']['voxel_noise'], 0, num_particles, True
102-
)
103-
voxel_noise = trace.normal(voxel_noise_params['mu'],
104-
voxel_noise_params['sigma'],
105-
name='voxel_noise')
106-
10796
if blocks is None:
10897
blocks = list(range(self._num_blocks))
10998

@@ -125,16 +114,17 @@ def forward(self, trace, params, times=None, blocks=None,
125114
factor_centers += [fc]
126115
factor_log_widths += [flw]
127116

128-
return weights, factor_centers, factor_log_widths, voxel_noise
117+
return weights, factor_centers, factor_log_widths
129118

130119
class HTFAGuide(nn.Module):
131120
"""Variational guide for hierarchical topographic factor analysis"""
132121
def __init__(self, query, num_factors=tfa_models.NUM_FACTORS):
133122
super(self.__class__, self).__init__()
134123
self._num_blocks = len(query)
135-
self._num_times = niidb.query_min_time(query)
124+
self._num_times = niidb.query_max_time(query)
136125

137-
b = np.random.choice(self._num_blocks, 1)[0]
126+
b = max(range(self._num_blocks), key=lambda b: query[b].end_time -
127+
query[b].start_time)
138128
query[b].load()
139129
centers, widths, weights = utils.initial_hypermeans(
140130
query[b].activations.numpy().T, query[b].locations.numpy(),
@@ -161,7 +151,7 @@ def forward(self, trace, times=None, blocks=None,
161151

162152
class HTFAGenerativeHyperParams(tfa_models.HyperParams):
163153
def __init__(self, brain_center, brain_center_std_dev, num_blocks,
164-
num_factors=tfa_models.NUM_FACTORS):
154+
num_factors=tfa_models.NUM_FACTORS, volume=None):
165155
self._num_factors = num_factors
166156
self._num_blocks = num_blocks
167157

@@ -177,20 +167,22 @@ def __init__(self, brain_center, brain_center_std_dev, num_blocks,
177167
params['template']['factor_centers']['sigma'] =\
178168
brain_center_std_dev.expand(self._num_factors, 3)
179169

170+
coefficient = 1.0
171+
if volume is not None:
172+
coefficient = np.log(np.cbrt(volume / self._num_factors))
180173
params['template']['factor_log_widths']['mu'] =\
181-
torch.ones(self._num_factors)
174+
coefficient * torch.ones(self._num_factors)
182175
params['template']['factor_log_widths']['sigma'] =\
183176
tfa_models.SOURCE_LOG_WIDTH_STD_DEV * torch.ones(self._num_factors)
184177

185178
params['block'] = {
186179
'factor_center_noise': torch.ones(self._num_blocks),
187180
'factor_log_width_noise': torch.ones(self._num_blocks),
188181
'weights': {
189-
'mu': torch.rand(self._num_blocks, self._num_factors),
182+
'mu': torch.zeros(self._num_blocks, self._num_factors),
190183
'sigma': tfa_models.SOURCE_WEIGHT_STD_DEV *\
191184
torch.ones(self._num_blocks, self._num_factors)
192185
},
193-
'voxel_noise': utils.gaussian_populator(self._num_blocks)
194186
}
195187
super(self.__class__, self).__init__(params, guide=False)
196188

@@ -216,19 +208,16 @@ def __init__(self, num_blocks, num_times):
216208
for b in range(self._num_blocks)]
217209

218210
def forward(self, trace, params, template, times=None, blocks=None,
219-
guide=probtorch.Trace()):
220-
voxel_noise = trace.normal(params['block']['voxel_noise']['mu'],
221-
params['block']['voxel_noise']['sigma'],
222-
value=guide['voxel_noise'],
223-
name='voxel_noise')
224-
211+
guide=probtorch.Trace(), weights_params=None):
225212
if blocks is None:
226213
blocks = list(range(self._num_blocks))
214+
if times is None:
215+
times = (0, self._num_times)
227216

228217
weights = []
229218
factor_centers = []
230219
factor_log_widths = []
231-
for b in blocks:
220+
for (i, b) in enumerate(blocks):
232221
sparams = utils.vardict({
233222
'factor_centers': {
234223
'mu': template['factor_centers'],
@@ -243,56 +232,58 @@ def forward(self, trace, params, template, times=None, blocks=None,
243232
'sigma': params['block']['weights']['sigma'][b],
244233
}
245234
})
235+
if weights_params is not None:
236+
sparams['weights'] = weights_params[i]
246237
w, fc, flw = self._tfa_priors[b](trace, sparams, times=times,
247238
guide=guide)
248239
weights += [w]
249240
factor_centers += [fc]
250241
factor_log_widths += [flw]
251242

252-
return weights, factor_centers, factor_log_widths, voxel_noise
243+
return weights, factor_centers, factor_log_widths
253244

254245
class HTFAModel(nn.Module):
255246
"""Generative model for hierarchical topographic factor analysis"""
256-
def __init__(self, query, num_blocks, num_times,
257-
num_factors=tfa_models.NUM_FACTORS):
247+
def __init__(self, locations, num_blocks, num_times,
248+
num_factors=tfa_models.NUM_FACTORS, volume=None):
258249
super(self.__class__, self).__init__()
259250

260251
self._num_factors = num_factors
261252
self._num_blocks = num_blocks
262253
self._num_times = num_times
263254

264-
b = np.random.choice(self._num_blocks, 1)[0]
265-
query[b].load()
266-
center, center_sigma = utils.brain_centroid(query[b].locations)
255+
center, center_sigma = utils.brain_centroid(locations)
256+
hull = scipy.spatial.ConvexHull(locations)
257+
if volume is not None:
258+
volume = hull.volume
267259

268260
self._hyperparams = HTFAGenerativeHyperParams(center, center_sigma,
269261
self._num_blocks,
270-
self._num_factors)
262+
self._num_factors,
263+
volume=volume)
271264
self._template_prior = HTFAGenerativeTemplatePrior()
272265
self._subject_prior = HTFAGenerativeSubjectPrior(
273266
self._num_blocks, self._num_times
274267
)
275-
for block in query:
276-
block.load()
277268
self.likelihoods = [tfa_models.TFAGenerativeLikelihood(
278-
query[b].locations, self._num_times[b], tfa_models.VOXEL_NOISE,
269+
locations, self._num_times[b], tfa_models.VOXEL_NOISE,
279270
block=b, register_locations=False
280271
) for b in range(self._num_blocks)]
281272
for b, block_likelihood in enumerate(self.likelihoods):
282273
self.add_module('likelihood' + str(b), block_likelihood)
283274

284275
def forward(self, trace, times=None, guide=probtorch.Trace(), blocks=None,
285-
observations=[]):
276+
observations=[], weights_params=None):
286277
if blocks is None:
287278
blocks = list(range(self._num_blocks))
288279
params = self._hyperparams.state_vardict()
289280

290281
template = self._template_prior(trace, params, guide=guide)
291-
weights, centers, log_widths, voxel_noise = self._subject_prior(
292-
trace, params, template, times=times, blocks=blocks, guide=guide
282+
weights, centers, log_widths = self._subject_prior(
283+
trace, params, template, times=times, blocks=blocks, guide=guide,
284+
weights_params=weights_params
293285
)
294286

295287
return [self.likelihoods[b](trace, weights[i], centers[i], log_widths[i],
296-
times=times, observations=observations[i],
297-
voxel_noise=voxel_noise)
288+
times=times, observations=observations[i])
298289
for (i, b) in enumerate(blocks)]

htfa_torch/niidb.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,18 @@ def wrapped_table_method(*args, **kwargs):
114114
return wrapped_table_method
115115
return attr
116116

117+
def query_max_time(qiter):
118+
result = -1
119+
for block in qiter:
120+
load = block.activations is None
121+
if load:
122+
block.load()
123+
if result < 0 or block.end_time > result:
124+
result = block.end_time
125+
if load:
126+
block.unload()
127+
return result
128+
117129
def query_min_time(qiter):
118130
result = -1
119131
for block in qiter:

htfa_torch/tfa.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,54 @@ def free_energy(q, p, num_particles=tfa_models.NUM_PARTICLES):
4444
sample_dim = None
4545
return -probtorch.objectives.montecarlo.elbo(q, p, sample_dim=sample_dim)
4646

47-
def hierarchical_elbo(q, p, rv_weight=lambda x: 1.0,
47+
def hierarchical_elbo(q, p, rv_weight=lambda x, prior=True: 1.0,
4848
num_particles=tfa_models.NUM_PARTICLES,
4949
sample_dim=None, batch_dim=None):
5050
if num_particles and num_particles > 0:
5151
sample_dim = 0
5252
else:
5353
sample_dim = None
5454

55-
weight_rvs = utils.inverse(rv_weight, p)
56-
weighted_elbo = 0.0
57-
for weight, rvs in weight_rvs.items():
55+
weighted_log_likelihood = 0.0
56+
weighted_prior_kl = 0.0
57+
for rv in p:
5858
local_elbo = p.log_joint(sample_dim=sample_dim, batch_dim=batch_dim,
59-
nodes=rvs) -\
59+
nodes=[rv]) -\
6060
q.log_joint(sample_dim=sample_dim, batch_dim=batch_dim,
61-
nodes=rvs)
62-
weighted_elbo += weight * local_elbo.sum()
63-
return weighted_elbo
61+
nodes=[rv])
62+
if sample_dim is not None:
63+
local_elbo = local_elbo.mean(dim=sample_dim)
64+
if p[rv].observed and rv not in q:
65+
weighted_log_likelihood += rv_weight(rv, False) * local_elbo
66+
else:
67+
weighted_prior_kl -= rv_weight(rv, True) * local_elbo
68+
weighted_elbo = weighted_log_likelihood - weighted_prior_kl
69+
return weighted_elbo, weighted_log_likelihood, weighted_prior_kl
70+
71+
def componentized_elbo(q, p, rv_weight=lambda x: 1.0,
72+
num_particles=tfa_models.NUM_PARTICLES, sample_dim=None,
73+
batch_dim=None):
74+
if num_particles and num_particles > 0:
75+
sample_dim = 0
76+
else:
77+
sample_dim = None
78+
79+
trace_elbos = {}
80+
for rv in p:
81+
local_elbo = p.log_joint(sample_dim=sample_dim, batch_dim=batch_dim,
82+
nodes=[rv]) -\
83+
q.log_joint(sample_dim=sample_dim, batch_dim=batch_dim,
84+
nodes=[rv])
85+
if sample_dim is not None:
86+
local_elbo = local_elbo.mean(dim=sample_dim)
87+
trace_elbos[rv] = rv_weight(rv) * local_elbo
88+
89+
weighted_elbo = sum([trace_elbos[rv] for rv in trace_elbos])
90+
return weighted_elbo, trace_elbos
6491

6592
def hierarchical_free_energy(*args, **kwargs):
66-
return -hierarchical_elbo(*args, **kwargs)
93+
elbo, ll, kl = hierarchical_elbo(*args, **kwargs)
94+
return -elbo, ll, kl
6795

6896
def log_likelihood(q, p, num_particles=tfa_models.NUM_PARTICLES):
6997
"""The expected log-likelihood of observed data under the proposal distribution"""
@@ -75,11 +103,12 @@ def log_likelihood(q, p, num_particles=tfa_models.NUM_PARTICLES):
75103

76104
class TopographicalFactorAnalysis:
77105
"""Overall container for a run of TFA"""
78-
def __init__(self, data_file, num_factors=tfa_models.NUM_FACTORS):
106+
def __init__(self, data_file, num_factors=tfa_models.NUM_FACTORS,
107+
zscore=True):
79108
self.num_factors = num_factors
80109

81110
self.voxel_activations, self.voxel_locations, self._name,\
82-
self._template = utils.load_dataset(data_file)
111+
self._template = utils.load_dataset(data_file, zscore=zscore)
83112

84113
# Pull out relevant dimensions: the number of times-of-recording, and
85114
# the number of voxels in each timewise "slice"

htfa_torch/tfa_models.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def forward(self, trace, params, times=None, num_particles=NUM_PARTICLES):
131131

132132
weights = trace.normal(weight_params['mu'],
133133
weight_params['sigma'],
134-
name='Weights' + str(self.block))
134+
name='Weights%dt%d-%d' % (self.block, times[0], times[1]))
135135

136136
centers = trace.normal(params['factor_centers']['mu'],
137137
params['factor_centers']['sigma'],
@@ -196,8 +196,8 @@ def forward(self, trace, params, times=None, guide=probtorch.Trace()):
196196

197197
weights = trace.normal(weight_params['mu'],
198198
weight_params['sigma'],
199-
value=guide['Weights' + str(self.block)],
200-
name='Weights' + str(self.block))
199+
value=guide['Weights%dt%d-%d' % (self.block, times[0], times[1])],
200+
name='Weights%dt%d-%d' % (self.block, times[0], times[1]))
201201

202202
factor_centers = trace.normal(params['factor_centers']['mu'],
203203
params['factor_centers']['sigma'],
@@ -224,18 +224,17 @@ def __init__(self, locations, num_times, voxel_noise=VOXEL_NOISE, block=0,
224224
self.block = block
225225

226226
def forward(self, trace, weights, centers, log_widths, times=None,
227-
observations=collections.defaultdict(), voxel_noise=None):
227+
observations=collections.defaultdict()):
228228
if times is None:
229229
times = (0, self._num_times)
230-
if voxel_noise is None:
231-
voxel_noise = self._voxel_noise
232230

233-
factors = radial_basis(Variable(self.voxel_locations), centers,
234-
log_widths)
231+
factors = radial_basis(Variable(self.voxel_locations,
232+
requires_grad=True),
233+
centers, log_widths)
235234

236235
activations = trace.normal(weights @ factors,
237236
self._voxel_noise, value=observations['Y'],
238-
name='Y' + str(self.block))
237+
name='Y%dt%d-%d' % (self.block, times[0], times[1]))
239238
return activations
240239

241240
class TFAModel(nn.Module):

0 commit comments

Comments
 (0)