@@ -202,7 +202,10 @@ void computeSubviewOffsets(PatternRewriter &rewriter, Location loc,
202202 Value memref, SmallVector<Value> &resultOffsets,
203203 Value &resultRootMemref) {
204204 auto fillVal = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
205- auto origShape = dyn_cast<MemRefType>(memref.getType ()).getShape ();
205+ auto type = dyn_cast<MemRefType>(memref.getType ());
206+ assert (type && " Expected a memref type" );
207+
208+ auto origShape = type.getShape ();
206209
207210 resultOffsets.clear ();
208211 resultOffsets.append (origShape.size (), fillVal);
@@ -235,8 +238,8 @@ SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
235238 return strides;
236239}
237240
238- FailureOr<Value> squeezeMemref (PatternRewriter &rewriter, Location loc,
239- Value memref, size_t maxDims = 2 ) {
241+ FailureOr<Value> reduceMemrefDims (PatternRewriter &rewriter, Location loc,
242+ Value memref, size_t maxDims = 2 ) {
240243 auto type = dyn_cast<MemRefType>(memref.getType ());
241244 auto shape = type.getShape ();
242245
@@ -293,7 +296,7 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
293296 if (type.getShape ().size () <= maxDims)
294297 continue ;
295298
296- auto res = squeezeMemref (rewriter, loc, operand, maxDims);
299+ auto res = reduceMemrefDims (rewriter, loc, operand, maxDims);
297300 if (failed (res)) {
298301 return rewriter.notifyMatchFailure (
299302 linalgOp, " Can't squeeze memref to the desired number of dimensions" );
@@ -303,9 +306,10 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
303306 newOperands.emplace_back (i, flatSubview);
304307 }
305308
306- for (auto [i, operand] : newOperands)
307- linalgOp->setOperand (i, operand);
308-
309+ rewriter.modifyOpInPlace (linalgOp, [&] {
310+ for (auto [i, operand] : newOperands)
311+ linalgOp->setOperand (i, operand);
312+ });
309313 return success ();
310314}
311315
0 commit comments