Skip to content

Commit be510b6

Browse files
authored
[LinalgExt] Fix FoldWithProducerReshapeByExpansion for >1 dyn dim (iree-org#21894)
The builder for `tensor.expand_shape` cannot infer the output shape in cases where there are more then 1 dynamic dimensions. So, `ExpansionInfo` needs to track the SSA values for the expanded shape to be able to create reshapes. This introduces two problems: 1. Creating `tensor.dim` ops early (i.e. before knowing if the pattern can be successfully applied) will cause the greedy pattern rewrite driver to loop forever. This is fixed by using the `DimSize` class to delay IR modifications until the pattern is known to not fail. 2. In the case of `iree_linalg_ext -> tensor.expand_shape`, the output shape SSA values must be moved to dominate the op. This is also done upstream: https://github.com/llvm/llvm-project/blob/879f40ab041b31fa73b9b25e4ec9e06e810bc767/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp#L875-L894 Closes iree-org#21889 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent b18e45c commit be510b6

File tree

2 files changed

+128
-70
lines changed

2 files changed

+128
-70
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

Lines changed: 101 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,62 @@
1111
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1212
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1313
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
14+
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1415
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1516
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1617
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1718
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1819
#include "mlir/IR/MLIRContext.h"
1920
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Transforms/RegionUtils.h"
2022

2123
#include <cstdint>
2224
#include <optional>
2325

2426
namespace mlir::iree_compiler::IREE::LinalgExt {
27+
namespace {
28+
29+
/// Represents the size of a dimension of some ShapedType value in the IR. This
30+
/// is used instead of OpFoldResult when modifying the IR is illegal. This can
31+
/// still be constructed from an OpFoldResult in cases where the value can be
32+
/// obtained without IR modification.
33+
class DimSize {
34+
public:
35+
DimSize(TypedValue<ShapedType> val, int64_t dim)
36+
: ofr(nullptr), val(val), dim(dim) {}
37+
DimSize(OpFoldResult ofr) : ofr(ofr), val(nullptr), dim(-1) {}
38+
39+
bool isStatic() const {
40+
if (ofr) {
41+
return getConstantIntValue(ofr).has_value();
42+
}
43+
return val.getType().isStaticDim(dim);
44+
}
45+
46+
// Get an OpFoldResult by possibly inserting IR.
47+
OpFoldResult materialize(OpBuilder &b) const {
48+
if (ofr) {
49+
return ofr;
50+
}
51+
return getDim(b, val.getLoc(), val, dim);
52+
}
53+
54+
private:
55+
OpFoldResult ofr;
56+
TypedValue<ShapedType> val;
57+
int64_t dim;
58+
};
59+
} // namespace
60+
61+
static SmallVector<DimSize> getDimSizes(Value v) {
62+
auto shapedVal = cast<TypedValue<ShapedType>>(v);
63+
int64_t rank = shapedVal.getType().getRank();
64+
SmallVector<DimSize> sizes;
65+
for (int i = 0; i < rank; ++i) {
66+
sizes.emplace_back(shapedVal, i);
67+
}
68+
return sizes;
69+
}
2570

2671
static bool
2772
isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
@@ -33,7 +78,7 @@ isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
3378
};
3479

3580
static SmallVector<ReassociationIndices>
36-
computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
81+
computeReassocFromShapeMap(ArrayRef<SmallVector<DimSize>> shapeMap) {
3782
SmallVector<ReassociationIndices> reassoc;
3883
int64_t dimCount = 0;
3984
for (auto &shape : shapeMap) {
@@ -45,14 +90,13 @@ computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
4590
}
4691

4792
namespace {
48-
4993
/// Helper class that supports fusing reshapes with operands when not all of the
5094
/// shape dims map to the iteration space.
5195
struct ReshapeOperandInfo {
5296
static constexpr int64_t kNoMapping = -1;
5397

5498
// Original shape of this operand.
55-
ArrayRef<int64_t> originalShape;
99+
SmallVector<DimSize> originalShape;
56100

57101
// Similar to the results of the operand's `AffineMap` except `kNoMapping` if
58102
// that dim doesn't map to the iteration space. For example, the indexed
@@ -72,7 +116,7 @@ class ExpansionInfo {
72116
SmallVector<int64_t> loopRanges,
73117
OpOperand *fusableOpOperand,
74118
ArrayRef<ReassociationIndices> operandReassoc,
75-
ArrayRef<int64_t> expandedShape);
119+
ArrayRef<DimSize> expandedShape);
76120

77121
std::optional<Value> getOrCreateExpanded(Location loc, OpOperand *operand,
78122
RewriterBase &rewriter) {
@@ -81,13 +125,17 @@ class ExpansionInfo {
81125
if (isIdentityReassoc(reassoc)) {
82126
return operand->get();
83127
}
84-
SmallVector<int64_t> flattenedArray;
128+
SmallVector<OpFoldResult> outputShape;
85129
for (auto &shape : shapeMap) {
86-
flattenedArray.append(shape.begin(), shape.end());
130+
llvm::append_range(
131+
outputShape, llvm::map_range(shape, [&rewriter](const DimSize &size) {
132+
return size.materialize(rewriter);
133+
}));
87134
}
135+
auto [staticShape, dynamicShape] = decomposeMixedValues(outputShape);
136+
(void)dynamicShape;
88137
auto oldType = cast<ShapedType>(operand->get().getType());
89-
auto newType =
90-
RankedTensorType::get(flattenedArray, oldType.getElementType());
138+
auto newType = RankedTensorType::get(staticShape, oldType.getElementType());
91139
if (failed(reshapeLikeShapesAreCompatible(
92140
[&](const Twine &msg) {
93141
return rewriter.notifyMatchFailure(loc, msg);
@@ -97,18 +145,18 @@ class ExpansionInfo {
97145
return {};
98146
}
99147
return tensor::ExpandShapeOp::create(rewriter, loc, newType, operand->get(),
100-
reassoc);
148+
reassoc, outputShape);
101149
};
102150

103151
/// Get the shape map for the operand.
104-
SmallVector<SmallVector<int64_t>> getShapeMap(OpOperand *operand) const {
152+
SmallVector<SmallVector<DimSize>> getShapeMap(OpOperand *operand) const {
105153
auto info = reshapeInfos[operand->getOperandNumber()];
106-
SmallVector<SmallVector<int64_t>> shapeMap;
154+
SmallVector<SmallVector<DimSize>> shapeMap;
107155
for (auto [operandIdx, loopIdx] :
108156
llvm::enumerate(info.operandToIterationSpace)) {
109157
if (loopIdx == ReshapeOperandInfo::kNoMapping) {
110158
shapeMap.push_back(
111-
SmallVector<int64_t>{info.originalShape[operandIdx]});
159+
SmallVector<DimSize>{info.originalShape[operandIdx]});
112160
} else {
113161
shapeMap.push_back(loopShapeMap[loopIdx]);
114162
}
@@ -126,17 +174,12 @@ class ExpansionInfo {
126174
ReassociationIndicesRef getExpandedLoops(unsigned i) const {
127175
return loopReassoc[i];
128176
}
129-
ArrayRef<int64_t> getExpandedShapeOfLoop(unsigned i) const {
130-
return loopShapeMap[i];
131-
}
132177

133178
private:
134-
/// Extent of the iteration space in the original operation.
135-
SmallVector<int64_t> loopRanges;
136179
SmallVector<ReassociationIndices> loopReassoc;
137180
/// Mapping from extent of loops in the original operation, to the extent of
138181
/// loops in the expanded operation.
139-
SmallVector<SmallVector<int64_t>> loopShapeMap;
182+
SmallVector<SmallVector<DimSize>> loopShapeMap;
140183
unsigned expandedOpNumDims;
141184
/// Info about the reassociation and original shape for each operand.
142185
SmallVector<ReshapeOperandInfo> reshapeInfos;
@@ -196,7 +239,7 @@ class CollapsingInfo {
196239
LogicalResult ExpansionInfo::compute(
197240
SmallVector<ReshapeOperandInfo> infos, SmallVector<int64_t> loopRanges,
198241
OpOperand *fusableOpOperand, ArrayRef<ReassociationIndices> operandReassoc,
199-
ArrayRef<int64_t> expandedShape) {
242+
ArrayRef<DimSize> expandedShape) {
200243
if (operandReassoc.empty())
201244
return failure();
202245

@@ -206,7 +249,8 @@ LogicalResult ExpansionInfo::compute(
206249
for (auto [operandDim, iterDim] :
207250
llvm::enumerate(info.operandToIterationSpace)) {
208251
if (iterDim != ReshapeOperandInfo::kNoMapping &&
209-
loopRanges[iterDim] != info.originalShape[operandDim]) {
252+
ShapedType::isStatic(loopRanges[iterDim]) !=
253+
info.originalShape[operandDim].isStatic()) {
210254
return failure();
211255
}
212256
}
@@ -229,12 +273,22 @@ LogicalResult ExpansionInfo::compute(
229273
}
230274
}
231275

232-
// Fill in the remaining elements with `loopRanges`
233-
this->expandedOpNumDims = 0;
234-
for (const auto &[loopIdx, shapeMap] : llvm::enumerate(this->loopShapeMap)) {
235-
if (shapeMap.empty()) {
236-
this->loopShapeMap[loopIdx] = SmallVector<int64_t>{loopRanges[loopIdx]};
276+
// Fill in the remaining elements.
277+
for (const ReshapeOperandInfo &info : infos) {
278+
for (auto [operandIdx, loopIdx] :
279+
llvm::enumerate(info.operandToIterationSpace)) {
280+
if (loopIdx == ReshapeOperandInfo::kNoMapping ||
281+
!this->loopShapeMap[loopIdx].empty()) {
282+
continue;
283+
}
284+
285+
this->loopShapeMap[loopIdx] =
286+
SmallVector<DimSize>{info.originalShape[operandIdx]};
237287
}
288+
}
289+
290+
this->expandedOpNumDims = 0;
291+
for (const auto &shapeMap : this->loopShapeMap) {
238292
this->expandedOpNumDims += shapeMap.size();
239293
}
240294

@@ -244,7 +298,6 @@ LogicalResult ExpansionInfo::compute(
244298
}
245299
this->loopReassoc = computeReassocFromShapeMap(this->loopShapeMap);
246300
this->reshapeInfos = std::move(infos);
247-
this->loopRanges = std::move(loopRanges);
248301
return success();
249302
}
250303

@@ -307,7 +360,7 @@ getReshapeInfo(LinalgExt::AttentionOp attentionOp) {
307360
return operandInfo;
308361
}
309362

310-
operandInfo.originalShape = operandType.getShape();
363+
operandInfo.originalShape = getDimSizes(opOperand.get());
311364
for (auto result :
312365
attentionOp.getMatchingIndexingMap(&opOperand).getResults()) {
313366
operandInfo.operandToIterationSpace.push_back(
@@ -325,13 +378,13 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
325378
auto updateRank = scatterOp.getUpdateType().getRank();
326379

327380
ReshapeOperandInfo updateInfo;
328-
updateInfo.originalShape = scatterOp.getUpdateType().getShape();
381+
updateInfo.originalShape = getDimSizes(scatterOp.getUpdates());
329382
llvm::append_range(updateInfo.operandToIterationSpace,
330383
llvm::seq<int64_t>(0, updateRank));
331384
infos.push_back(std::move(updateInfo));
332385

333386
ReshapeOperandInfo indicesInfo;
334-
indicesInfo.originalShape = scatterOp.getIndicesType().getShape();
387+
indicesInfo.originalShape = getDimSizes(scatterOp.getIndices());
335388
llvm::append_range(indicesInfo.operandToIterationSpace,
336389
llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
337390
if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank())
@@ -340,7 +393,7 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
340393
infos.push_back(std::move(indicesInfo));
341394

342395
ReshapeOperandInfo originalInfo;
343-
originalInfo.originalShape = scatterOp.getOriginalType().getShape();
396+
originalInfo.originalShape = getDimSizes(scatterOp.getOriginal());
344397
originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(),
345398
ReshapeOperandInfo::kNoMapping);
346399
llvm::append_range(originalInfo.operandToIterationSpace,
@@ -356,15 +409,15 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
356409
auto outputRank = gatherOp.getOutputType().getRank();
357410

358411
ReshapeOperandInfo sourceInfo;
359-
sourceInfo.originalShape = gatherOp.getSourceType().getShape();
412+
sourceInfo.originalShape = getDimSizes(gatherOp.getSource());
360413
sourceInfo.operandToIterationSpace.append(gatherOp.getIndexDepth(),
361414
ReshapeOperandInfo::kNoMapping);
362415
llvm::append_range(sourceInfo.operandToIterationSpace,
363416
llvm::seq(outputRank - rankOfContiguousSlice, outputRank));
364417
infos.push_back(std::move(sourceInfo));
365418

366419
ReshapeOperandInfo indicesInfo;
367-
indicesInfo.originalShape = gatherOp.getIndicesType().getShape();
420+
indicesInfo.originalShape = getDimSizes(gatherOp.getIndices());
368421
llvm::append_range(indicesInfo.operandToIterationSpace,
369422
llvm::seq<int64_t>(0, gatherOp.getBatchRank()));
370423
if (gatherOp.getBatchRank() != gatherOp.getIndicesType().getRank())
@@ -373,7 +426,7 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
373426
infos.push_back(std::move(indicesInfo));
374427

375428
ReshapeOperandInfo outputInfo;
376-
outputInfo.originalShape = gatherOp.getOutputType().getShape();
429+
outputInfo.originalShape = getDimSizes(gatherOp.getOutput());
377430
llvm::append_range(outputInfo.operandToIterationSpace,
378431
llvm::seq<int64_t>(0, outputRank));
379432
infos.push_back(std::move(outputInfo));
@@ -407,15 +460,26 @@ fuseWithReshapeByExpansion(OpTy op, Operation *reshapeOp,
407460
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
408461
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
409462
bool isExpanding = (expandingReshapeOp != nullptr);
410-
RankedTensorType expandedType = isExpanding
411-
? expandingReshapeOp.getResultType()
412-
: collapsingReshapeOp.getSrcType();
463+
Value expandedVal = isExpanding ? expandingReshapeOp.getResult()
464+
: collapsingReshapeOp.getSrc();
465+
SmallVector<DimSize> expandedSize;
466+
if (isExpanding) {
467+
// The SSA dims must dominate `op` in order to use them to create new
468+
// expand_shape ops.
469+
if (failed(moveValueDefinitions(rewriter,
470+
expandingReshapeOp.getOutputShape(), op))) {
471+
return std::nullopt;
472+
}
473+
llvm::append_range(expandedSize, expandingReshapeOp.getMixedOutputShape());
474+
} else {
475+
expandedSize = getDimSizes(expandedVal);
476+
}
413477
ExpansionInfo info;
414478
if (failed(info.compute(
415479
getReshapeInfo(op), op.getStaticLoopRanges(), fusableOpOperand,
416480
isExpanding ? expandingReshapeOp.getReassociationIndices()
417481
: collapsingReshapeOp.getReassociationIndices(),
418-
expandedType.getShape()))) {
482+
expandedSize))) {
419483
return std::nullopt;
420484
}
421485

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/reshape_fusion.mlir

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -196,19 +196,13 @@ util.func public @attention_dynamic(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x
196196
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
197197
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
198198
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
199-
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
200199
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divsi %[[D0]]
201-
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
202200
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
203-
// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
204201
// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
205-
// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
206-
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divsi %[[D5]], %[[C2]]
207-
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
208-
// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
209-
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
210-
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divsi %[[D8]], %[[C2]]
211-
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
202+
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D6]], %[[D2]]]
203+
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C2]]
204+
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D6]], %[[D9]]]
205+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D9]]) : tensor<2x?x?x?xf16>
212206
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
213207
// CHECK-SAME: indexing_maps =
214208
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
@@ -256,29 +250,13 @@ util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tens
256250
// CHECK-SAME: %[[ARG3:.+]]: f16
257251
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>)
258252
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
259-
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
260253
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
261-
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
262-
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
263-
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
264-
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
265-
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divsi %[[D0]]
266-
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
267-
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
268-
// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
269-
// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
270-
// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
271-
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divsi %[[D5]], %[[C2]]
272-
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
273-
// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
274-
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
275-
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divsi %[[D8]], %[[C2]]
276-
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
277-
// CHECK-DAG: %[[D10:.+]] = tensor.dim %[[ARG4]], %[[C0]]
278-
// CHECK-DAG: %[[D11:.+]] = tensor.dim %[[ARG4]], %[[C1]]
279-
// CHECK-DAG: %[[D12:.+]] = tensor.dim %[[ARG4]], %[[C2]]
280-
// CHECK-DAG: %[[SPLIT3:.+]] = arith.divsi %[[D10]], %[[C2]]
281-
// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT3]], %[[D11]], %[[D12]]]
254+
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
255+
// CHECK-DAG: %[[SPLIT:.+]] = arith.divsi %[[DIM]], %[[C2]]
256+
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
257+
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
258+
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
259+
// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
282260
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
283261
// CHECK-SAME: indexing_maps =
284262
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
@@ -288,7 +266,6 @@ util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tens
288266
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
289267
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
290268
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
291-
// CHECK-SAME: outs(%[[EMPTY]] :
292269
// CHECK: util.return %[[ATTENTION]]
293270

294271
// -----
@@ -710,6 +687,23 @@ util.func public @scatter_collapse_noop(%arg0: tensor<10xf16>, %arg1: tensor<10x
710687

711688
// -----
712689

690+
util.func public @scatter_collapse_multiple_dynamic(%arg0 : tensor<?x?x4x32x32xf16>, %arg1 : tensor<?xi64>, %arg2 : tensor<?x4x32x32xf16>) -> (tensor<?x4x32x32xf16>){
691+
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3], [4]] : tensor<?x?x4x32x32xf16> into tensor<?x4x32x32xf16>
692+
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x4x32x32xf16>, tensor<?xi64>) outs(%arg2 : tensor<?x4x32x32xf16>) {
693+
^bb0(%arg7: f16, %arg8: f16):
694+
iree_linalg_ext.yield %arg7 : f16
695+
} -> tensor<?x4x32x32xf16>
696+
util.return %0 : tensor<?x4x32x32xf16>
697+
}
698+
// CHECK-LABEL: util.func public @scatter_collapse_multiple_dynamic
699+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
700+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
701+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
702+
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
703+
// CHECK-SAME: ins({{.*}} : tensor<?x?x4x32x32xf16>, tensor<?x?xi64>
704+
705+
// -----
706+
713707
util.func public @gather_expand(%arg0: tensor<100x128xf16>, %arg1: tensor<10xi32>, %arg2: tensor<10x128xf16>) -> tensor<2x5x4x32xf16> {
714708
%c0 = arith.constant 0 : index
715709
%0 = iree_linalg_ext.gather dimension_map = [0] ins(%arg0, %arg1 : tensor<100x128xf16>, tensor<10xi32>) outs(%arg2 : tensor<10x128xf16>) -> tensor<10x128xf16>

0 commit comments

Comments
 (0)