File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -42,7 +42,7 @@ def symb_forward(self, symb_input):
4242 else :
4343 return symb_input * self .inference_weight .dimshuffle (* d_shuffle ) + self .inference_bias .dimshuffle (* d_shuffle )
4444
45- def get_stat_updates (self , ):
45+ def get_stat_updates (self ):
4646 assert (self .batch_mean is not None ) and (self .batch_var is not None ), "You need to do a forward pass first"
4747
4848 stat_updates = list ()
@@ -66,4 +66,4 @@ def training(self):
6666 def evaluate (self ):
6767 Module .evaluate (self )
6868 self .inference_weight .set_value (self .weight .get_value () / _np .sqrt (self .buffer_variance .get_value () + self .eps ))
69- self .inference_bias .set_value (self .bias .get_value () - self .inference_weight .get_value () * self .buffer_mean .get_value ())
69+ self .inference_bias .set_value (self .bias .get_value () - self .inference_weight .get_value () * self .buffer_mean .get_value ())
Original file line number Diff line number Diff line change @@ -81,6 +81,10 @@ def accumulate_statistics(self, data_in):
8181 self .symb_forward (symb_in )
8282
8383 stat_updates = self .get_stat_updates ()
84+ if not stat_updates :
85+ # If there's no layer collecting statistics, we don't need to
86+ # compile and call a function. This prevents theano errors.
87+ return
8488
8589 self .fn_accum_stats = _th .function (
8690 inputs = [symb_in ],
You can’t perform that action at this time.
0 commit comments