File tree Expand file tree Collapse file tree 1 file changed +6
-22
lines changed
mlir/lib/Dialect/Linalg/IR Expand file tree Collapse file tree 1 file changed +6
-22
lines changed Original file line number Diff line number Diff line change @@ -2285,30 +2285,14 @@ LogicalResult IndexOp::verify() {
22852285
22862286OpFoldResult IndexOp::fold (FoldAdaptor adaptor) {
22872287 auto linalgOp = cast<LinalgOp>((*this )->getParentOp ());
2288- int64_t flatDimPos =
2289- cast<AffineDimExpr>(linalgOp.getShapesToLoopsMap ().getResult (getDim ()))
2290- .getPosition ();
2291-
2292- // Find the flat dimension position among the operands.
2293- int64_t flatPosOffset = 0 ;
2294- for (Value operand : linalgOp->getOperands ()) {
2295- assert (flatDimPos >= flatPosOffset && " invalid position" );
2296- auto shapedType = dyn_cast<ShapedType>(operand.getType ());
2297- if (!shapedType)
2298- break ;
22992288
2300- int64_t rank = shapedType.getRank ();
2301- if (flatDimPos < flatPosOffset + rank) {
2302- // Found the dimension within this shape. Now we can either fold if the
2303- // dim size is 1, or bail out otherwise.
2304- int64_t pos = flatDimPos - flatPosOffset;
2305- if (shapedType.getDimSize (pos) != 1 )
2306- break ;
2289+ // Index of unit dims is always 0.
2290+ SmallVector<int64_t , 4 > loopBounds = linalgOp.getStaticLoopRanges ();
2291+ uint64_t dim = getDim ();
2292+ assert (dim < loopBounds.size ());
2293+ if (loopBounds[dim] == 1 )
2294+ return IntegerAttr::get (IndexType::get (getContext ()), 0 );
23072295
2308- return IntegerAttr::get (IndexType::get (getContext ()), 0 );
2309- }
2310- flatPosOffset += rank;
2311- }
23122296 return OpFoldResult{};
23132297}
23142298
You can’t perform that action at this time.
0 commit comments