Skip to content

Commit 14e4586

Browse files
authored
Merge pull request #427 from Xilinx/matthias.fix_concat_zero_fold
TOSA: concat: fix canonicalization that would result in concat with no operands
2 parents c9c2863 + ae05589 commit 14e4586

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,19 +1349,13 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
13491349
return {};
13501350
}
13511351

1352-
static bool hasZeroSize(Type ty) {
1353-
auto ranked = dyn_cast<RankedTensorType>(ty);
1354-
if (!ranked)
1355-
return false;
1356-
return any_of(ranked.getShape(), [](auto d) { return d == 0; });
1357-
}
1358-
13591352
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
13601353
/// Remove operands that have zero elements.
13611354
bool changed = false;
13621355
for (size_t i = 0; i < getInput1().size(); ) {
1363-
auto input = getInput1()[i];
1364-
if (hasZeroSize(input.getType())) {
1356+
auto input = cast<RankedTensorType>(getInput1()[i].getType());
1357+
// Ensure that we have at least one operand left.
1358+
if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) {
13651359
getInput1Mutable().erase(i);
13661360
changed = true;
13671361
} else {

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,32 @@ func.func @concat_fold_zero(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg
204204
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
205205
return %0 : tensor<?x3xf32>
206206
}
207+
// -----
208+
209+
// CHECK-LABEL: @concat_fold_zero
210+
func.func @concat_fold_zero_all(%arg0: tensor<?x0xf32>, %arg1: tensor<?x0xf32>) -> tensor<?x0xf32> {
211+
// CHECK: return %arg1
212+
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x0xf32>) -> tensor<?x0xf32>
213+
return %0 : tensor<?x0xf32>
214+
}
215+
216+
// -----
217+
218+
// CHECK-LABEL: @concat_fold_zero
219+
func.func @concat_fold_zero_different_axis(%arg0: tensor<0x2xf32>, %arg1: tensor<0x4xf32>) -> tensor<0x6xf32> {
220+
// CHECK: tosa.concat %arg0, %arg1
221+
%0 = tosa.concat %arg0, %arg1 {axis = 1 : i32}: (tensor<0x2xf32>, tensor<0x4xf32>) -> tensor<0x6xf32>
222+
return %0 : tensor<0x6xf32>
223+
}
224+
225+
// -----
226+
227+
// CHECK-LABEL: @concat_fold_zero_size
228+
func.func @concat_fold_zero_size(%arg0: tensor<?x0xf32>, %arg1: tensor<?x1xf32>, %arg2: tensor<?x2xf32>) -> tensor<?x3xf32> {
229+
// CHECK: tosa.concat %arg1, %arg2 {axis = 1 : i32}
230+
%0 = tosa.concat %arg0, %arg1, %arg2 {axis = 1 : i32}: (tensor<?x0xf32>, tensor<?x1xf32>, tensor<?x2xf32>) -> tensor<?x3xf32>
231+
return %0 : tensor<?x3xf32>
232+
}
207233

208234
// -----
209235

0 commit comments

Comments
 (0)