1111#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1212#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1313#include " iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
14- #include " iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1514#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1615#include " mlir/Dialect/Tensor/IR/Tensor.h"
1716#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
1817#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
1918#include " mlir/IR/MLIRContext.h"
2019#include " mlir/IR/PatternMatch.h"
21- #include " mlir/Transforms/RegionUtils.h"
2220
2321#include < cstdint>
2422#include < optional>
2523
2624namespace mlir ::iree_compiler::IREE::LinalgExt {
27- namespace {
28-
29- // / Represents the size of a dimension of some ShapedType value in the IR. This
30- // / is used instead of OpFoldResult when modifying the IR is illegal. This can
31- // / still be constructed from an OpFoldResult in cases where the value can be
32- // / obtained without IR modification.
33- class DimSize {
34- public:
35- DimSize (TypedValue<ShapedType> val, int64_t dim)
36- : ofr(nullptr ), val(val), dim(dim) {}
37- DimSize (OpFoldResult ofr) : ofr(ofr), val(nullptr ), dim(-1 ) {}
38-
39- bool isStatic () const {
40- if (ofr) {
41- return getConstantIntValue (ofr).has_value ();
42- }
43- return val.getType ().isStaticDim (dim);
44- }
45-
46- // Get an OpFoldResult by possibly inserting IR.
47- OpFoldResult materialize (OpBuilder &b) const {
48- if (ofr) {
49- return ofr;
50- }
51- return getDim (b, val.getLoc (), val, dim);
52- }
53-
54- private:
55- OpFoldResult ofr;
56- TypedValue<ShapedType> val;
57- int64_t dim;
58- };
59- } // namespace
60-
61- static SmallVector<DimSize> getDimSizes (Value v) {
62- auto shapedVal = cast<TypedValue<ShapedType>>(v);
63- int64_t rank = shapedVal.getType ().getRank ();
64- SmallVector<DimSize> sizes;
65- for (int i = 0 ; i < rank; ++i) {
66- sizes.emplace_back (shapedVal, i);
67- }
68- return sizes;
69- }
7025
7126static bool
7227isIdentityReassoc (const SmallVector<ReassociationIndices> &indices) {
@@ -78,7 +33,7 @@ isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
7833};
7934
8035static SmallVector<ReassociationIndices>
81- computeReassocFromShapeMap (ArrayRef<SmallVector<DimSize >> shapeMap) {
36+ computeReassocFromShapeMap (ArrayRef<SmallVector<int64_t >> shapeMap) {
8237 SmallVector<ReassociationIndices> reassoc;
8338 int64_t dimCount = 0 ;
8439 for (auto &shape : shapeMap) {
@@ -90,13 +45,14 @@ computeReassocFromShapeMap(ArrayRef<SmallVector<DimSize>> shapeMap) {
9045}
9146
9247namespace {
48+
9349// / Helper class that supports fusing reshapes with operands when not all of the
9450// / shape dims map to the iteration space.
9551struct ReshapeOperandInfo {
9652 static constexpr int64_t kNoMapping = -1 ;
9753
9854 // Original shape of this operand.
99- SmallVector<DimSize > originalShape;
55+ ArrayRef< int64_t > originalShape;
10056
10157 // Similar to the results of the operand's `AffineMap` except `kNoMapping` if
10258 // that dim doesn't map to the iteration space. For example, the indexed
@@ -116,7 +72,7 @@ class ExpansionInfo {
11672 SmallVector<int64_t > loopRanges,
11773 OpOperand *fusableOpOperand,
11874 ArrayRef<ReassociationIndices> operandReassoc,
119- ArrayRef<DimSize > expandedShape);
75+ ArrayRef<int64_t > expandedShape);
12076
12177 std::optional<Value> getOrCreateExpanded (Location loc, OpOperand *operand,
12278 RewriterBase &rewriter) {
@@ -125,17 +81,13 @@ class ExpansionInfo {
12581 if (isIdentityReassoc (reassoc)) {
12682 return operand->get ();
12783 }
128- SmallVector<OpFoldResult> outputShape ;
84+ SmallVector<int64_t > flattenedArray ;
12985 for (auto &shape : shapeMap) {
130- llvm::append_range (
131- outputShape, llvm::map_range (shape, [&rewriter](const DimSize &size) {
132- return size.materialize (rewriter);
133- }));
86+ flattenedArray.append (shape.begin (), shape.end ());
13487 }
135- auto [staticShape, dynamicShape] = decomposeMixedValues (outputShape);
136- (void )dynamicShape;
13788 auto oldType = cast<ShapedType>(operand->get ().getType ());
138- auto newType = RankedTensorType::get (staticShape, oldType.getElementType ());
89+ auto newType =
90+ RankedTensorType::get (flattenedArray, oldType.getElementType ());
13991 if (failed (reshapeLikeShapesAreCompatible (
14092 [&](const Twine &msg) {
14193 return rewriter.notifyMatchFailure (loc, msg);
@@ -145,18 +97,18 @@ class ExpansionInfo {
14597 return {};
14698 }
14799 return tensor::ExpandShapeOp::create (rewriter, loc, newType, operand->get (),
148- reassoc, outputShape );
100+ reassoc);
149101 };
150102
151103 // / Get the shape map for the operand.
152- SmallVector<SmallVector<DimSize >> getShapeMap (OpOperand *operand) const {
104+ SmallVector<SmallVector<int64_t >> getShapeMap (OpOperand *operand) const {
153105 auto info = reshapeInfos[operand->getOperandNumber ()];
154- SmallVector<SmallVector<DimSize >> shapeMap;
106+ SmallVector<SmallVector<int64_t >> shapeMap;
155107 for (auto [operandIdx, loopIdx] :
156108 llvm::enumerate (info.operandToIterationSpace )) {
157109 if (loopIdx == ReshapeOperandInfo::kNoMapping ) {
158110 shapeMap.push_back (
159- SmallVector<DimSize >{info.originalShape [operandIdx]});
111+ SmallVector<int64_t >{info.originalShape [operandIdx]});
160112 } else {
161113 shapeMap.push_back (loopShapeMap[loopIdx]);
162114 }
@@ -174,12 +126,17 @@ class ExpansionInfo {
174126 ReassociationIndicesRef getExpandedLoops (unsigned i) const {
175127 return loopReassoc[i];
176128 }
129+ ArrayRef<int64_t > getExpandedShapeOfLoop (unsigned i) const {
130+ return loopShapeMap[i];
131+ }
177132
178133private:
134+ // / Extent of the iteration space in the original operation.
135+ SmallVector<int64_t > loopRanges;
179136 SmallVector<ReassociationIndices> loopReassoc;
180137 // / Mapping from extent of loops in the original operation, to the extent of
181138 // / loops in the expanded operation.
182- SmallVector<SmallVector<DimSize >> loopShapeMap;
139+ SmallVector<SmallVector<int64_t >> loopShapeMap;
183140 unsigned expandedOpNumDims;
184141 // / Info about the reassociation and original shape for each operand.
185142 SmallVector<ReshapeOperandInfo> reshapeInfos;
@@ -239,7 +196,7 @@ class CollapsingInfo {
239196LogicalResult ExpansionInfo::compute (
240197 SmallVector<ReshapeOperandInfo> infos, SmallVector<int64_t > loopRanges,
241198 OpOperand *fusableOpOperand, ArrayRef<ReassociationIndices> operandReassoc,
242- ArrayRef<DimSize > expandedShape) {
199+ ArrayRef<int64_t > expandedShape) {
243200 if (operandReassoc.empty ())
244201 return failure ();
245202
@@ -249,8 +206,7 @@ LogicalResult ExpansionInfo::compute(
249206 for (auto [operandDim, iterDim] :
250207 llvm::enumerate (info.operandToIterationSpace )) {
251208 if (iterDim != ReshapeOperandInfo::kNoMapping &&
252- ShapedType::isStatic (loopRanges[iterDim]) !=
253- info.originalShape [operandDim].isStatic ()) {
209+ loopRanges[iterDim] != info.originalShape [operandDim]) {
254210 return failure ();
255211 }
256212 }
@@ -273,22 +229,12 @@ LogicalResult ExpansionInfo::compute(
273229 }
274230 }
275231
276- // Fill in the remaining elements.
277- for (const ReshapeOperandInfo &info : infos) {
278- for (auto [operandIdx, loopIdx] :
279- llvm::enumerate (info.operandToIterationSpace )) {
280- if (loopIdx == ReshapeOperandInfo::kNoMapping ||
281- !this ->loopShapeMap [loopIdx].empty ()) {
282- continue ;
283- }
284-
285- this ->loopShapeMap [loopIdx] =
286- SmallVector<DimSize>{info.originalShape [operandIdx]};
287- }
288- }
289-
232+ // Fill in the remaining elements with `loopRanges`
290233 this ->expandedOpNumDims = 0 ;
291- for (const auto &shapeMap : this ->loopShapeMap ) {
234+ for (const auto &[loopIdx, shapeMap] : llvm::enumerate (this ->loopShapeMap )) {
235+ if (shapeMap.empty ()) {
236+ this ->loopShapeMap [loopIdx] = SmallVector<int64_t >{loopRanges[loopIdx]};
237+ }
292238 this ->expandedOpNumDims += shapeMap.size ();
293239 }
294240
@@ -298,6 +244,7 @@ LogicalResult ExpansionInfo::compute(
298244 }
299245 this ->loopReassoc = computeReassocFromShapeMap (this ->loopShapeMap );
300246 this ->reshapeInfos = std::move (infos);
247+ this ->loopRanges = std::move (loopRanges);
301248 return success ();
302249}
303250
@@ -360,7 +307,7 @@ getReshapeInfo(LinalgExt::AttentionOp attentionOp) {
360307 return operandInfo;
361308 }
362309
363- operandInfo.originalShape = getDimSizes (opOperand. get () );
310+ operandInfo.originalShape = operandType. getShape ( );
364311 for (auto result :
365312 attentionOp.getMatchingIndexingMap (&opOperand).getResults ()) {
366313 operandInfo.operandToIterationSpace .push_back (
@@ -378,13 +325,13 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
378325 auto updateRank = scatterOp.getUpdateType ().getRank ();
379326
380327 ReshapeOperandInfo updateInfo;
381- updateInfo.originalShape = getDimSizes ( scatterOp.getUpdates () );
328+ updateInfo.originalShape = scatterOp.getUpdateType (). getShape ( );
382329 llvm::append_range (updateInfo.operandToIterationSpace ,
383330 llvm::seq<int64_t >(0 , updateRank));
384331 infos.push_back (std::move (updateInfo));
385332
386333 ReshapeOperandInfo indicesInfo;
387- indicesInfo.originalShape = getDimSizes ( scatterOp.getIndices () );
334+ indicesInfo.originalShape = scatterOp.getIndicesType (). getShape ( );
388335 llvm::append_range (indicesInfo.operandToIterationSpace ,
389336 llvm::seq<int64_t >(0 , scatterOp.getBatchRank ()));
390337 if (scatterOp.getBatchRank () != scatterOp.getIndicesType ().getRank ())
@@ -393,7 +340,7 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
393340 infos.push_back (std::move (indicesInfo));
394341
395342 ReshapeOperandInfo originalInfo;
396- originalInfo.originalShape = getDimSizes ( scatterOp.getOriginal () );
343+ originalInfo.originalShape = scatterOp.getOriginalType (). getShape ( );
397344 originalInfo.operandToIterationSpace .append (scatterOp.getIndexDepth (),
398345 ReshapeOperandInfo::kNoMapping );
399346 llvm::append_range (originalInfo.operandToIterationSpace ,
@@ -409,15 +356,15 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
409356 auto outputRank = gatherOp.getOutputType ().getRank ();
410357
411358 ReshapeOperandInfo sourceInfo;
412- sourceInfo.originalShape = getDimSizes ( gatherOp.getSource () );
359+ sourceInfo.originalShape = gatherOp.getSourceType (). getShape ( );
413360 sourceInfo.operandToIterationSpace .append (gatherOp.getIndexDepth (),
414361 ReshapeOperandInfo::kNoMapping );
415362 llvm::append_range (sourceInfo.operandToIterationSpace ,
416363 llvm::seq (outputRank - rankOfContiguousSlice, outputRank));
417364 infos.push_back (std::move (sourceInfo));
418365
419366 ReshapeOperandInfo indicesInfo;
420- indicesInfo.originalShape = getDimSizes ( gatherOp.getIndices () );
367+ indicesInfo.originalShape = gatherOp.getIndicesType (). getShape ( );
421368 llvm::append_range (indicesInfo.operandToIterationSpace ,
422369 llvm::seq<int64_t >(0 , gatherOp.getBatchRank ()));
423370 if (gatherOp.getBatchRank () != gatherOp.getIndicesType ().getRank ())
@@ -426,7 +373,7 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
426373 infos.push_back (std::move (indicesInfo));
427374
428375 ReshapeOperandInfo outputInfo;
429- outputInfo.originalShape = getDimSizes ( gatherOp.getOutput () );
376+ outputInfo.originalShape = gatherOp.getOutputType (). getShape ( );
430377 llvm::append_range (outputInfo.operandToIterationSpace ,
431378 llvm::seq<int64_t >(0 , outputRank));
432379 infos.push_back (std::move (outputInfo));
@@ -460,26 +407,15 @@ fuseWithReshapeByExpansion(OpTy op, Operation *reshapeOp,
460407 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
461408 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
462409 bool isExpanding = (expandingReshapeOp != nullptr );
463- Value expandedVal = isExpanding ? expandingReshapeOp.getResult ()
464- : collapsingReshapeOp.getSrc ();
465- SmallVector<DimSize> expandedSize;
466- if (isExpanding) {
467- // The SSA dims must dominate `op` in order to use them to create new
468- // expand_shape ops.
469- if (failed (moveValueDefinitions (rewriter,
470- expandingReshapeOp.getOutputShape (), op))) {
471- return std::nullopt ;
472- }
473- llvm::append_range (expandedSize, expandingReshapeOp.getMixedOutputShape ());
474- } else {
475- expandedSize = getDimSizes (expandedVal);
476- }
410+ RankedTensorType expandedType = isExpanding
411+ ? expandingReshapeOp.getResultType ()
412+ : collapsingReshapeOp.getSrcType ();
477413 ExpansionInfo info;
478414 if (failed (info.compute (
479415 getReshapeInfo (op), op.getStaticLoopRanges (), fusableOpOperand,
480416 isExpanding ? expandingReshapeOp.getReassociationIndices ()
481417 : collapsingReshapeOp.getReassociationIndices (),
482- expandedSize ))) {
418+ expandedType. getShape () ))) {
483419 return std::nullopt ;
484420 }
485421
0 commit comments