77
88
99class SpatialConvolution (Module ):
10- def __init__ (self , n_input_plane , n_output_plane , k_w , k_h , d_w = 1 , d_h = 1 , with_bias = True , border_mode = 'valid' , imshape = None , init = xavier (), init_b = zero ):
10+ def __init__ (self , n_input_plane , n_output_plane , k_w , k_h , d_w = 1 , d_h = 1 , with_bias = True , initW = xavier (), initB = zero , border_mode = 'valid' , imshape = None ):
1111 Module .__init__ (self )
1212 self .n_input_plane = n_input_plane
1313 self .n_output_plane = n_output_plane
@@ -22,9 +22,9 @@ def __init__(self, n_input_plane, n_output_plane, k_w, k_h, d_w=1, d_h=1, with_b
2222 self .w_shape = (n_output_plane , n_input_plane , k_h , k_w )
2323 w_fan = (n_input_plane * k_w * k_h , n_output_plane * k_w * k_h )
2424
25- self .weight , self .grad_weight = create_param_and_grad (self .w_shape , init , fan = w_fan , name = 'Wconv_{},{}@{}x{}' .format (n_input_plane , n_output_plane , k_w , k_h ))
25+ self .weight , self .grad_weight = create_param_and_grad (self .w_shape , initW , fan = w_fan , name = 'Wconv_{},{}@{}x{}' .format (n_input_plane , n_output_plane , k_w , k_h ))
2626 if self .with_bias :
27- self .bias , self .grad_bias = create_param_and_grad (n_output_plane , init_b , name = 'bconv_{}' .format (n_output_plane ))
27+ self .bias , self .grad_bias = create_param_and_grad (n_output_plane , initB , name = 'bconv_{}' .format (n_output_plane ))
2828
2929 def symb_forward (self , symb_input ):
3030 conv_output = _th .tensor .nnet .conv .conv2d (symb_input , self .weight ,
0 commit comments