Skip to content

Commit c4486cf

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Fix reshape fusion to reshape the outs instead of creating new tensors.
When fusing tensor_reshape ops with generic/indexed_Generic op, new linalg.init_tensor operations were created for the `outs` of the fused op. While correct (technically) it is better to just reshape the original `outs` operands and rely on canonicalization of init_tensor -> tensor_reshape to achieve the same effect. Differential Revision: https://reviews.llvm.org/D93774
1 parent 3d693bd commit c4486cf

File tree

4 files changed

+179
-239
lines changed

4 files changed

+179
-239
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,14 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
183183
return llvm::to_vector<4>(llvm::map_range(reassociation(), [
184184
](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
185185
}
186+
SmallVector<ReassociationExprs, 4> getReassociationExprs() {
187+
return
188+
llvm::to_vector<4>(llvm::map_range(reassociation(),
189+
[](Attribute a) {
190+
return llvm::to_vector<2>(
191+
a.cast<AffineMapAttr>().getValue().getResults());
192+
}));
193+
}
186194
}];
187195
let assemblyFormat = [{
188196
$src $reassociation attr-dict `:` type($src) `into` type(results)

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 13 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -566,45 +566,6 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
566566
return RankedTensorType::get(expandedShape, originalType.getElementType());
567567
}
568568

569-
/// Get the value to use for the output in the expanded operation given the
570-
/// `indexingMap` for the output in the original op. Creates an
571-
/// `linalg.init_tensor` operation to materialize the tensor that carries the
572-
/// shape information. This is only used when the tensor_reshape is expanding
573-
/// and is a consumer. In such cases, the tensor_reshape op semantics gaurantees
574-
/// that the shape of the output is computable from the shape of the input since
575-
/// at most one of the expanded dims can be dynamic.
576-
static Value getOutputValueForExpandedOp(OpBuilder &builder, Location loc,
577-
AffineMap indexingMap, Value result,
578-
const ExpansionInfo &expansionInfo) {
579-
SmallVector<Value, 4> dynamicDims;
580-
SmallVector<int64_t, 4> staticDims;
581-
ShapedType resultType = result.getType().cast<ShapedType>();
582-
ArrayRef<int64_t> origShape = resultType.getShape();
583-
for (AffineExpr expr : indexingMap.getResults()) {
584-
unsigned origDimPos = expr.cast<AffineDimExpr>().getPosition();
585-
bool foundDynamic = false;
586-
int64_t linearizedShape = 1;
587-
for (int64_t extent : expansionInfo.getExpandedShapeOfDim(origDimPos)) {
588-
if (ShapedType::isDynamic(extent)) {
589-
assert(!foundDynamic &&
590-
"Expanded dimensions of reshape can have only one dynamic dim");
591-
staticDims.push_back(ShapedType::kDynamicSize);
592-
foundDynamic = true;
593-
continue;
594-
}
595-
staticDims.push_back(extent);
596-
linearizedShape *= extent;
597-
}
598-
if (ShapedType::isDynamic(origShape[origDimPos])) {
599-
Value origDim = builder.create<DimOp>(loc, result, origDimPos);
600-
dynamicDims.push_back(builder.create<UnsignedDivIOp>(
601-
loc, origDim, builder.create<ConstantIndexOp>(loc, linearizedShape)));
602-
}
603-
}
604-
return builder.create<linalg::InitTensorOp>(loc, dynamicDims, staticDims,
605-
resultType.getElementType());
606-
}
607-
608569
/// Returns the reassociation maps to use in the `linalg.tensor_reshape`
609570
/// operation to convert the operands of the origial operation to operands of
610571
/// the expanded operation. The same method is used to compute the
@@ -734,8 +695,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
734695
SmallVector<Value, 1> outputs;
735696
for (auto result : llvm::enumerate(linalgOp.getOutputs())) {
736697
AffineMap indexingMap = linalgOp.getOutputIndexingMap(result.index());
737-
outputs.push_back(getOutputValueForExpandedOp(
738-
rewriter, loc, indexingMap, result.value(), expansionInfo));
698+
RankedTensorType expandedOutputType =
699+
getExpandedType(result.value().getType().cast<RankedTensorType>(),
700+
indexingMap, expansionInfo);
701+
if (expandedOutputType != result.value().getType()) {
702+
SmallVector<ReassociationIndices, 4> reassociation =
703+
getReassociationForExpansion(indexingMap, expansionInfo);
704+
outputs.push_back(rewriter.create<TensorReshapeOp>(
705+
linalgOp.getLoc(), expandedOutputType, result.value(),
706+
reassociation));
707+
}
739708
}
740709

741710
// The iterator types of the expanded op are all parallel.
@@ -779,47 +748,6 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
779748
return resultVals;
780749
}
781750

