Skip to content

Commit 9c0dc0b

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Fold init_tensor -> linalg.tensor_reshape.
Reshaping an init_tensor can be folded to a init_tensor op of the final type. Differential Revision: https://reviews.llvm.org/D93773
1 parent 4214ca9 commit 9c0dc0b

File tree

2 files changed

+156
-6
lines changed

2 files changed

+156
-6
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -718,9 +718,123 @@ struct ReplaceDimOfInitTensorOp : public OpRewritePattern<DimOp> {
718718
};
719719
} // namespace
720720

721+
static Value getCollapsedInitTensor(OpBuilder &builder,
722+
TensorReshapeOp reshapeOp) {
723+
Location loc = reshapeOp.getLoc();
724+
SmallVector<Value, 4> dynamicShapes;
725+
SmallVector<int64_t, 4> staticShapes;
726+
auto reassociation = reshapeOp.getReassociationMaps();
727+
Value src = reshapeOp.src();
728+
RankedTensorType srcType = reshapeOp.getSrcType();
729+
ArrayRef<int64_t> srcShape = srcType.getShape();
730+
for (auto map : reassociation) {
731+
Value linearizedDynamicDim = nullptr;
732+
int64_t linearizedStaticDim = 1;
733+
for (unsigned i : llvm::map_range(map.getResults(), [](AffineExpr e) {
734+
return e.cast<AffineDimExpr>().getPosition();
735+
})) {
736+
if (ShapedType::isDynamic(srcShape[i])) {
737+
Value shapeVal = builder.create<DimOp>(loc, src, i);
738+
if (linearizedDynamicDim) {
739+
linearizedDynamicDim =
740+
builder.create<MulIOp>(loc, linearizedDynamicDim, shapeVal);
741+
} else {
742+
linearizedDynamicDim = shapeVal;
743+
}
744+
} else {
745+
linearizedStaticDim *= srcShape[i];
746+
}
747+
}
748+
if (linearizedDynamicDim) {
749+
if (linearizedStaticDim != 1) {
750+
linearizedDynamicDim = builder.create<MulIOp>(
751+
loc, linearizedDynamicDim,
752+
builder.create<ConstantIndexOp>(loc, linearizedStaticDim));
753+
}
754+
dynamicShapes.push_back(linearizedDynamicDim);
755+
staticShapes.push_back(ShapedType::kDynamicSize);
756+
} else {
757+
staticShapes.push_back(linearizedStaticDim);
758+
}
759+
}
760+
return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes,
761+
srcType.getElementType());
762+
}
763+
764+
static Value getExpandedInitTensor(OpBuilder &builder,
765+
TensorReshapeOp reshapeOp) {
766+
SmallVector<Value, 4> dynamicShapes;
767+
SmallVector<int64_t, 4> staticShapes;
768+
auto reassociation = reshapeOp.getReassociationMaps();
769+
Value src = reshapeOp.src();
770+
RankedTensorType srcType = reshapeOp.getSrcType();
771+
ArrayRef<int64_t> srcShape = srcType.getShape();
772+
ArrayRef<int64_t> dstShape = reshapeOp.getResultType().getShape();
773+
Location loc = reshapeOp.getLoc();
774+
for (auto map : enumerate(reassociation)) {
775+
int64_t linearizedStaticDim = 1;
776+
bool hasDynamic = false;
777+
for (unsigned i :
778+
llvm::map_range(map.value().getResults(), [](AffineExpr e) {
779+
return e.cast<AffineDimExpr>().getPosition();
780+
})) {
781+
if (ShapedType::isDynamic(dstShape[i])) {
782+
// Only one of the dimensions of the expanded shape should be dynamic.
783+
if (hasDynamic)
784+
return nullptr;
785+
hasDynamic = true;
786+
staticShapes.push_back(ShapedType::kDynamicSize);
787+
continue;
788+
}
789+
staticShapes.push_back(dstShape[i]);
790+
linearizedStaticDim *= dstShape[i];
791+
}
792+
if (hasDynamic) {
793+
// If the expanded dimensions has a dynamic shape, the src shape must be
794+
// dynamic as well.
795+
if (!ShapedType::isDynamic(srcShape[map.index()]))
796+
return nullptr;
797+
Value dynamicDim = builder.create<DimOp>(loc, src, map.index());
798+
if (linearizedStaticDim != 1) {
799+
dynamicDim = builder.create<UnsignedDivIOp>(
800+
loc, dynamicDim,
801+
builder.create<ConstantIndexOp>(loc, linearizedStaticDim));
802+
}
803+
dynamicShapes.push_back(dynamicDim);
804+
}
805+
}
806+
return builder.create<InitTensorOp>(loc, dynamicShapes, staticShapes,
807+
srcType.getElementType());
808+
}
809+
810+
namespace {
811+
struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
812+
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
813+
814+
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
815+
PatternRewriter &rewriter) const override {
816+
if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
817+
return failure();
818+
RankedTensorType collapsedType = reshapeOp.getSrcType();
819+
RankedTensorType expandedType = reshapeOp.getResultType();
820+
bool isCollapsed = expandedType.getRank() < collapsedType.getRank();
821+
if (isCollapsed)
822+
std::swap(collapsedType, expandedType);
823+
Value initTensorOp = isCollapsed
824+
? getCollapsedInitTensor(rewriter, reshapeOp)
825+
: getExpandedInitTensor(rewriter, reshapeOp);
826+
if (!initTensorOp)
827+
return failure();
828+
rewriter.replaceOp(reshapeOp, initTensorOp);
829+
return success();
830+
}
831+
};
832+
} // namespace
833+
721834
void InitTensorOp::getCanonicalizationPatterns(
722835
OwningRewritePatternList &results, MLIRContext *context) {
723-
results.insert<ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
836+
results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
837+
ReplaceStaticShapeDims>(context);
724838
}
725839

