@@ -212,6 +212,40 @@ def test_shape_arithmetic_with_zero_broadcast(self) -> None:
212212 data_prop = True ,
213213 ) # type: ignore
214214
215+ def test_empty_tensor (self ) -> None :
216+ """Test that a Concat with an empty tensor as input is handled correctly by data-propagation."""
217+ model = onnx .parser .parse_model (
218+ """
219+ <ir_version: 7, opset_import: [ "" : 17]>
220+ agraph (float[256] y) => (float[N] z)
221+ <float[0] x = {}>
222+ {
223+ z = Concat <axis=0> (x, y)
224+ }
225+ """
226+ )
227+ inferred_model = onnx .shape_inference .infer_shapes (model , True , True , True )
228+ output = inferred_model .graph .output [0 ]
229+ self .assertEqual (output .type .tensor_type .shape .dim [0 ].dim_value , 256 )
230+
231+ def test_empty_tensor_negative_axis (self ) -> None :
232+ """Test that a Concat with an empty tensor as input is handled correctly by data-propagation.
233+ This time with a negative axis.
234+ """
235+ model = onnx .parser .parse_model (
236+ """
237+ <ir_version: 7, opset_import: [ "" : 17]>
238+ agraph (float[256] y) => (float[N] z)
239+ <float[0] x = {}>
240+ {
241+ z = Concat <axis=-1> (x, y)
242+ }
243+ """
244+ )
245+ inferred_model = onnx .shape_inference .infer_shapes (model , True , True , True )
246+ output = inferred_model .graph .output [0 ]
247+ self .assertEqual (output .type .tensor_type .shape .dim [0 ].dim_value , 256 )
248+
215249
216250if __name__ == "__main__" :
217251 unittest .main ()
0 commit comments