@@ -6,10 +6,8 @@ class InsertZeroPaddingBeforeConv1D(OptimizerPass):
66 name = 'insert_zero_padding_before_conv1d'
77
88 def match (self , node ):
9- is_match = (
10- isinstance (node , (Conv1D , SeparableConv1D ))
11- and ((node .get_attr ('padding' ) == 'same' ) or (node .get_attr ('padding' ) == 'causal' ))
12- and node .get_attr ('filt_width' ) != 1
9+ is_match = isinstance (node , (Conv1D , SeparableConv1D )) and (
10+ (node .get_attr ('pad_left' ) != 0 ) or (node .get_attr ('pad_right' ) != 0 )
1311 )
1412 return is_match
1513
@@ -37,7 +35,6 @@ def transform(self, model, node):
3735 }
3836
3937 # Switch Conv1D layer padding to 'valid'
40- node .set_attr ('padding' , 'valid' )
4138 node .set_attr ('pad_left' , 0 )
4239 node .set_attr ('pad_right' , 0 )
4340 node .set_attr ('in_width' , out_width )
@@ -54,11 +51,11 @@ class InsertZeroPaddingBeforeConv2D(OptimizerPass):
5451 name = 'insert_zero_padding_before_conv2d'
5552
5653 def match (self , node ):
57- is_match = (
58- isinstance (node , ( Conv2D , SeparableConv2D ) )
59- and node .get_attr ('padding ' ) == 'same'
60- and node .get_attr ('filt_height ' ) != 1
61- and node .get_attr ('filt_width ' ) != 1
54+ is_match = isinstance ( node , ( Conv2D , SeparableConv2D )) and (
55+ (node . get_attr ( 'pad_left' ) != 0 )
56+ or ( node .get_attr ('pad_right ' ) != 0 )
57+ or ( node .get_attr ('pad_top ' ) != 0 )
58+ or ( node .get_attr ('pad_bottom ' ) != 0 )
6259 )
6360 return is_match
6461
@@ -93,7 +90,6 @@ def transform(self, model, node):
9390 }
9491
9592 # Switch Conv2D layer padding to 'valid'
96- node .set_attr ('padding' , 'valid' )
9793 node .set_attr ('pad_top' , 0 )
9894 node .set_attr ('pad_bottom' , 0 )
9995 node .set_attr ('pad_left' , 0 )
0 commit comments