Skip to content

Commit bd6e3f9

Browse files
authored
just overwrite concept_text_enc (i* in paper)
1 parent ef354a1 commit bd6e3f9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def forward(
9797

9898
# update using exponential moving average
9999

100-
ema_concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
100+
concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay)
101101

102102
self.initted[prompt_ids] = True
103-
self.ema_concept_text_enc[prompt_ids] = ema_concept_text_enc
103+
self.ema_concept_text_enc[prompt_ids] = concept_text_enc
104104

105105
return einsum('b n i, o i -> b n o', text_enc, weights)

0 commit comments

Comments
 (0)