Skip to content

Commit 3d0ba4a

Browse files
committed
tweak
1 parent df6169d commit 3d0ba4a

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

perfusion_pytorch/perfusion.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,13 @@ def forward(
174174
# for the prompt ids not initialized yet, hard copy over the initial superclass outputs
175175
self.superclass_output.data.copy_(superclass_output)
176176

177+
elif exists(superclass_output):
178+
# if text enc with superclass is passed in for more than 1 batch
179+
# just take the opportunity to exponentially average it a bit more
180+
181+
ema_superclass_output = self.superclass_output * decay + superclass_output * (1. - decay)
182+
self.superclass_output.data.copy_(ema_superclass_output)
183+
177184
# if any in the batch is not initialized, initialize
178185

179186
if not initted:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.17',
6+
version = '0.0.18',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)