782-
static Value
783-
getOutputValueForLinearization(OpBuilder &builder, Location loc,
784-
Value origOutput,
785-
ArrayRef<AffineMap> reassociationMaps) {
786-
SmallVector<Value, 4> dynamicDims;
787-
SmallVector<int64_t, 4> staticDims;
788-
auto shapedType = origOutput.getType().cast<ShapedType>();
789-
ArrayRef<int64_t> origShape = shapedType.getShape();
790-
for (auto map : reassociationMaps) {
791-
Optional<Value> dynamicDim;
792-
int64_t staticLinearizedShape = 1;
793-
for (AffineDimExpr expr :
794-
llvm::map_range(map.getResults(), [](AffineExpr e) {
795-
return e.cast<AffineDimExpr>();
796-
})) {
797-
unsigned pos = expr.getPosition();
798-
if (ShapedType::isDynamic(origShape[pos])) {
799-
Value dim = builder.create<DimOp>(loc, origOutput, pos);
800-
if (dynamicDim) {
801-
dynamicDim = builder.create<MulIOp>(loc, dynamicDim.getValue(), dim);
802-
} else {
803-
dynamicDim = dim;
804-
}
805-
} else {
806-
staticLinearizedShape *= origShape[pos];
807-
}
808-
}
809-
if (dynamicDim) {
810-
dynamicDim = builder.create<MulIOp>(
811-
loc, dynamicDim.getValue(),
812-
builder.create<ConstantIndexOp>(loc, staticLinearizedShape));
813-
dynamicDims.push_back(dynamicDim.getValue());
814-
staticDims.push_back(ShapedType::kDynamicSize);
815-
} else {
816-
staticDims.push_back(staticLinearizedShape);
817-
}
818-
}
819-
return builder.create<InitTensorOp>(loc, dynamicDims, staticDims,
820-
shapedType.getElementType());
821-
}
822-
823751
namespace {
824752

825753
/// Pattern to fold tensor_reshape op with its consumer by using the source of
@@ -973,7 +901,7 @@ struct FoldConsumerReshapeOpByLinearization
973901
reshapeOp.getReassociationMaps());
974902
for (AffineExpr expr : modifiedMap.getResults()) {
975903
if (!expr.isPureAffine())
976-
return reshapeOp.emitRemark("fused op indexing map is not affine");
904+
return producer.emitRemark("fused op indexing map is not affine");
977905
}
978906
fusedIndexMaps.back() = modifiedMap;
979907

@@ -983,9 +911,8 @@ struct FoldConsumerReshapeOpByLinearization
983911
return reshapeOp.emitRemark("fused op loop bound computation failed");
984912

985913
Location loc = producer.getLoc();
986-
Value output =
987-
getOutputValueForLinearization(rewriter, loc, producer.getOutputs()[0],
988-
reshapeOp.getReassociationMaps());
914+
Value output = rewriter.create<TensorReshapeOp>(
915+
loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs());
989916
LinalgOp fusedOp = createLinalgOpOfSameType(
990917
producer, rewriter, loc, reshapeOp.getResultType(),
991918
/*inputs=*/producer.getInputs(),

0 commit comments

Comments
 (0)