@@ -11,13 +11,13 @@ def symb_forward(self, symb_inputs):
1111 return s
1212
1313
14- def block (nchan , fs = (3 ,3 ), body = None ):
14+ def block (nchan , fs = (3 ,3 ), body = None , bnmom = False ):
1515 return df .Sequential (
1616 df .RepeatInput (
1717 df .Sequential (
18- df .BatchNormalization (nchan ), df .ReLU (),
18+ df .BatchNormalization (nchan , bnmom ), df .ReLU (),
1919 df .SpatialConvolutionCUDNN (nchan , nchan , fs , border = 'same' , init = df .init .prelu (), bias = False ),
20- df .BatchNormalization (nchan ), df .ReLU (),
20+ df .BatchNormalization (nchan , bnmom ), df .ReLU (),
2121 df .SpatialConvolutionCUDNN (nchan , nchan , fs , border = 'same' , init = df .init .prelu (), bias = False )
2222 ) if body is None else body ,
2323 df .Identity ()
@@ -26,13 +26,13 @@ def block(nchan, fs=(3,3), body=None):
2626 )
2727
2828
29- def block_proj (nin , nout , fs = (3 ,3 ), body = None ):
29+ def block_proj (nin , nout , fs = (3 ,3 ), body = None , bnmom = False ):
3030 return df .Sequential (
3131 df .RepeatInput (
3232 df .Sequential (
33- df .BatchNormalization (nin ), df .ReLU (),
33+ df .BatchNormalization (nin , bnmom ), df .ReLU (),
3434 df .SpatialConvolutionCUDNN (nin , nout , fs , border = 'same' , init = df .init .prelu (), bias = False ),
35- df .BatchNormalization (nout ), df .ReLU (),
35+ df .BatchNormalization (nout , bnmom ), df .ReLU (),
3636 df .SpatialConvolutionCUDNN (nout , nout , fs , border = 'same' , init = df .init .prelu (), bias = False )
3737 ) if body is None else body ,
3838 df .SpatialConvolutionCUDNN (nin , nout , (1 ,)* len (fs )),
0 commit comments