@@ -551,10 +551,9 @@ LogicalResult GatherToLDSOp::verify() {
551551}
552552
553553namespace {
554- // / If the source/target of a CopyOp is a CastOp that does not modify the shape
555- // / and element type, the cast can be skipped. Such CastOps only cast the layout
556- // / of the type.
557- struct FoldGatherToLDSOfCast : public OpRewritePattern <GatherToLDSOp> {
554+ // / If the source/target of a GatherToLDSOp is a CastOp that only removes static
555+ // / information or changes layout, the cast can be skipped.
556+ struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
558557 using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
559558
560559 LogicalResult matchAndRewrite (GatherToLDSOp gatherOp,
@@ -563,10 +562,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
563562
564563 // Check source.
565564 if (auto castOp = gatherOp.getSrc ().getDefiningOp <memref::CastOp>()) {
566- auto fromType = llvm:: dyn_cast<MemRefType>(castOp.getSource ().getType ());
567- auto toType = llvm:: dyn_cast<MemRefType>(castOp.getSource ().getType ());
565+ auto fromType = dyn_cast<MemRefType>(castOp.getSource ().getType ());
566+ auto toType = dyn_cast<MemRefType>(castOp.getSource ().getType ());
568567
569- if (fromType && toType &&
568+ if (memref::CastOp::canFoldIntoConsumerOp (castOp) && fromType && toType &&
570569 fromType.getElementType () == toType.getElementType ()) {
571570 rewriter.modifyOpInPlace (gatherOp, [&] {
572571 gatherOp.getSrcMutable ().assign (castOp.getSource ());
@@ -577,10 +576,10 @@ struct FoldGatherToLDSOfCast : public OpRewritePattern<GatherToLDSOp> {
577576
578577 // Check target.
579578 if (auto castOp = gatherOp.getDst ().getDefiningOp <memref::CastOp>()) {
580- auto fromType = llvm:: dyn_cast<MemRefType>(castOp.getSource ().getType ());
581- auto toType = llvm:: dyn_cast<MemRefType>(castOp.getSource ().getType ());
579+ auto fromType = dyn_cast<MemRefType>(castOp.getSource ().getType ());
580+ auto toType = dyn_cast<MemRefType>(castOp.getSource ().getType ());
582581
583- if (fromType && toType &&
582+ if (memref::CastOp::canFoldIntoConsumerOp (castOp) && fromType && toType &&
584583 fromType.getElementType () == toType.getElementType ()) {
585584 rewriter.modifyOpInPlace (gatherOp, [&] {
586585 gatherOp.getDstMutable ().assign (castOp.getSource ());
0 commit comments