@@ -11,16 +11,16 @@ class BatchNormalization(Module):
1111 def __init__ (self , n_features , eps = None ):
1212 Module .__init__ (self )
1313
14- self .weight , self .grad_weight = create_param_and_grad (n_features , const (1 ), 'W_BN' )
15- self .bias , self .grad_bias = create_param_and_grad (n_features , const (0 ), 'b_BN' )
14+ self .weight , self .grad_weight = create_param_and_grad (n_features , const (1 ), name = 'W_BN_{}' . format ( n_features ) )
15+ self .bias , self .grad_bias = create_param_and_grad (n_features , const (0 ), name = 'b_BN_{}' . format ( n_features ) )
1616
17- self .inference_weight = create_param (n_features , const (1 ), 'W_BN_inf' )
18- self .inference_bias = create_param (n_features , const (0 ), 'b_BN_inf' )
17+ self .inference_weight = create_param (n_features , const (1 ), name = 'W_BN_{}_inf' . format ( n_features ) )
18+ self .inference_bias = create_param (n_features , const (0 ), name = 'b_BN_{}_inf' . format ( n_features ) )
1919
2020 # These are buffers for collecting the minibatch statistics.
21- self .buffer_variance = create_param (n_features , const (1 ), 'BN_var' )
22- self .buffer_mean = create_param (n_features , const (0 ), 'BN_mean' )
23- self .buffer_counts = _th .shared (_np .asarray (0 , dtype = _th .config .floatX ))
21+ self .buffer_variance = create_param (n_features , const (1 ), name = 'BN_var_{}' . format ( n_features ) )
22+ self .buffer_mean = create_param (n_features , const (0 ), name = 'BN_mean_{}' . format ( n_features ) )
23+ self .buffer_counts = _th .shared (_np .asarray (0 , dtype = _th .config .floatX ), name = 'BN_count_{}' . format ( n_features ) )
2424
2525 self .eps = eps or 1e-5
2626
0 commit comments