55import numpy as np
66
77class SpatialConvolutionCUDNN (df .Module ):
8- def __init__ (self , nchan_in , nchan_out , filter_size , stride = 1 , border = 0 , mode = 'cross' , init = df .init .xavier (), bias = df . init . const ( 0 ) ):
8+ def __init__ (self , nchan_in , nchan_out , filter_size , stride = 1 , border = 0 , mode = 'cross' , init = df .init .xavier (), bias = 0 ):
99 # mode='cross' is the default in Lasagne[1], Torch[2], matConvNet[3], Caffee[4].
1010 #
1111 # 1: https://github.com/Lasagne/Lasagne/blob/63d44a0d/lasagne/layers/dnn.py#L299
1212 # 2: https://github.com/soumith/cudnn.torch/blob/840f0228/SpatialConvolution.lua#L83
1313 # 3: https://github.com/vlfeat/matconvnet/blob/b7dd9c96/matlab/src/bits/impl/nnconv_cudnn.cu#L133
1414 # 4: https://github.com/BVLC/caffe/blob/50ab52cb/include/caffe/util/cudnn.hpp#L104
1515 df .Module .__init__ (self )
16-
16+
1717 # Catch a probably common bug while we transition the API.
1818 assert isinstance (filter_size , (list , tuple )), "New conv API: filter_size needs to be a tuple!"
1919
@@ -33,11 +33,7 @@ def __init__(self, nchan_in, nchan_out, filter_size, stride=1, border=0, mode='c
3333 w_fan = (np .prod (self .filter_size )* nchan_in , np .prod (self .filter_size )* nchan_out )
3434 w_name = ('Wconv_{},{}@{}' + 'x{}' * (len (w_shape ) - 3 )).format (* w_shape )
3535 self .W = self ._addparam (w_shape , init , fan = w_fan , name = w_name )
36-
37- if bias not in (None , False ):
38- self .b = self ._addparam (nchan_out , bias , decay = False , name = 'bconv_{}' .format (nchan_out ))
39- else :
40- self .b = None
36+ self .b = self ._addparam_optional (nchan_out , bias , decay = False , name = 'bconv_{}' .format (nchan_out ))
4137
4238
4339 def symb_forward (self , symb_input ):
0 commit comments