Skip to content

Commit 25a134a

Browse files
authored
Fix handling of tensor rank in concat data propagation (onnx#6570)
### Description Fix issue onnx#6276 (data propagation fails on Concat when first tensor is initialized empty). --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 96a0ca4 commit 25a134a

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

onnx/defs/data_propagators.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,21 @@ inline bool axisIsZero(DataPropagationContext& ctx, bool defaultZero = false) {
3333
}
3434
}
3535
int axis = static_cast<int>(axisAttr->i());
36-
auto input_data_0 = ctx.getInputData(0);
37-
if (input_data_0 == nullptr) {
36+
if (axis >= 0) {
37+
return axis == 0;
38+
}
39+
// For negative axes, we need rank information to determine if it is equivalent to axis 0
40+
const TypeProto* type = ctx.getInputType(0);
41+
if ((type == nullptr) || (!type->has_tensor_type()) || (!type->tensor_type().has_shape())) {
3842
return false;
3943
}
40-
int rank = input_data_0->dim_size();
44+
45+
int rank = type->tensor_type().shape().dim_size();
4146
if (axis < -rank || axis >= rank) {
4247
fail_shape_inference("axis must be in [-rank, rank-1].");
4348
return false;
4449
}
45-
if (axis < 0) {
46-
axis += rank;
47-
}
50+
axis += rank;
4851
// Only supports axis = 0 since the data comes from Shape
4952
return axis == 0;
5053
}

onnx/test/data_propagation_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,40 @@ def test_shape_arithmetic_with_zero_broadcast(self) -> None:
212212
data_prop=True,
213213
) # type: ignore
214214

215+
def test_empty_tensor(self) -> None:
216+
"""Test that a Concat with an empty tensor as input is handled correctly by data-propagation."""
217+
model = onnx.parser.parse_model(
218+
"""
219+
<ir_version: 7, opset_import: [ "" : 17]>
220+
agraph (float[256] y) => (float[N] z)
221+
<float[0] x = {}>
222+
{
223+
z = Concat <axis=0> (x, y)
224+
}
225+
"""
226+
)
227+
inferred_model = onnx.shape_inference.infer_shapes(model, True, True, True)
228+
output = inferred_model.graph.output[0]
229+
self.assertEqual(output.type.tensor_type.shape.dim[0].dim_value, 256)
230+
231+
def test_empty_tensor_negative_axis(self) -> None:
232+
"""Test that a Concat with an empty tensor as input is handled correctly by data-propagation.
233+
This time with a negative axis.
234+
"""
235+
model = onnx.parser.parse_model(
236+
"""
237+
<ir_version: 7, opset_import: [ "" : 17]>
238+
agraph (float[256] y) => (float[N] z)
239+
<float[0] x = {}>
240+
{
241+
z = Concat <axis=-1> (x, y)
242+
}
243+
"""
244+
)
245+
inferred_model = onnx.shape_inference.infer_shapes(model, True, True, True)
246+
output = inferred_model.graph.output[0]
247+
self.assertEqual(output.type.tensor_type.shape.dim[0].dim_value, 256)
248+
215249

216250
if __name__ == "__main__":
217251
unittest.main()

0 commit comments

Comments
 (0)