Skip to content

Commit 8eb0949

Browse files
committed
Merge pull request #11 from lucasb-eyer/empty-stats
Don't compile/call function when there's no stats to be accumulated.
2 parents 5d86833 + f5858f1 commit 8eb0949

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

beacon8/layers/BatchNormalization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def symb_forward(self, symb_input):
4242
else:
4343
return symb_input * self.inference_weight.dimshuffle(*d_shuffle) + self.inference_bias.dimshuffle(*d_shuffle)
4444

45-
def get_stat_updates(self,):
45+
def get_stat_updates(self):
4646
assert (self.batch_mean is not None) and (self.batch_var is not None), "You need to do a forward pass first"
4747

4848
stat_updates = list()
@@ -66,4 +66,4 @@ def training(self):
6666
def evaluate(self):
6767
Module.evaluate(self)
6868
self.inference_weight.set_value(self.weight.get_value() / _np.sqrt(self.buffer_variance.get_value() + self.eps))
69-
self.inference_bias.set_value(self.bias.get_value() - self.inference_weight.get_value() * self.buffer_mean.get_value())
69+
self.inference_bias.set_value(self.bias.get_value() - self.inference_weight.get_value() * self.buffer_mean.get_value())

beacon8/layers/Module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def accumulate_statistics(self, data_in):
8181
self.symb_forward(symb_in)
8282

8383
stat_updates = self.get_stat_updates()
84+
if not stat_updates:
85+
# If there's no layer collecting statistics, we don't need to
86+
# compile and call a function. This prevents theano errors.
87+
return
8488

8589
self.fn_accum_stats = _th.function(
8690
inputs=[symb_in],

0 commit comments

Comments
 (0)