Skip to content

Commit aa30232

Browse files
committed
dtfa.DeepTFA: switch from training_blocks to testing_blocks in free_energy
Signed-off-by: Eli Sennesh <sennesh.e@husky.neu.edu>
1 parent c43449c commit aa30232

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

htfa_torch/dtfa.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,12 @@ def block_rv_weight(node, prior=True):
282282

283283
def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
284284
blocks_filter=lambda block: True, num_particles=1,
285-
sample_size=10):
286-
training_blocks = [(b, block) for (b, block) in enumerate(self._blocks)
287-
if blocks_filter(block)]
285+
sample_size=10, predictive=False):
286+
testing_blocks = [(b, block) for (b, block) in enumerate(self._blocks)
287+
if blocks_filter(block)]
288288
activations_loader = torch.utils.data.DataLoader(
289289
utils.TFADataset([block.activations.detach()
290-
for (_, block) in training_blocks]),
290+
for (_, block) in testing_blocks]),
291291
batch_size=batch_size,
292292
pin_memory=True,
293293
)
@@ -310,11 +310,11 @@ def free_energy(self, batch_size=64, use_cuda=True, blocks_batch_size=4,
310310

311311
for k in range(sample_size // num_particles):
312312
for (batch, data) in enumerate(activations_loader):
313-
block_batches = utils.chunks(list(range(len(training_blocks))),
313+
block_batches = utils.chunks(list(range(len(testing_blocks))),
314314
n=blocks_batch_size)
315315
for block_batch in block_batches:
316316
activations = [{'Y': data[:, b, :]} for b in block_batch]
317-
block_batch = [training_blocks[b][0] for b in block_batch]
317+
block_batch = [testing_blocks[b][0] for b in block_batch]
318318
if tfa.CUDA and use_cuda:
319319
for b in block_batch:
320320
generative.likelihoods[b].voxel_locations =\

0 commit comments

Comments
 (0)