From 75de3afe3720c7c4f1c2ae4f484dfa9b9467925a Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 07:27:25 -0400 Subject: [PATCH 1/9] Fix canonicalization pattern for shape.shape_of --- mlir/lib/Dialect/Shape/IR/Shape.cpp | 18 ++++++++++--- mlir/test/Dialect/Shape/canonicalize.mlir | 33 +++++++++++++++++++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 10ba808cd26c2..b8eac7c86797b 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1734,10 +1734,22 @@ struct ShapeOfFromReshape : public OpRewritePattern { // Operand 'shape' of 'tensor.reshape' may now be used as the result of // 'shape.shape_of'. While its type is guaranteed to be compatible in well- // formed IR, it may not be identical (dynamically vs statically shaped), - // in which case it needs to be cast first. + // in which case it needs to be cast first using 'tensor.cast'. + // Additionally, it may not have identical element type (i32 vs index) + // while it has identical shaped type (dynamic vs static), in which case it needs + // to be cast first using 'arith.index_cast'. + // Note: 'shape.shape_of' op result must be shape or extent tensor. Value shape = tensorReshapeOp.getShape(); - if (op.getType() != shape.getType()) - shape = rewriter.create(op.getLoc(), op.getType(), shape); + + auto opTensorType = llvm::dyn_cast(op.getType()); + auto shapeTensorType = llvm::dyn_cast(shape.getType()); + + if (op.getType() != shape.getType()) { + if (opTensorType.getElementType() == shapeTensorType.getElementType()) + shape = rewriter.create(op.getLoc(), op.getType(), shape); + else if (!isExtentTensorType(shape.getType())) + shape = rewriter.create(op.getLoc(), op.getType(), shape); + } rewriter.replaceOp(op, shape); return success(); diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index cf439c9c1b854..9b25468b3ab1e 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1389,10 +1389,25 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor) - // ----- -// CHECK-LABEL: func @shape_of_from_reshape_compatible_types +// Check statically shaped types, with element types i32 to index. +// CHECK-LABEL: func @shape_of_from_reshape_compatible_types1 +// CHECK-SAME: %[[INPUT:.*]]: tensor +// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> +func.func @shape_of_from_reshape_compatible_types1(%arg0: tensor, %arg1: tensor<3xi32>) -> tensor<3xindex> { + // CHECK: %[[CAST_SHAPE:.*]] = arith.index_cast %[[SHAPE]] : tensor<3xi32> to tensor<3xindex> + // CHECK: return %[[CAST_SHAPE]] : tensor<3xindex> + %0 = tensor.reshape %arg0(%arg1) : (tensor, tensor<3xi32>) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor<3xindex> + return %1 : tensor<3xindex> +} + +// ----- + +// Check similar element types, with statically shaped to dynamically shaped. +// CHECK-LABEL: func @shape_of_from_reshape_compatible_types2 // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex> -func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor { +func.func @shape_of_from_reshape_compatible_types2(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor { // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor // CHECK: return %[[CAST_SHAPE]] : tensor %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32> @@ -1402,6 +1417,20 @@ func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: t // ----- +// Check similar element types, with dynamically shaped to statically shaped. +// CHECK-LABEL: func @shape_of_from_reshape_compatible_types3 +// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> +// CHECK-SAME: %[[SHAPE:.*]]: tensor +func.func @shape_of_from_reshape_compatible_types3(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<5xindex> { + // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor to tensor<5xindex> + // CHECK: return %[[CAST_SHAPE]] : tensor<5xindex> + %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex> + return %1 : tensor<5xindex> +} + +// ----- + // CHECK-LABEL: func @shape_of_from_reshape_nofold // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor From 394735f79035ae8586521302b1b89fc99462d26d Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 08:34:15 -0400 Subject: [PATCH 2/9] dyn_cast check --- mlir/lib/Dialect/Shape/IR/Shape.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index b8eac7c86797b..f9302256eefe2 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1741,11 +1741,13 @@ struct ShapeOfFromReshape : public OpRewritePattern { // Note: 'shape.shape_of' op result must be shape or extent tensor. Value shape = tensorReshapeOp.getShape(); - auto opTensorType = llvm::dyn_cast(op.getType()); - auto shapeTensorType = llvm::dyn_cast(shape.getType()); + auto opTensorTy = llvm::dyn_cast(op.getType()); + auto shapeTensorTy = llvm::dyn_cast(shape.getType()); + if (!opTensorTy || !shapeTensorTy) + return failure(); if (op.getType() != shape.getType()) { - if (opTensorType.getElementType() == shapeTensorType.getElementType()) + if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) shape = rewriter.create(op.getLoc(), op.getType(), shape); else if (!isExtentTensorType(shape.getType())) shape = rewriter.create(op.getLoc(), op.getType(), shape); From e12e2e4534e059f11070b3b5901d37c969031f47 Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 17:03:43 -0400 Subject: [PATCH 3/9] use llvm::cast --- mlir/lib/Dialect/Shape/IR/Shape.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index f9302256eefe2..052b6cdb3eee7 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1741,10 +1741,8 @@ struct ShapeOfFromReshape : public OpRewritePattern { // Note: 'shape.shape_of' op result must be shape or extent tensor. Value shape = tensorReshapeOp.getShape(); - auto opTensorTy = llvm::dyn_cast(op.getType()); - auto shapeTensorTy = llvm::dyn_cast(shape.getType()); - if (!opTensorTy || !shapeTensorTy) - return failure(); + auto opTensorTy = llvm::cast(op.getType()); + auto shapeTensorTy = llvm::cast(shape.getType()); if (op.getType() != shape.getType()) { if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) From 137dcd06ccb214698bd3f19f19ed3d55bf19fdfc Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 18:11:52 -0400 Subject: [PATCH 4/9] Update mlir/lib/Dialect/Shape/IR/Shape.cpp Co-authored-by: Mehdi Amini --- mlir/lib/Dialect/Shape/IR/Shape.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 052b6cdb3eee7..d0b064e6fc1df 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1741,8 +1741,8 @@ struct ShapeOfFromReshape : public OpRewritePattern { // Note: 'shape.shape_of' op result must be shape or extent tensor. Value shape = tensorReshapeOp.getShape(); - auto opTensorTy = llvm::cast(op.getType()); - auto shapeTensorTy = llvm::cast(shape.getType()); + auto opTensorTy = cast(op.getType()); + auto shapeTensorTy = cast(shape.getType()); if (op.getType() != shape.getType()) { if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) From 5cf4a388840d55b64ae5fb32eee42f7f02a6603d Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 18:16:45 -0400 Subject: [PATCH 5/9] Update mlir/lib/Dialect/Shape/IR/Shape.cpp Co-authored-by: Mehdi Amini --- mlir/lib/Dialect/Shape/IR/Shape.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index d0b064e6fc1df..f66a589c72f7e 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1744,11 +1744,11 @@ struct ShapeOfFromReshape : public OpRewritePattern { auto opTensorTy = cast(op.getType()); auto shapeTensorTy = cast(shape.getType()); - if (op.getType() != shape.getType()) { + if (opTensorTy != shapeTensorTy) { if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) - shape = rewriter.create(op.getLoc(), op.getType(), shape); - else if (!isExtentTensorType(shape.getType())) - shape = rewriter.create(op.getLoc(), op.getType(), shape); + shape = rewriter.create(op.getLoc(), opTensorTy, shape); + else if (!isExtentTensorType(shapeTensorTy)) + shape = rewriter.create(op.getLoc(), opTensorTy, shape); } rewriter.replaceOp(op, shape); From d2db005b3497dc1c4c9e51b7a6e42a81edaa70c8 Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 18:41:55 -0400 Subject: [PATCH 6/9] fix code formatting issue --- mlir/lib/Dialect/Shape/IR/Shape.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index f66a589c72f7e..f670614806dbd 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1735,20 +1735,21 @@ struct ShapeOfFromReshape : public OpRewritePattern { // 'shape.shape_of'. While its type is guaranteed to be compatible in well- // formed IR, it may not be identical (dynamically vs statically shaped), // in which case it needs to be cast first using 'tensor.cast'. - // Additionally, it may not have identical element type (i32 vs index) - // while it has identical shaped type (dynamic vs static), in which case it needs - // to be cast first using 'arith.index_cast'. - // Note: 'shape.shape_of' op result must be shape or extent tensor. + // Additionally, it may not have identical element type (i32 vs index) + // while it has identical shaped type (dynamic vs static), in which case it + // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of' + // op result must be shape or extent tensor. Value shape = tensorReshapeOp.getShape(); auto opTensorTy = cast(op.getType()); auto shapeTensorTy = cast(shape.getType()); if (opTensorTy != shapeTensorTy) { - if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) - shape = rewriter.create(op.getLoc(), opTensorTy, shape); - else if (!isExtentTensorType(shapeTensorTy)) - shape = rewriter.create(op.getLoc(), opTensorTy, shape); + if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) + shape = rewriter.create(op.getLoc(), opTensorTy, shape); + else if (!isExtentTensorType(shapeTensorTy)) + shape = + rewriter.create(op.getLoc(), opTensorTy, shape); } rewriter.replaceOp(op, shape); From 89a8ffad8aa20efd2258292eec30e08745fa2aa4 Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 20:56:36 -0400 Subject: [PATCH 7/9] update LIT test names --- mlir/test/Dialect/Shape/canonicalize.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 9b25468b3ab1e..4a65edb3bc1bc 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1393,7 +1393,7 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor) - // CHECK-LABEL: func @shape_of_from_reshape_compatible_types1 // CHECK-SAME: %[[INPUT:.*]]: tensor // CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> -func.func @shape_of_from_reshape_compatible_types1(%arg0: tensor, %arg1: tensor<3xi32>) -> tensor<3xindex> { +func.func @shape_of_from_reshape_int_to_index(%arg0: tensor, %arg1: tensor<3xi32>) -> tensor<3xindex> { // CHECK: %[[CAST_SHAPE:.*]] = arith.index_cast %[[SHAPE]] : tensor<3xi32> to tensor<3xindex> // CHECK: return %[[CAST_SHAPE]] : tensor<3xindex> %0 = tensor.reshape %arg0(%arg1) : (tensor, tensor<3xi32>) -> tensor @@ -1407,7 +1407,7 @@ func.func @shape_of_from_reshape_compatible_types1(%arg0: tensor, %arg1 // CHECK-LABEL: func @shape_of_from_reshape_compatible_types2 // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex> -func.func @shape_of_from_reshape_compatible_types2(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor { +func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor { // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor // CHECK: return %[[CAST_SHAPE]] : tensor %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32> @@ -1421,7 +1421,7 @@ func.func @shape_of_from_reshape_compatible_types2(%arg0: tensor<*xf32>, %arg1: // CHECK-LABEL: func @shape_of_from_reshape_compatible_types3 // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor -func.func @shape_of_from_reshape_compatible_types3(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<5xindex> { +func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<5xindex> { // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor to tensor<5xindex> // CHECK: return %[[CAST_SHAPE]] : tensor<5xindex> %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> From 245263ca2d85a4e906fe2ce27cf4366f99a4d4c5 Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 21:04:06 -0400 Subject: [PATCH 8/9] minor fix --- mlir/test/Dialect/Shape/canonicalize.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 4a65edb3bc1bc..71a80de8adfb9 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1390,7 +1390,7 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor) - // ----- // Check statically shaped types, with element types i32 to index. -// CHECK-LABEL: func @shape_of_from_reshape_compatible_types1 +// CHECK-LABEL: func @shape_of_from_reshape_int_to_index // CHECK-SAME: %[[INPUT:.*]]: tensor // CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> func.func @shape_of_from_reshape_int_to_index(%arg0: tensor, %arg1: tensor<3xi32>) -> tensor<3xindex> { @@ -1404,7 +1404,7 @@ func.func @shape_of_from_reshape_int_to_index(%arg0: tensor, %arg1: ten // ----- // Check similar element types, with statically shaped to dynamically shaped. -// CHECK-LABEL: func @shape_of_from_reshape_compatible_types2 +// CHECK-LABEL: func @shape_of_from_reshape_static_to_dynamic // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex> func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor { @@ -1418,7 +1418,7 @@ func.func @shape_of_from_reshape_static_to_dynamic(%arg0: tensor<*xf32>, %arg1: // ----- // Check similar element types, with dynamically shaped to statically shaped. -// CHECK-LABEL: func @shape_of_from_reshape_compatible_types3 +// CHECK-LABEL: func @shape_of_from_reshape_dynamic_to_static // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<5xindex> { From 34b5dbdc733d95b6219653dcbe3fb25a42f63b8b Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 3 Apr 2025 21:34:28 -0400 Subject: [PATCH 9/9] add LIT test shape_of_from_reshape_identical_types --- mlir/test/Dialect/Shape/canonicalize.mlir | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 71a80de8adfb9..b42fa75e4112d 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1431,6 +1431,19 @@ func.func @shape_of_from_reshape_dynamic_to_static(%arg0: tensor<*xf32>, %arg1: // ----- +// Check similar element types and similar static shape. +// CHECK-LABEL: func @shape_of_from_reshape_identical_types +// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> +// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex> +func.func @shape_of_from_reshape_identical_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<5xindex> { + // CHECK: return %[[SHAPE]] : tensor<5xindex> + %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32> + %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<5xindex> + return %1 : tensor<5xindex> +} + +// ----- + // CHECK-LABEL: func @shape_of_from_reshape_nofold // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32> // CHECK-SAME: %[[SHAPE:.*]]: tensor