@@ -30,7 +30,7 @@ def make_node(self, data, kernel):
3030 elif self .mode == "valid" :
3131 out_shape = (max (n , k ) - min (n , k ) + 1 ,)
3232 elif self .mode == "same" :
33- out_shape = (max ( n , k ) ,)
33+ out_shape = (n ,)
3434
3535 out = pt .tensor (dtype = dtype , shape = out_shape )
3636 return Apply (self , [data , kernel ], [out ])
@@ -48,29 +48,34 @@ def infer_shape(self, fgraph, node, shapes):
4848 elif self .mode == "valid" :
4949 shape = pt .maximum (n , k ) - pt .minimum (n , k ) + 1
5050 elif self .mode == "same" :
51- shape = pt . maximum ( n , k )
51+ shape = n
5252 return [[shape ]]
5353
5454 def L_op (self , inputs , outputs , output_grads ):
5555 data , kernel = inputs
5656 [grad ] = output_grads
5757
5858 if self .mode == "full" :
59- valid_conv = type (self )(mode = "valid" )
60- data_bar = valid_conv (grad , kernel [::- 1 ])
61- kernel_bar = valid_conv (grad , data [::- 1 ])
59+ data_bar = convolve (grad , kernel [::- 1 ], mode = "valid" )
60+ kernel_bar = convolve (grad , data [::- 1 ], mode = "valid" )
6261
6362 elif self .mode == "valid" :
64- full_conv = type (self )(mode = "full" )
6563 n = data .shape [0 ]
6664 k = kernel .shape [0 ]
6765 kmn = pt .maximum (0 , k - n )
6866 nkm = pt .maximum (0 , n - k )
6967 # We need mode="full" if k >= n else "valid" for data_bar (opposite for kernel_bar), but mode is not symbolic.
7068 # Instead we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
71- data_bar = full_conv (grad , kernel [::- 1 ])
69+ data_bar = convolve (grad , kernel [::- 1 ], mode = "full" )
7270 data_bar = data_bar [kmn : data_bar .shape [0 ] - kmn ]
73- kernel_bar = full_conv (grad , data [::- 1 ])
71+ kernel_bar = convolve (grad , data [::- 1 ], mode = "full" )
7472 kernel_bar = kernel_bar [nkm : kernel_bar .shape [0 ] - nkm ]
7573
74+ else : # self.mode == "same"
75+ raise NotImplementedError ("L_op not implemented for mode='same'" )
76+
7677 return [data_bar , kernel_bar ]
78+
79+
80+ def convolve (data , kernel , mode = "full" ):
81+ return Conv1d (mode )(data , kernel )
0 commit comments