Skip to content

Commit f69e914

Browse files
committed
i* needs gradients
1 parent 75105da commit f69e914

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def forward(
312312

313313
# exponential moving average for concept input encoding
314314

315-
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
315+
ema_concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
316316

317317
# store
318318

@@ -321,7 +321,7 @@ def forward(
321321

322322
# update ema i_*
323323

324-
self.ema_concept_text_encs[concept_id].data.copy_(concept_text_enc)
324+
self.ema_concept_text_encs[concept_id].data.copy_(ema_concept_text_enc)
325325

326326
else:
327327
assert self.initted[concept_id_tensor].all(), 'you have not initialized or trained this module for the concepts id given'
@@ -330,6 +330,12 @@ def forward(
330330

331331
i, o, W = self.ema_concept_text_encs[concept_id_tensor], self.concept_outputs[concept_id_tensor], weights
332332

333+
# if training, i* needs gradients. use straight-through?
334+
# check with author about this
335+
336+
if self.training:
337+
i = i + concept_text_enc - concept_text_enc.detach()
338+
333339
# main contribution eq (3)
334340

335341
i_energy = opt_einsum('c o, o i, c i -> c', i, Ci, i)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.1',
6+
version = '0.1.2',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)