@@ -681,28 +681,21 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
681681 builder.getContext ());
682682}
683683
684- // / Return the type of the operand/result to use in the expanded op given the
685- // / type in the original op.
684+ // / Return the shape and type of the operand/result to use in the expanded op
685+ // / given the type in the original op.
686686static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687687getExpandedShapeAndType (RankedTensorType originalType, AffineMap indexingMap,
688688 const ExpansionInfo &expansionInfo) {
689- SmallVector<int64_t > expandedStaticShape;
690689 SmallVector<OpFoldResult> expandedShape;
691690 for (AffineExpr expr : indexingMap.getResults ()) {
692691 unsigned dim = cast<AffineDimExpr>(expr).getPosition ();
693692 ArrayRef<OpFoldResult> dimExpansion =
694693 expansionInfo.getExpandedShapeOfDim (dim);
695- llvm::append_range (expandedStaticShape,
696- llvm::map_range (dimExpansion, [](OpFoldResult ofr) {
697- std::optional<int64_t > staticShape =
698- getConstantIntValue (ofr);
699- if (staticShape) {
700- return staticShape.value ();
701- }
702- return ShapedType::kDynamic ;
703- }));
704694 expandedShape.append (dimExpansion.begin (), dimExpansion.end ());
705695 }
696+ SmallVector<int64_t > expandedStaticShape;
697+ std::tie (expandedStaticShape, std::ignore) =
698+ decomposeMixedValues (expandedShape);
706699 return {expandedShape, RankedTensorType::get (expandedStaticShape,
707700 originalType.getElementType ())};
708701}
@@ -761,13 +754,14 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
761754 [&](int64_t dim) { return rewriter.create <IndexOp>(loc, dim); });
762755 OpFoldResult newIndex =
763756 rewriter.create <IndexOp>(loc, expandedDims.front ()).getResult ();
764- for (auto it : llvm::zip (expandedDimsShape, expandedIndices)) {
757+ for (auto [expandedShape, expandedIndex] :
758+ llvm::zip (expandedDimsShape, expandedIndices)) {
765759 AffineExpr idx, acc, shape;
766760 bindDims (rewriter.getContext (), idx, acc);
767761 bindSymbols (rewriter.getContext (), shape);
768762 newIndex = affine::makeComposedFoldedAffineApply (
769763 rewriter, indexOp.getLoc (), idx + acc * shape,
770- ArrayRef<OpFoldResult>{std::get< 1 >(it) , newIndex, std::get< 0 >(it) });
764+ ArrayRef<OpFoldResult>{expandedIndex , newIndex, expandedShape });
771765 }
772766 Value newIndexVal =
773767 getValueOrCreateConstantIndexOp (rewriter, indexOp.getLoc (), newIndex);
@@ -890,6 +884,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
890884 src = expandingReshapeOp.getSrc ();
891885 } else {
892886 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887+ if (!collapsingReshapeOp)
888+ return std::nullopt ;
889+
893890 expandedShape = tensor::getMixedSizes (
894891 rewriter, collapsingReshapeOp->getLoc (), collapsingReshapeOp.getSrc ());
895892 reassociationIndices = collapsingReshapeOp.getReassociationMaps ();
0 commit comments