11from .Module import Module
2+ from beacon8 .init import zero , xavier
3+ from beacon8 .utils import create_param_and_grad
24
35import theano as _th
46import numpy as _np
57
68
79class SpatialConvolution (Module ):
8- def __init__ (self , n_input_plane , n_output_plane , k_w , k_h , d_w = 1 , d_h = 1 , with_bias = True , border_mode = 'valid' , imshape = None ):
10+ def __init__ (self , n_input_plane , n_output_plane , k_w , k_h , d_w = 1 , d_h = 1 , with_bias = True , border_mode = 'valid' , imshape = None , init = xavier , init_b = zero ):
911 Module .__init__ (self )
1012 self .n_input_plane = n_input_plane
1113 self .n_output_plane = n_output_plane
@@ -17,19 +19,17 @@ def __init__(self, n_input_plane, n_output_plane, k_w, k_h, d_w=1, d_h=1, with_b
1719 self .border_mode = border_mode
1820 self .imshape = imshape
1921
20- w_bound = _np .sqrt (4. / ((self .n_input_plane + self .n_output_plane ) * self .k_w * self .k_h ))
21- W = _np .random .uniform (low = - w_bound , high = w_bound , size = (n_output_plane , n_input_plane , k_h , k_w ))
22- self .weight = _th .shared (W .astype (dtype = _th .config .floatX ))
23- self .grad_weight = _th .shared ((W * 0 ).astype (dtype = _th .config .floatX ))
22+ self .w_shape = (n_output_plane , n_input_plane , k_h , k_w )
23+ w_fan = (n_input_plane * k_w * k_h , n_output_plane * k_w * k_h )
2424
25+ self .weight , self .grad_weight = create_param_and_grad (self .w_shape , init , fan = w_fan , name = 'Wconv_{},{}@{}x{}' .format (n_input_plane , n_output_plane , k_w , k_h ))
2526 if self .with_bias :
26- self .bias = _th .shared (_np .zeros (shape = (n_output_plane , ), dtype = _th .config .floatX ))
27- self .grad_bias = _th .shared (_np .zeros (shape = (n_output_plane , ), dtype = _th .config .floatX ))
27+ self .bias , self .grad_bias = create_param_and_grad (n_output_plane , init_b , name = 'bconv_{}' .format (n_output_plane ))
2828
2929 def symb_forward (self , symb_input ):
3030 conv_output = _th .tensor .nnet .conv .conv2d (symb_input , self .weight ,
3131 image_shape = (None , self .n_input_plane ) + (self .imshape or (None , None )),
32- filter_shape = ( self .n_output_plane , self . n_input_plane , self . k_h , self . k_w ) ,
32+ filter_shape = self .w_shape ,
3333 border_mode = self .border_mode ,
3434 subsample = (self .d_h , self .d_w )
3535 )
@@ -38,4 +38,3 @@ def symb_forward(self, symb_input):
3838 return conv_output + self .bias .dimshuffle ('x' , 0 , 'x' , 'x' )
3939 else :
4040 return conv_output
41-
0 commit comments