@@ -32,14 +32,11 @@ def __init__(self, num_subjects, num_tasks, embedding_dim=2):
3232 params = utils .vardict ({
3333 'subject' : {
3434 'mu' : torch .zeros (self .num_subjects , self .embedding_dim ),
35- 'sigma' : torch .ones (self .num_subjects , self .embedding_dim ) * \
36- np .sqrt (tfa_models .SOURCE_WEIGHT_STD_DEV ** 2 + \
37- tfa_models .SOURCE_LOG_WIDTH_STD_DEV ** 2 ),
35+ 'sigma' : torch .ones (self .num_subjects , self .embedding_dim ),
3836 },
3937 'task' : {
4038 'mu' : torch .zeros (self .num_tasks , self .embedding_dim ),
41- 'sigma' : torch .ones (self .num_tasks , self .embedding_dim ) * \
42- tfa_models .SOURCE_WEIGHT_STD_DEV ,
39+ 'sigma' : torch .ones (self .num_tasks , self .embedding_dim ),
4340 },
4441 'voxel_noise' : torch .ones (1 ) * tfa_models .VOXEL_NOISE ,
4542 })
@@ -48,7 +45,7 @@ def __init__(self, num_subjects, num_tasks, embedding_dim=2):
4845
4946class DeepTFAGuideHyperparams (tfa_models .HyperParams ):
5047 def __init__ (self , num_blocks , num_times , num_factors , num_subjects ,
51- num_tasks , hyper_means , embedding_dim = 2 ):
48+ num_tasks , hyper_means , embedding_dim = 2 , time_series = True ):
5249 self .num_blocks = num_blocks
5350 self .num_subjects = num_subjects
5451 self .num_tasks = num_tasks
@@ -78,23 +75,26 @@ def __init__(self, num_blocks, num_times, num_factors, num_subjects,
7875 'sigma' : torch .ones (self .num_subjects , self ._num_factors ) * \
7976 hyper_means ['factor_log_widths' ].std (),
8077 },
81- 'weights' : {
78+ })
79+ if time_series :
80+ params ['weights' ] = {
8281 'mu' : torch .zeros (self .num_blocks , self .num_times ,
8382 self ._num_factors ),
8483 'sigma' : torch .ones (self .num_blocks , self .num_times ,
8584 self ._num_factors ),
86- },
87- })
85+ }
8886
8987 super (self .__class__ , self ).__init__ (params , guide = True )
9088
9189class DeepTFADecoder (nn .Module ):
9290 """Neural network module mapping from embeddings to a topographic factor
9391 analysis"""
94- def __init__ (self , num_factors , hyper_means , embedding_dim = 2 ):
92+ def __init__ (self , num_factors , hyper_means , embedding_dim = 2 ,
93+ time_series = True ):
9594 super (DeepTFADecoder , self ).__init__ ()
9695 self ._embedding_dim = embedding_dim
9796 self ._num_factors = num_factors
97+ self ._time_series = time_series
9898
9999 self .factors_embedding = nn .Sequential (
100100 nn .Linear (self ._embedding_dim , self ._embedding_dim * 2 ),
@@ -200,7 +200,8 @@ def predict(self, trace, params, guide, subject, task, times=(0, 1),
200200 weight_predictions = self ._predict_param (
201201 params , 'weights' , block , weight_predictions ,
202202 'Weights%d_%d-%d' % (block , times [0 ], times [1 ]), trace ,
203- predict = generative or block < 0 , guide = guide ,
203+ predict = generative or block < 0 or not self ._time_series ,
204+ guide = guide ,
204205 )
205206
206207 return centers_predictions , log_widths_predictions , weight_predictions
@@ -237,12 +238,14 @@ def forward(self, trace, blocks, block_subjects, block_tasks, params, times,
237238class DeepTFAGuide (nn .Module ):
238239 """Variational guide for deep topographic factor analysis"""
239240 def __init__ (self , num_factors , block_subjects , block_tasks , num_blocks = 1 ,
240- num_times = [1 ], embedding_dim = 2 , hyper_means = None ):
241+ num_times = [1 ], embedding_dim = 2 , hyper_means = None ,
242+ time_series = True ):
241243 super (self .__class__ , self ).__init__ ()
242244 self ._num_blocks = num_blocks
243245 self ._num_times = num_times
244246 self ._num_factors = num_factors
245247 self ._embedding_dim = embedding_dim
248+ self ._time_series = time_series
246249
247250 self .block_subjects = block_subjects
248251 self .block_tasks = block_tasks
@@ -254,7 +257,7 @@ def __init__(self, num_factors, block_subjects, block_tasks, num_blocks=1,
254257 self ._num_factors ,
255258 num_subjects , num_tasks ,
256259 hyper_means ,
257- embedding_dim )
260+ embedding_dim , time_series )
258261
259262 def forward (self , decoder , trace , times = None , blocks = None ,
260263 num_particles = tfa_models .NUM_PARTICLES ):
@@ -269,7 +272,7 @@ def forward(self, decoder, trace, times=None, blocks=None,
269272 if b in blocks ]
270273 block_tasks = [self .block_tasks [b ] for b in range (self ._num_blocks )
271274 if b in blocks ]
272- if times :
275+ if times and self . _time_series :
273276 for k , v in params ['weights' ].items ():
274277 params ['weights' ][k ] = v [:, :, times [0 ]:times [1 ], :]
275278
0 commit comments