@@ -33,6 +33,7 @@ def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = Non
3333
3434 def build_model (
3535 self ,
36+ op_type : str ,
3637 input_shape : ir .Shape ,
3738 weight_shape : typing .Sequence [int ],
3839 pad_inputs : typing .Sequence [ir .TensorProtocol | ir .Value | None ],
@@ -57,14 +58,17 @@ def build_model(
5758 raise ValueError (f"Unsupported type for pad input ({ x } ): { type (x )} ." )
5859
5960 # Register operations in the tape
60- x = ir .Input ("X" , shape = input_shape , type = ir .TensorType (ir .DataType .FLOAT ))
61+ idtype = ir .DataType .UINT8 if op_type == "ConvInteger" else ir .DataType .FLOAT
62+ x = ir .Input ("X" , shape = input_shape , type = ir .TensorType (idtype ))
6163 y = tape .op ("Pad" , inputs = [x , * pad_inputs ], attributes = pad_attributes )
6264 y = tape .op (
63- "Conv" ,
65+ op_type ,
6466 inputs = [y , self .get_conv_weights (weight_shape , tape )],
6567 attributes = conv_attributes ,
6668 output = ir .Input ("Y" , shape = output_shape , type = ir .TensorType (x .dtype )),
6769 )
70+ if op_type == "ConvInteger" :
71+ y .dtype = ir .DataType .INT32
6872
6973 # Build the model
7074 ir_model = ir .Model (
@@ -101,6 +105,7 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
101105 if axes is not None :
102106 pad_inputs .append (axes )
103107 base_model = self .build_model (
108+ op_type = "Conv" ,
104109 input_shape = ir .Shape (("N" , 32 , 14 , 16 )),
105110 weight_shape = (10 , 32 , 3 , 3 ),
106111 pad_inputs = pad_inputs ,
@@ -190,6 +195,7 @@ def test_unsupported_fuse_pad_into_conv(
190195 self , mode , pads , const_value , axes , auto_pad , err_msg
191196 ):
192197 base_model = self .build_model (
198+ op_type = "Conv" ,
193199 input_shape = ir .Shape (("N" , 32 , 14 , 16 , 12 )),
194200 weight_shape = (10 , 32 , 3 , 4 , 5 ),
195201 pad_inputs = [pads , const_value , axes ],
@@ -208,5 +214,51 @@ def test_unsupported_fuse_pad_into_conv(
208214 self .assertRegex (tracer_match .match_result .reason , err_msg )
209215
210216
217+ class FusePadConvIntegerTest (FusePadConvBaseTest ):
218+ def get_conv_weights (self , shape : typing .Sequence [int ], tape : ir .tape .Tape = None ):
219+ w = ir .tensor (self .rng .integers (0 , 256 , shape ).astype ("uint8" ), name = "W" )
220+ if tape is not None :
221+ w = tape .initializer (w )
222+ return w
223+
224+ @parameterized .parameterized .expand (
225+ [
226+ (pad_pads , const_value , axes , conv_pads )
227+ for pad_pads , axes , conv_pads in [
228+ ([0 , 0 , 3 , 2 , 0 , 0 , 1 , 4 ], None , [1 , 1 , 1 , 1 ]),
229+ ([2 , 2 , 0 , 2 , 2 , 0 ], ir .tensor ([- 2 , - 1 , 1 ], name = "axes" ), None ),
230+ ([1 , 2 , 2 , 1 ], ir .tensor ([- 1 , 2 ], name = "axes" ), [0 , 1 , 0 , 1 ]),
231+ ]
232+ for const_value in [None , ir .tensor (np .array ([0 ], "uint8" ), name = "const_value" )]
233+ ]
234+ )
235+ def test_fuse_pad_into_conv_integer (self , pad_pads , const_value , axes , conv_pads ):
236+ pad_inputs = [ir .tensor (pad_pads , name = "pads" )]
237+ if const_value is not None or axes is not None :
238+ pad_inputs .append (const_value )
239+ if axes is not None :
240+ pad_inputs .append (axes )
241+ base_model = self .build_model (
242+ op_type = "ConvInteger" ,
243+ input_shape = ir .Shape (("N" , 24 , 19 , 23 )),
244+ weight_shape = (8 , 24 , 3 , 3 ),
245+ pad_inputs = pad_inputs ,
246+ conv_attributes = {"pads" : conv_pads },
247+ )
248+ updated_model = _clone_model (base_model )
249+
250+ # Apply rule
251+ count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
252+
253+ # Check that Pad was fused
254+ self .assertEqual (count , 1 )
255+ self .assertEqual (updated_model .graph .num_nodes (), 1 )
256+ onnx_checker .CheckerPass (True )(updated_model )
257+
258+ # Check inference
259+ inputs = self .rng .integers (0 , 255 , (1 , 24 , 19 , 23 ), dtype = "uint8" )
260+ testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
261+
262+
211263if __name__ == "__main__" :
212264 unittest .main ()
0 commit comments