Skip to content

Commit dbaa97a

Browse files
Address comments.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 1b675e9 commit dbaa97a

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -681,28 +681,21 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
681681
builder.getContext());
682682
}
683683

684-
/// Return the type of the operand/result to use in the expanded op given the
685-
/// type in the original op.
684+
/// Return the shape and type of the operand/result to use in the expanded op
685+
/// given the type in the original op.
686686
static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
687687
getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
688688
const ExpansionInfo &expansionInfo) {
689-
SmallVector<int64_t> expandedStaticShape;
690689
SmallVector<OpFoldResult> expandedShape;
691690
for (AffineExpr expr : indexingMap.getResults()) {
692691
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
693692
ArrayRef<OpFoldResult> dimExpansion =
694693
expansionInfo.getExpandedShapeOfDim(dim);
695-
llvm::append_range(expandedStaticShape,
696-
llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
697-
std::optional<int64_t> staticShape =
698-
getConstantIntValue(ofr);
699-
if (staticShape) {
700-
return staticShape.value();
701-
}
702-
return ShapedType::kDynamic;
703-
}));
704694
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
705695
}
696+
SmallVector<int64_t> expandedStaticShape;
697+
std::tie(expandedStaticShape, std::ignore) =
698+
decomposeMixedValues(expandedShape);
706699
return {expandedShape, RankedTensorType::get(expandedStaticShape,
707700
originalType.getElementType())};
708701
}
@@ -761,13 +754,14 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
761754
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
762755
OpFoldResult newIndex =
763756
rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
764-
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
757+
for (auto [expandedShape, expandedIndex] :
758+
llvm::zip(expandedDimsShape, expandedIndices)) {
765759
AffineExpr idx, acc, shape;
766760
bindDims(rewriter.getContext(), idx, acc);
767761
bindSymbols(rewriter.getContext(), shape);
768762
newIndex = affine::makeComposedFoldedAffineApply(
769763
rewriter, indexOp.getLoc(), idx + acc * shape,
770-
ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
764+
ArrayRef<OpFoldResult>{expandedIndex, newIndex, expandedShape});
771765
}
772766
Value newIndexVal =
773767
getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
@@ -890,6 +884,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
890884
src = expandingReshapeOp.getSrc();
891885
} else {
892886
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
887+
if (!collapsingReshapeOp)
888+
return std::nullopt;
889+
893890
expandedShape = tensor::getMixedSizes(
894891
rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
895892
reassociationIndices = collapsingReshapeOp.getReassociationMaps();

0 commit comments

Comments
 (0)