726840
//===----------------------------------------------------------------------===//
@@ -1043,23 +1157,23 @@ static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
10431157
ArrayRef<int64_t> expandedShape = expandedType.getShape();
10441158
unsigned expandedDimStart = 0;
10451159
for (auto map : llvm::enumerate(op.getReassociationMaps())) {
1046-
Optional<int64_t> dynamicDims;
1160+
Optional<int64_t> dynamicShape;
10471161
int64_t linearizedStaticShape = 1;
10481162
for (auto dim : llvm::enumerate(expandedShape.slice(
10491163
expandedDimStart, map.value().getNumResults()))) {
10501164
if (ShapedType::isDynamic(dim.value())) {
1051-
if (isExpandingReshape && dynamicDims) {
1165+
if (isExpandingReshape && dynamicShape) {
10521166
return op->emitOpError("invalid to have a single dimension (")
10531167
<< map.index() << ") expanded into multiple dynamic dims ("
1054-
<< expandedDimStart + dynamicDims.getValue() << ","
1168+
<< expandedDimStart + dynamicShape.getValue() << ","
10551169
<< expandedDimStart + dim.index() << ")";
10561170
}
1057-
dynamicDims = dim.index();
1171+
dynamicShape = dim.index();
10581172
} else {
10591173
linearizedStaticShape *= dim.value();
10601174
}
10611175
}
1062-
if (dynamicDims) {
1176+
if (dynamicShape) {
10631177
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
10641178
return op->emitOpError("expected dimension ")
10651179
<< map.index()

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,39 @@ func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
413413
// CHECK-SAME: [[ARG_0:%.*]]: tensor<?xf32>, [[ARG_1:%.*]]: tensor<?xf32>)
414414
// CHECK: dim [[ARG_0]]
415415
// CHECK: dim [[ARG_1]]
416+
417+
// -----
418+
419+
func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
420+
%0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
421+
%1 = linalg.tensor_reshape %0
422+
[affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
423+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
424+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
425+
tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
426+
return %1 : tensor<2x3x5x4x?x7xf32>
427+
}
428+
// CHECK: func @init_tensor_reshape_expansion
429+
// CHECK-SAME: %[[ARG0:.+]]: index
430+
// CHECK: %[[C28:.+]] = constant 28 : index
431+
// CHECK: %[[T0:.+]] = divi_unsigned %[[ARG0]], %[[C28]]
432+
// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
433+
// CHECK: return %[[T1]]
434+
435+
// -----
436+
437+
func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
438+
%0 = linalg.init_tensor [2, 3, 5, 4, %arg0, 7] : tensor<2x3x5x4x?x7xf32>
439+
%1 = linalg.tensor_reshape %0
440+
[affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
441+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
442+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>] :
443+
tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
444+
return %1 : tensor<6x5x?xf32>
445+
}
446+
// CHECK: func @init_tensor_reshape_collapse
447+
// CHECK-SAME: %[[ARG0:.+]]: index
448+
// CHECK: %[[C28:.+]] = constant 28 : index
449+
// CHECK: %[[T0:.+]] = muli %[[ARG0]], %[[C28]]
450+
// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
451+
// CHECK: return %[[T1]]

0 commit comments

Comments
 (0)