@@ -282,12 +282,12 @@ def block_rv_weight(node, prior=True):
282282
283283 def free_energy (self , batch_size = 64 , use_cuda = True , blocks_batch_size = 4 ,
284284 blocks_filter = lambda block : True , num_particles = 1 ,
285- sample_size = 10 ):
286- training_blocks = [(b , block ) for (b , block ) in enumerate (self ._blocks )
287- if blocks_filter (block )]
285+ sample_size = 10 , predictive = False ):
286+ testing_blocks = [(b , block ) for (b , block ) in enumerate (self ._blocks )
287+ if blocks_filter (block )]
288288 activations_loader = torch .utils .data .DataLoader (
289289 utils .TFADataset ([block .activations .detach ()
290- for (_ , block ) in training_blocks ]),
290+ for (_ , block ) in testing_blocks ]),
291291 batch_size = batch_size ,
292292 pin_memory = True ,
293293 )
@@ -310,11 +310,11 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
310310
311311 for k in range (sample_size // num_particles ):
312312 for (batch , data ) in enumerate (activations_loader ):
313- block_batches = utils .chunks (list (range (len (training_blocks ))),
313+ block_batches = utils .chunks (list (range (len (testing_blocks ))),
314314 n = blocks_batch_size )
315315 for block_batch in block_batches :
316316 activations = [{'Y' : data [:, b , :]} for b in block_batch ]
317- block_batch = [training_blocks [b ][0 ] for b in block_batch ]
317+ block_batch = [testing_blocks [b ][0 ] for b in block_batch ]
318318 if tfa .CUDA and use_cuda :
319319 for b in block_batch :
320320 generative .likelihoods [b ].voxel_locations = \
0 commit comments