Skip to content

Commit cbcdb33

Browse files
committed
Simplify
1 parent b3e5afe commit cbcdb33

File tree

1 file changed

+6
-22
lines changed

1 file changed

+6
-22
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,30 +2285,14 @@ LogicalResult IndexOp::verify() {
22852285

22862286
OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
22872287
auto linalgOp = cast<LinalgOp>((*this)->getParentOp());
2288-
int64_t flatDimPos =
2289-
cast<AffineDimExpr>(linalgOp.getShapesToLoopsMap().getResult(getDim()))
2290-
.getPosition();
2291-
2292-
// Find the flat dimension position among the operands.
2293-
int64_t flatPosOffset = 0;
2294-
for (Value operand : linalgOp->getOperands()) {
2295-
assert(flatDimPos >= flatPosOffset && "invalid position");
2296-
auto shapedType = dyn_cast<ShapedType>(operand.getType());
2297-
if (!shapedType)
2298-
break;
22992288

2300-
int64_t rank = shapedType.getRank();
2301-
if (flatDimPos < flatPosOffset + rank) {
2302-
// Found the dimension within this shape. Now we can either fold if the
2303-
// dim size is 1, or bail out otherwise.
2304-
int64_t pos = flatDimPos - flatPosOffset;
2305-
if (shapedType.getDimSize(pos) != 1)
2306-
break;
2289+
// Index of unit dims is always 0.
2290+
SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2291+
uint64_t dim = getDim();
2292+
assert(dim < loopBounds.size());
2293+
if (loopBounds[dim] == 1)
2294+
return IntegerAttr::get(IndexType::get(getContext()), 0);
23072295

2308-
return IntegerAttr::get(IndexType::get(getContext()), 0);
2309-
}
2310-
flatPosOffset += rank;
2311-
}
23122296
return OpFoldResult{};
23132297
}
23142298

0 commit comments

Comments
 (0)