@@ -42,6 +42,7 @@ limitations under the License.
4242#include " mlir/IR/Diagnostics.h"
4343#include " mlir/IR/IRMapping.h"
4444#include " mlir/IR/Matchers.h"
45+ #include " mlir/IR/OpDefinition.h"
4546#include " mlir/IR/OperationSupport.h"
4647#include " mlir/IR/PatternMatch.h"
4748#include " mlir/IR/Region.h"
@@ -58,6 +59,22 @@ namespace tpu {
5859
5960namespace {
6061
62+ // This should only be used to canonicalize away EraseLayoutOps that feed ops
63+ // that only consume memrefs and don't return them.
64+ LogicalResult propagateTiledLayoutToConsumer (Operation* op,
65+ PatternRewriter& rewriter) {
66+ bool modified = false ;
67+ for (unsigned int i = 0 ; i < op->getNumOperands (); ++i) {
68+ if (auto erase_layout_op =
69+ op->getOperand (i).getDefiningOp <tpu::EraseLayoutOp>()) {
70+ modified = true ;
71+ rewriter.modifyOpInPlace (
72+ op, [&]() { op->setOperand (i, erase_layout_op.getOperand ()); });
73+ }
74+ }
75+ return success (modified);
76+ }
77+
6178llvm::RoundingMode convertTpuRoundingModeToLLVMIR (tpu::RoundingMode mode) {
6279 switch (mode) {
6380 case tpu::RoundingMode::kToNearestEven :
@@ -268,6 +285,8 @@ struct MemRefSliceFoldConstantDynamicDim
268285 op.getResult ().setType (new_type);
269286 op.getDynamicSizesMutable ().assign (new_dynamic_sizes);
270287 });
288+ mlir::OpBuilder::InsertionGuard guard (rewriter);
289+ rewriter.setInsertionPointAfter (op);
271290 auto cast_op = memref::CastOp::create (rewriter, op.getLoc (), old_type, op);
272291 rewriter.replaceAllUsesExcept (op, cast_op, cast_op);
273292 return success ();
@@ -604,7 +623,7 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op,
604623 }
605624 auto layout_ref = erase_layout_op.getOperand ();
606625 auto layout_ty = layout_ref.getType ();
607- auto layout = dyn_cast <tpu::TiledLayoutAttr>(layout_ty.getLayout ());
626+ auto layout = cast <tpu::TiledLayoutAttr>(layout_ty.getLayout ());
608627 CHECK (!layout.getTiles ().empty ());
609628 auto tile = layout.getTiles ().front ().dimensions ();
610629 auto new_tile_strides = ComputeTileStrides (dst_ty, tile);
@@ -788,6 +807,11 @@ LogicalResult VectorStoreOp::verify() {
788807 return verifyStoreOp (*this );
789808}
790809
810+ LogicalResult VectorStoreOp::canonicalize (VectorStoreOp op,
811+ PatternRewriter& rewriter) {
812+ return propagateTiledLayoutToConsumer (op, rewriter);
813+ }
814+
791815template <typename Op>
792816LogicalResult verifyLoadOp (Op op) {
793817 MemRefType ref_ty = op.getBase ().getType ();
@@ -826,6 +850,11 @@ LogicalResult VectorLoadOp::verify() {
826850 return verifyLoadOp (*this );
827851}
828852
853+ LogicalResult VectorLoadOp::canonicalize (VectorLoadOp op,
854+ PatternRewriter& rewriter) {
855+ return propagateTiledLayoutToConsumer (op, rewriter);
856+ }
857+
829858LogicalResult VectorLoadIdxOp::verify () {
830859 VectorType value_ty = getResult ().getType ();
831860 MemRefType ref_ty = getBase ().getType ();
@@ -846,6 +875,11 @@ LogicalResult VectorLoadIdxOp::verify() {
846875 return verifyLoadOp (*this );
847876}
848877
878+ LogicalResult VectorLoadIdxOp::canonicalize (VectorLoadIdxOp op,
879+ PatternRewriter& rewriter) {
880+ return propagateTiledLayoutToConsumer (op, rewriter);
881+ }
882+
849883LogicalResult VectorStoreIdxOp::verify () {
850884 VectorType value_ty = getValueToStore ().getType ();
851885 MemRefType ref_ty = getBase ().getType ();
@@ -870,6 +904,11 @@ LogicalResult VectorStoreIdxOp::verify() {
870904 return verifyStoreOp (*this );
871905}
872906
907+ LogicalResult VectorStoreIdxOp::canonicalize (VectorStoreIdxOp op,
908+ PatternRewriter& rewriter) {
909+ return propagateTiledLayoutToConsumer (op, rewriter);
910+ }
911+
873912LogicalResult ReinterpretCastOp::verify () {
874913 auto source_type = getMemRefType (getInput ());
875914 auto target_type = getType ();
@@ -881,6 +920,17 @@ LogicalResult ReinterpretCastOp::verify() {
881920 return success ();
882921}
883922
923+ LogicalResult ReinterpretCastOp::canonicalize (ReinterpretCastOp op,
924+ PatternRewriter& rewriter) {
925+ if (auto erase_layout_op = op.getInput ().getDefiningOp <EraseLayoutOp>()) {
926+ rewriter.modifyOpInPlace (op, [&]() {
927+ op.getInputMutable ().assign (erase_layout_op.getOperand ());
928+ });
929+ return success ();
930+ }
931+ return failure ();
932+ }
933+
884934LogicalResult EraseLayoutOp::inferReturnTypes (
885935 MLIRContext* context, std::optional<Location> location,
886936 EraseLayoutOp::Adaptor adaptor,
@@ -891,6 +941,14 @@ LogicalResult EraseLayoutOp::inferReturnTypes(
891941 return success ();
892942}
893943
944+ OpFoldResult EraseLayoutOp::fold (FoldAdaptor op) {
945+ // If the operand has no interesting layout then there's no need to erase it.
946+ if (getOperand ().getType ().getLayout ().isIdentity ()) {
947+ return op.getOperand ();
948+ }
949+ return OpFoldResult ();
950+ }
951+
894952template <typename Op>
895953LogicalResult verifyRotateOp (Op op) {
896954 auto vty = op.getResult ().getType ();
@@ -1371,6 +1429,11 @@ LogicalResult EnqueueDMAOp::verify() {
13711429 return success ();
13721430}
13731431
1432+ LogicalResult EnqueueDMAOp::canonicalize (EnqueueDMAOp op,
1433+ PatternRewriter& rewriter) {
1434+ return propagateTiledLayoutToConsumer (op, rewriter);
1435+ }
1436+
13741437LogicalResult EnqueueIndirectDMAOp::verifyGather (
13751438 MemRefType operand_ty, ArrayRef<int64_t > offsets_shape,
13761439 MemRefType result_ty) {
@@ -1550,6 +1613,11 @@ LogicalResult EnqueueIndirectDMAOp::verify() {
15501613 /* operand_ty=*/ target_ty);
15511614}
15521615
1616+ LogicalResult EnqueueIndirectDMAOp::canonicalize (EnqueueIndirectDMAOp op,
1617+ PatternRewriter& rewriter) {
1618+ return propagateTiledLayoutToConsumer (op, rewriter);
1619+ }
1620+
15531621// TODO(b/395630795): Remove after 2025-08-10.
15541622LogicalResult WaitDMAOp::verify () {
15551623 auto sem_type = getMemRefType (getSemaphore ());
@@ -1573,6 +1641,11 @@ LogicalResult WaitDMA2Op::verify() {
15731641 return success ();
15741642}
15751643
1644+ LogicalResult WaitDMA2Op::canonicalize (WaitDMA2Op op,
1645+ PatternRewriter& rewriter) {
1646+ return propagateTiledLayoutToConsumer (op, rewriter);
1647+ }
1648+
15761649FailureOr<bool > WaitIndirectDMAOp::isGather () {
15771650 return mlir::tpu::isGather (*getOperation (), getSrc (), getDst ());
15781651}
@@ -1593,6 +1666,11 @@ LogicalResult WaitIndirectDMAOp::verify() {
15931666 return isGather ();
15941667}
15951668
1669+ LogicalResult WaitIndirectDMAOp::canonicalize (WaitIndirectDMAOp op,
1670+ PatternRewriter& rewriter) {
1671+ return propagateTiledLayoutToConsumer (op, rewriter);
1672+ }
1673+
15961674LogicalResult RegionOp::verify () {
15971675 for (auto result_type : getResultTypes ()) {
15981676 if (!isa<FloatType, IntegerType, VectorType, IndexType>(result_type)) {
0 commit comments