1111#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1212#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1313#include " iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
14+ #include " iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1415#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1516#include " mlir/Dialect/Tensor/IR/Tensor.h"
1617#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
1718#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
1819#include " mlir/IR/MLIRContext.h"
1920#include " mlir/IR/PatternMatch.h"
21+ #include " mlir/Transforms/RegionUtils.h"
2022
2123#include < cstdint>
2224#include < optional>
2325
2426namespace mlir ::iree_compiler::IREE::LinalgExt {
27+ namespace {
28+
29+ // / Represents the size of a dimension of some ShapedType value in the IR. This
30+ // / is used instead of OpFoldResult when modifying the IR is illegal. This can
31+ // / still be constructed from an OpFoldResult in cases where the value can be
32+ // / obtained without IR modification.
33+ class DimSize {
34+ public:
35+ DimSize (TypedValue<ShapedType> val, int64_t dim)
36+ : ofr(nullptr ), val(val), dim(dim) {}
37+ DimSize (OpFoldResult ofr) : ofr(ofr), val(nullptr ), dim(-1 ) {}
38+
39+ bool isStatic () const {
40+ if (ofr) {
41+ return getConstantIntValue (ofr).has_value ();
42+ }
43+ return val.getType ().isStaticDim (dim);
44+ }
45+
46+ // Get an OpFoldResult by possibly inserting IR.
47+ OpFoldResult materialize (OpBuilder &b) const {
48+ if (ofr) {
49+ return ofr;
50+ }
51+ return getDim (b, val.getLoc (), val, dim);
52+ }
53+
54+ private:
55+ OpFoldResult ofr;
56+ TypedValue<ShapedType> val;
57+ int64_t dim;
58+ };
59+ } // namespace
60+
61+ static SmallVector<DimSize> getDimSizes (Value v) {
62+ auto shapedVal = cast<TypedValue<ShapedType>>(v);
63+ int64_t rank = shapedVal.getType ().getRank ();
64+ SmallVector<DimSize> sizes;
65+ for (int i = 0 ; i < rank; ++i) {
66+ sizes.emplace_back (shapedVal, i);
67+ }
68+ return sizes;
69+ }
2570
2671static bool
2772isIdentityReassoc (const SmallVector<ReassociationIndices> &indices) {
@@ -33,7 +78,7 @@ isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
3378};
3479
3580static SmallVector<ReassociationIndices>
36- computeReassocFromShapeMap (ArrayRef<SmallVector<int64_t >> shapeMap) {
81+ computeReassocFromShapeMap (ArrayRef<SmallVector<DimSize >> shapeMap) {
3782 SmallVector<ReassociationIndices> reassoc;
3883 int64_t dimCount = 0 ;
3984 for (auto &shape : shapeMap) {
@@ -45,14 +90,13 @@ computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
4590}
4691
4792namespace {
48-
4993// / Helper class that supports fusing reshapes with operands when not all of the
5094// / shape dims map to the iteration space.
5195struct ReshapeOperandInfo {
5296 static constexpr int64_t kNoMapping = -1 ;
5397
5498 // Original shape of this operand.
55- ArrayRef< int64_t > originalShape;
99+ SmallVector<DimSize > originalShape;
56100
57101 // Similar to the results of the operand's `AffineMap` except `kNoMapping` if
58102 // that dim doesn't map to the iteration space. For example, the indexed
@@ -72,7 +116,7 @@ class ExpansionInfo {
72116 SmallVector<int64_t > loopRanges,
73117 OpOperand *fusableOpOperand,
74118 ArrayRef<ReassociationIndices> operandReassoc,
75- ArrayRef<int64_t > expandedShape);
119+ ArrayRef<DimSize > expandedShape);
76120
77121 std::optional<Value> getOrCreateExpanded (Location loc, OpOperand *operand,
78122 RewriterBase &rewriter) {
@@ -81,13 +125,17 @@ class ExpansionInfo {
81125 if (isIdentityReassoc (reassoc)) {
82126 return operand->get ();
83127 }
84- SmallVector<int64_t > flattenedArray ;
128+ SmallVector<OpFoldResult> outputShape ;
85129 for (auto &shape : shapeMap) {
86- flattenedArray.append (shape.begin (), shape.end ());
130+ llvm::append_range (
131+ outputShape, llvm::map_range (shape, [&rewriter](const DimSize &size) {
132+ return size.materialize (rewriter);
133+ }));
87134 }
135+ auto [staticShape, dynamicShape] = decomposeMixedValues (outputShape);
136+ (void )dynamicShape;
88137 auto oldType = cast<ShapedType>(operand->get ().getType ());
89- auto newType =
90- RankedTensorType::get (flattenedArray, oldType.getElementType ());
138+ auto newType = RankedTensorType::get (staticShape, oldType.getElementType ());
91139 if (failed (reshapeLikeShapesAreCompatible (
92140 [&](const Twine &msg) {
93141 return rewriter.notifyMatchFailure (loc, msg);
@@ -97,18 +145,18 @@ class ExpansionInfo {
97145 return {};
98146 }
99147 return tensor::ExpandShapeOp::create (rewriter, loc, newType, operand->get (),
100- reassoc);
148+ reassoc, outputShape );
101149 };
102150
103151 // / Get the shape map for the operand.
104- SmallVector<SmallVector<int64_t >> getShapeMap (OpOperand *operand) const {
152+ SmallVector<SmallVector<DimSize >> getShapeMap (OpOperand *operand) const {
105153 auto info = reshapeInfos[operand->getOperandNumber ()];
106- SmallVector<SmallVector<int64_t >> shapeMap;
154+ SmallVector<SmallVector<DimSize >> shapeMap;
107155 for (auto [operandIdx, loopIdx] :
108156 llvm::enumerate (info.operandToIterationSpace )) {
109157 if (loopIdx == ReshapeOperandInfo::kNoMapping ) {
110158 shapeMap.push_back (
111- SmallVector<int64_t >{info.originalShape [operandIdx]});
159+ SmallVector<DimSize >{info.originalShape [operandIdx]});
112160 } else {
113161 shapeMap.push_back (loopShapeMap[loopIdx]);
114162 }
@@ -126,17 +174,12 @@ class ExpansionInfo {
126174 ReassociationIndicesRef getExpandedLoops (unsigned i) const {
127175 return loopReassoc[i];
128176 }
129- ArrayRef<int64_t > getExpandedShapeOfLoop (unsigned i) const {
130- return loopShapeMap[i];
131- }
132177
133178private:
134- // / Extent of the iteration space in the original operation.
135- SmallVector<int64_t > loopRanges;
136179 SmallVector<ReassociationIndices> loopReassoc;
137180 // / Mapping from extent of loops in the original operation, to the extent of
138181 // / loops in the expanded operation.
139- SmallVector<SmallVector<int64_t >> loopShapeMap;
182+ SmallVector<SmallVector<DimSize >> loopShapeMap;
140183 unsigned expandedOpNumDims;
141184 // / Info about the reassociation and original shape for each operand.
142185 SmallVector<ReshapeOperandInfo> reshapeInfos;
@@ -196,7 +239,7 @@ class CollapsingInfo {
196239LogicalResult ExpansionInfo::compute (
197240 SmallVector<ReshapeOperandInfo> infos, SmallVector<int64_t > loopRanges,
198241 OpOperand *fusableOpOperand, ArrayRef<ReassociationIndices> operandReassoc,
199- ArrayRef<int64_t > expandedShape) {
242+ ArrayRef<DimSize > expandedShape) {
200243 if (operandReassoc.empty ())
201244 return failure ();
202245
@@ -206,7 +249,8 @@ LogicalResult ExpansionInfo::compute(
206249 for (auto [operandDim, iterDim] :
207250 llvm::enumerate (info.operandToIterationSpace )) {
208251 if (iterDim != ReshapeOperandInfo::kNoMapping &&
209- loopRanges[iterDim] != info.originalShape [operandDim]) {
252+ ShapedType::isStatic (loopRanges[iterDim]) !=
253+ info.originalShape [operandDim].isStatic ()) {
210254 return failure ();
211255 }
212256 }
@@ -229,12 +273,22 @@ LogicalResult ExpansionInfo::compute(
229273 }
230274 }
231275
232- // Fill in the remaining elements with `loopRanges`
233- this ->expandedOpNumDims = 0 ;
234- for (const auto &[loopIdx, shapeMap] : llvm::enumerate (this ->loopShapeMap )) {
235- if (shapeMap.empty ()) {
236- this ->loopShapeMap [loopIdx] = SmallVector<int64_t >{loopRanges[loopIdx]};
276+ // Fill in the remaining elements.
277+ for (const ReshapeOperandInfo &info : infos) {
278+ for (auto [operandIdx, loopIdx] :
279+ llvm::enumerate (info.operandToIterationSpace )) {
280+ if (loopIdx == ReshapeOperandInfo::kNoMapping ||
281+ !this ->loopShapeMap [loopIdx].empty ()) {
282+ continue ;
283+ }
284+
285+ this ->loopShapeMap [loopIdx] =
286+ SmallVector<DimSize>{info.originalShape [operandIdx]};
237287 }
288+ }
289+
290+ this ->expandedOpNumDims = 0 ;
291+ for (const auto &shapeMap : this ->loopShapeMap ) {
238292 this ->expandedOpNumDims += shapeMap.size ();
239293 }
240294
@@ -244,7 +298,6 @@ LogicalResult ExpansionInfo::compute(
244298 }
245299 this ->loopReassoc = computeReassocFromShapeMap (this ->loopShapeMap );
246300 this ->reshapeInfos = std::move (infos);
247- this ->loopRanges = std::move (loopRanges);
248301 return success ();
249302}
250303
@@ -307,7 +360,7 @@ getReshapeInfo(LinalgExt::AttentionOp attentionOp) {
307360 return operandInfo;
308361 }
309362
310- operandInfo.originalShape = operandType. getShape ( );
363+ operandInfo.originalShape = getDimSizes (opOperand. get () );
311364 for (auto result :
312365 attentionOp.getMatchingIndexingMap (&opOperand).getResults ()) {
313366 operandInfo.operandToIterationSpace .push_back (
@@ -325,13 +378,13 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
325378 auto updateRank = scatterOp.getUpdateType ().getRank ();
326379
327380 ReshapeOperandInfo updateInfo;
328- updateInfo.originalShape = scatterOp.getUpdateType (). getShape ( );
381+ updateInfo.originalShape = getDimSizes ( scatterOp.getUpdates () );
329382 llvm::append_range (updateInfo.operandToIterationSpace ,
330383 llvm::seq<int64_t >(0 , updateRank));
331384 infos.push_back (std::move (updateInfo));
332385
333386 ReshapeOperandInfo indicesInfo;
334- indicesInfo.originalShape = scatterOp.getIndicesType (). getShape ( );
387+ indicesInfo.originalShape = getDimSizes ( scatterOp.getIndices () );
335388 llvm::append_range (indicesInfo.operandToIterationSpace ,
336389 llvm::seq<int64_t >(0 , scatterOp.getBatchRank ()));
337390 if (scatterOp.getBatchRank () != scatterOp.getIndicesType ().getRank ())
@@ -340,7 +393,7 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
340393 infos.push_back (std::move (indicesInfo));
341394
342395 ReshapeOperandInfo originalInfo;
343- originalInfo.originalShape = scatterOp.getOriginalType (). getShape ( );
396+ originalInfo.originalShape = getDimSizes ( scatterOp.getOriginal () );
344397 originalInfo.operandToIterationSpace .append (scatterOp.getIndexDepth (),
345398 ReshapeOperandInfo::kNoMapping );
346399 llvm::append_range (originalInfo.operandToIterationSpace ,
@@ -356,15 +409,15 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
356409 auto outputRank = gatherOp.getOutputType ().getRank ();
357410
358411 ReshapeOperandInfo sourceInfo;
359- sourceInfo.originalShape = gatherOp.getSourceType (). getShape ( );
412+ sourceInfo.originalShape = getDimSizes ( gatherOp.getSource () );
360413 sourceInfo.operandToIterationSpace .append (gatherOp.getIndexDepth (),
361414 ReshapeOperandInfo::kNoMapping );
362415 llvm::append_range (sourceInfo.operandToIterationSpace ,
363416 llvm::seq (outputRank - rankOfContiguousSlice, outputRank));
364417 infos.push_back (std::move (sourceInfo));
365418
366419 ReshapeOperandInfo indicesInfo;
367- indicesInfo.originalShape = gatherOp.getIndicesType (). getShape ( );
420+ indicesInfo.originalShape = getDimSizes ( gatherOp.getIndices () );
368421 llvm::append_range (indicesInfo.operandToIterationSpace ,
369422 llvm::seq<int64_t >(0 , gatherOp.getBatchRank ()));
370423 if (gatherOp.getBatchRank () != gatherOp.getIndicesType ().getRank ())
@@ -373,7 +426,7 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
373426 infos.push_back (std::move (indicesInfo));
374427
375428 ReshapeOperandInfo outputInfo;
376- outputInfo.originalShape = gatherOp.getOutputType (). getShape ( );
429+ outputInfo.originalShape = getDimSizes ( gatherOp.getOutput () );
377430 llvm::append_range (outputInfo.operandToIterationSpace ,
378431 llvm::seq<int64_t >(0 , outputRank));
379432 infos.push_back (std::move (outputInfo));
@@ -407,15 +460,26 @@ fuseWithReshapeByExpansion(OpTy op, Operation *reshapeOp,
407460 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
408461 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
409462 bool isExpanding = (expandingReshapeOp != nullptr );
410- RankedTensorType expandedType = isExpanding
411- ? expandingReshapeOp.getResultType ()
412- : collapsingReshapeOp.getSrcType ();
463+ Value expandedVal = isExpanding ? expandingReshapeOp.getResult ()
464+ : collapsingReshapeOp.getSrc ();
465+ SmallVector<DimSize> expandedSize;
466+ if (isExpanding) {
467+ // The SSA dims must dominate `op` in order to use them to create new
468+ // expand_shape ops.
469+ if (failed (moveValueDefinitions (rewriter,
470+ expandingReshapeOp.getOutputShape (), op))) {
471+ return std::nullopt ;
472+ }
473+ llvm::append_range (expandedSize, expandingReshapeOp.getMixedOutputShape ());
474+ } else {
475+ expandedSize = getDimSizes (expandedVal);
476+ }
413477 ExpansionInfo info;
414478 if (failed (info.compute (
415479 getReshapeInfo (op), op.getStaticLoopRanges (), fusableOpOperand,
416480 isExpanding ? expandingReshapeOp.getReassociationIndices ()
417481 : collapsingReshapeOp.getReassociationIndices (),
418- expandedType. getShape () ))) {
482+ expandedSize ))) {
419483 return std::nullopt ;
420484 }
421485
0 commit comments