@@ -37,6 +37,9 @@ static LogicalResult foldMemrefViewOp(
3737 Value &memrefBase, StringRef role)
3838{
3939 Operation *defOp = view.getDefiningOp ();
40+ if (!defOp) {
41+ return failure ();
42+ }
4043 return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
4144 .Case <memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
4245 mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides (
@@ -81,12 +84,19 @@ struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
8184 auto foldSrcResult = foldMemrefViewOp (
8285 rewriter, loc, op.getSrc (), op.getSrcIndices (), sourceIndices, memrefSource, " source" );
8386
87+ if (failed (foldSrcResult)) {
88+ memrefSource = op.getSrc ();
89+ sourceIndices = op.getSrcIndices ();
90+ }
91+
8492 auto foldDstResult = foldMemrefViewOp (
8593 rewriter, loc, op.getDst (), op.getDstIndices (), destIndices, memrefDest, " destination" );
8694
87- if (failed (foldSrcResult) || failed (foldDstResult)) {
88- return failure ();
89- }
95+ if (failed (foldDstResult)) {
96+ memrefDest = op.getDst ();
97+ destIndices = op.getDstIndices ();
98+ }
99+
90100
91101 rewriter.replaceOpWithNewOp <GatherToLDSOp>(op, memrefSource, sourceIndices,
92102 memrefDest, destIndices,
0 commit comments