Skip to content

Commit d66723a

Browse files
committed
Warn when switching to eval with BN but no stats.
1 parent 58a8710 commit d66723a

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

DeepFried2/layers/BatchNormalization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from DeepFried2.utils import flatten
33

44
import numpy as _np
5+
from warnings import warn as _warn
56

67

78
class BatchNormalization(df.Module):
@@ -84,6 +85,10 @@ def evaluate(self):
8485
self.Winf.set_value(self.W.get_value() / _np.sqrt(self.buf_var.get_value() + self.eps))
8586
self.binf.set_value(self.b.get_value() - self.Winf.get_value() * self.buf_mean.get_value())
8687

88+
# This check saved me from WTF'ing countless times!
89+
if self.buf_count.get_value() == 0:
90+
_warn("You're switching a BN-net to eval mode without having collected any statistics, that can't go well!")
91+
8792
def __getstate__(self):
8893
regular = df.Module.__getstate__(self)
8994
return [buf.get_value() for buf in (self.buf_mean, self.buf_var, self.buf_count)] + regular

0 commit comments

Comments
 (0)