Skip to content

Commit fe373da

Browse files
committed
address review comments
Signed-off-by: dchigarev <[email protected]>
1 parent ef96fbe commit fe373da

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

include/gc/Transforms/Utils/ValueUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
6464
Location loc, Value memref);
6565

6666
// Squeeze the leading dimensions of a given memref up to 'maxDims'.
67-
FailureOr<Value> squeezeMemref(PatternRewriter &rewriter, Location loc,
68-
Value memref, size_t maxDims = 2);
67+
FailureOr<Value> reduceMemrefDims(PatternRewriter &rewriter, Location loc,
68+
Value memref, size_t maxDims = 2);
6969

7070
// Squeeze the leading dimensions of memref operands of a given 'linalgOp'.
7171
LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,

lib/gc/Transforms/Utils/ValueUtils.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ void computeSubviewOffsets(PatternRewriter &rewriter, Location loc,
202202
Value memref, SmallVector<Value> &resultOffsets,
203203
Value &resultRootMemref) {
204204
auto fillVal = rewriter.create<arith::ConstantIndexOp>(loc, 0);
205-
auto origShape = dyn_cast<MemRefType>(memref.getType()).getShape();
205+
auto type = dyn_cast<MemRefType>(memref.getType());
206+
assert(type && "Expected a memref type");
207+
208+
auto origShape = type.getShape();
206209

207210
resultOffsets.clear();
208211
resultOffsets.append(origShape.size(), fillVal);
@@ -235,8 +238,8 @@ SmallVector<OpFoldResult> getMemrefStrides(PatternRewriter &rewriter,
235238
return strides;
236239
}
237240

238-
FailureOr<Value> squeezeMemref(PatternRewriter &rewriter, Location loc,
239-
Value memref, size_t maxDims = 2) {
241+
FailureOr<Value> reduceMemrefDims(PatternRewriter &rewriter, Location loc,
242+
Value memref, size_t maxDims = 2) {
240243
auto type = dyn_cast<MemRefType>(memref.getType());
241244
auto shape = type.getShape();
242245

@@ -293,7 +296,7 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
293296
if (type.getShape().size() <= maxDims)
294297
continue;
295298

296-
auto res = squeezeMemref(rewriter, loc, operand, maxDims);
299+
auto res = reduceMemrefDims(rewriter, loc, operand, maxDims);
297300
if (failed(res)) {
298301
return rewriter.notifyMatchFailure(
299302
linalgOp, "Can't squeeze memref to the desired number of dimensions");
@@ -303,9 +306,10 @@ LogicalResult maybeSqueezeDims(PatternRewriter &rewriter,
303306
newOperands.emplace_back(i, flatSubview);
304307
}
305308

306-
for (auto [i, operand] : newOperands)
307-
linalgOp->setOperand(i, operand);
308-
309+
rewriter.modifyOpInPlace(linalgOp, [&] {
310+
for (auto [i, operand] : newOperands)
311+
linalgOp->setOperand(i, operand);
312+
});
309313
return success();
310314
}
311315

0 commit comments

Comments
 (0)