@@ -68,19 +68,13 @@ def __init__(self, as_function: bool = False):
6868 def rewrite (
6969 self , op : ir .tape .Tape , x : ir .Value , pad : ir .Value , conv : ir .Value
7070 ) -> ir .Value :
71- pad_node = pad .producer ()
7271 conv_node = conv .producer ()
7372
7473 # Retrieve the padding and axes
7574 x_rank = len (x .shape )
76- pad_pads = pad_node .inputs [1 ].const_value .numpy ().tolist ()
77- if len (pad_node .inputs ) > 3 and (axes := pad_node .inputs [3 ]) is not None :
78- axes = [x if x >= 0 else x_rank + x for x in axes .const_value .numpy ()]
79- else :
80- axes = list (range (x_rank ))
8175
82- # Fulfill pad_pads in every dimension (filling with zero the other ones )
83- pad_pads = fill_pads_with_axes ( pad_pads , axes , x_rank )
76+ # Get computed pads in check( )
77+ pad_pads = self . _pads_list
8478
8579 # Get only spatial pads
8680 new_pads = pad_pads [2 :x_rank ] + pad_pads [x_rank + 2 :]
@@ -145,8 +139,9 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
145139 axes_list = list (range (x_rank ))
146140
147141 # Pad constraints: values
148- pads_list = fill_pads_with_axes (pads .const_value .numpy (), axes_list , x_rank )
149- if np .any (pads_list [:2 ] + pads_list [x_rank : x_rank + 2 ]):
142+ self ._pads_list = fill_pads_with_axes (pads .const_value .numpy (), axes_list , x_rank )
143+ if np .any (self ._pads_list [:2 ] + self ._pads_list [x_rank : x_rank + 2 ]):
144+ self ._pads_list = None
150145 return check_result .fail (f"{ pads .name } must be zero in non-spatial dimensions." )
151146
152147 return check_result
@@ -164,14 +159,12 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
164159
165160 def check (self , context , x : ir .Value , pad : ir .Value , conv : ir .Value ) -> orp .MatchResult :
166161 check_result = super ().check (context , x , pad , conv )
167- if check_result . reason :
162+ if not check_result :
168163 return check_result
169164
170165 # Conv constraints: attributes
171166 conv_node = conv .producer ()
172- if (
173- apad := conv_node .attributes .get ("auto_pad" , None )
174- ) and apad .as_string () != "NOTSET" :
167+ if conv_node .attributes .get_string ("auto_pad" , "NOTSET" ) != "NOTSET" :
175168 return check_result .fail (
176169 f"{ conv_node .name } ({ conv_node .op_type } ) auto_pad must be 'NOTSET'."
177170 )
0 commit comments