@@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
510510 return success ();
511511}
512512
513+ // ===----------------------------------------------------------------------===//
514+ // GatherToLDSOp
515+ // ===----------------------------------------------------------------------===//
516+
513517LogicalResult GatherToLDSOp::verify () {
514518 MemRefType srcType = cast<MemRefType>(getSrc ().getType ());
515519 MemRefType dstType = cast<MemRefType>(getDst ().getType ());
@@ -546,6 +550,59 @@ LogicalResult GatherToLDSOp::verify() {
546550 return success ();
547551}
548552
553+ namespace {
554+ // / If the source/target of a CopyOp is a CastOp that does not modify the shape
555+ // / and element type, the cast can be skipped. Such CastOps only cast the layout
556+ // / of the type.
557+ struct FoldGatherToLDSOfCast : public OpRewritePattern <GatherToLDSOp> {
558+ using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
559+
560+ LogicalResult matchAndRewrite (GatherToLDSOp gatherOp,
561+ PatternRewriter &rewriter) const override {
562+ bool modified = false ;
563+
564+ // Check source.
565+ if (auto castOp = gatherOp.getSrc ().getDefiningOp <memref::CastOp>()) {
566+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
567+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
568+
569+ if (fromType && toType &&
570+ fromType.getElementType () == toType.getElementType ()) {
571+ rewriter.modifyOpInPlace (gatherOp, [&] {
572+ gatherOp.getSrcMutable ().assign (castOp.getSource ());
573+ });
574+ modified = true ;
575+ }
576+ }
577+
578+ // Check target.
579+ if (auto castOp = gatherOp.getDst ().getDefiningOp <memref::CastOp>()) {
580+ auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
581+ auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
582+
583+ if (fromType && toType &&
584+ fromType.getElementType () == toType.getElementType ()) {
585+ rewriter.modifyOpInPlace (gatherOp, [&] {
586+ gatherOp.getDstMutable ().assign (castOp.getSource ());
587+ });
588+ modified = true ;
589+ }
590+ }
591+
592+ return success (modified);
593+ }
594+ };
595+ } // namespace
596+
597+ void GatherToLDSOp::getCanonicalizationPatterns (RewritePatternSet &results,
598+ MLIRContext *context) {
599+ results.add <FoldGatherToLDSOfCast>(context);
600+ }
601+
602+ // ===----------------------------------------------------------------------===//
603+ // TransposeLoadOp
604+ // ===----------------------------------------------------------------------===//
605+
549606LogicalResult TransposeLoadOp::verify () {
550607 MemRefType srcType = cast<MemRefType>(getSrc ().getType ());
551608
0 commit comments