Skip to content

Commit 515f292

Browse files
authored
Revert "[LinalgExt] Fix FoldWithProducerReshapeByExpansion for >1 dyn dim" (iree-org#21947)
Reverts iree-org#21894 Breaks iree-org#21889 Fixes iree-org#21941
1 parent af5f023 commit 515f292

File tree

2 files changed

+70
-128
lines changed

2 files changed

+70
-128
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp

Lines changed: 37 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -11,62 +11,17 @@
1111
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1212
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1313
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
14-
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1514
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1615
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1716
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1817
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1918
#include "mlir/IR/MLIRContext.h"
2019
#include "mlir/IR/PatternMatch.h"
21-
#include "mlir/Transforms/RegionUtils.h"
2220

2321
#include <cstdint>
2422
#include <optional>
2523

2624
namespace mlir::iree_compiler::IREE::LinalgExt {
27-
namespace {
28-
29-
/// Represents the size of a dimension of some ShapedType value in the IR. This
30-
/// is used instead of OpFoldResult when modifying the IR is illegal. This can
31-
/// still be constructed from an OpFoldResult in cases where the value can be
32-
/// obtained without IR modification.
33-
class DimSize {
34-
public:
35-
DimSize(TypedValue<ShapedType> val, int64_t dim)
36-
: ofr(nullptr), val(val), dim(dim) {}
37-
DimSize(OpFoldResult ofr) : ofr(ofr), val(nullptr), dim(-1) {}
38-
39-
bool isStatic() const {
40-
if (ofr) {
41-
return getConstantIntValue(ofr).has_value();
42-
}
43-
return val.getType().isStaticDim(dim);
44-
}
45-
46-
// Get an OpFoldResult by possibly inserting IR.
47-
OpFoldResult materialize(OpBuilder &b) const {
48-
if (ofr) {
49-
return ofr;
50-
}
51-
return getDim(b, val.getLoc(), val, dim);
52-
}
53-
54-
private:
55-
OpFoldResult ofr;
56-
TypedValue<ShapedType> val;
57-
int64_t dim;
58-
};
59-
} // namespace
60-
61-
static SmallVector<DimSize> getDimSizes(Value v) {
62-
auto shapedVal = cast<TypedValue<ShapedType>>(v);
63-
int64_t rank = shapedVal.getType().getRank();
64-
SmallVector<DimSize> sizes;
65-
for (int i = 0; i < rank; ++i) {
66-
sizes.emplace_back(shapedVal, i);
67-
}
68-
return sizes;
69-
}
7025

7126
static bool
7227
isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
@@ -78,7 +33,7 @@ isIdentityReassoc(const SmallVector<ReassociationIndices> &indices) {
7833
};
7934

8035
static SmallVector<ReassociationIndices>
81-
computeReassocFromShapeMap(ArrayRef<SmallVector<DimSize>> shapeMap) {
36+
computeReassocFromShapeMap(ArrayRef<SmallVector<int64_t>> shapeMap) {
8237
SmallVector<ReassociationIndices> reassoc;
8338
int64_t dimCount = 0;
8439
for (auto &shape : shapeMap) {
@@ -90,13 +45,14 @@ computeReassocFromShapeMap(ArrayRef<SmallVector<DimSize>> shapeMap) {
9045
}
9146

9247
namespace {
48+
9349
/// Helper class that supports fusing reshapes with operands when not all of the
9450
/// shape dims map to the iteration space.
9551
struct ReshapeOperandInfo {
9652
static constexpr int64_t kNoMapping = -1;
9753

9854
// Original shape of this operand.
99-
SmallVector<DimSize> originalShape;
55+
ArrayRef<int64_t> originalShape;
10056

10157
// Similar to the results of the operand's `AffineMap` except `kNoMapping` if
10258
// that dim doesn't map to the iteration space. For example, the indexed
@@ -116,7 +72,7 @@ class ExpansionInfo {
11672
SmallVector<int64_t> loopRanges,
11773
OpOperand *fusableOpOperand,
11874
ArrayRef<ReassociationIndices> operandReassoc,
119-
ArrayRef<DimSize> expandedShape);
75+
ArrayRef<int64_t> expandedShape);
12076

12177
std::optional<Value> getOrCreateExpanded(Location loc, OpOperand *operand,
12278
RewriterBase &rewriter) {
@@ -125,17 +81,13 @@ class ExpansionInfo {
12581
if (isIdentityReassoc(reassoc)) {
12682
return operand->get();
12783
}
128-
SmallVector<OpFoldResult> outputShape;
84+
SmallVector<int64_t> flattenedArray;
12985
for (auto &shape : shapeMap) {
130-
llvm::append_range(
131-
outputShape, llvm::map_range(shape, [&rewriter](const DimSize &size) {
132-
return size.materialize(rewriter);
133-
}));
86+
flattenedArray.append(shape.begin(), shape.end());
13487
}
135-
auto [staticShape, dynamicShape] = decomposeMixedValues(outputShape);
136-
(void)dynamicShape;
13788
auto oldType = cast<ShapedType>(operand->get().getType());
138-
auto newType = RankedTensorType::get(staticShape, oldType.getElementType());
89+
auto newType =
90+
RankedTensorType::get(flattenedArray, oldType.getElementType());
13991
if (failed(reshapeLikeShapesAreCompatible(
14092
[&](const Twine &msg) {
14193
return rewriter.notifyMatchFailure(loc, msg);
@@ -145,18 +97,18 @@ class ExpansionInfo {
14597
return {};
14698
}
14799
return tensor::ExpandShapeOp::create(rewriter, loc, newType, operand->get(),
148-
reassoc, outputShape);
100+
reassoc);
149101
};
150102

151103
/// Get the shape map for the operand.
152-
SmallVector<SmallVector<DimSize>> getShapeMap(OpOperand *operand) const {
104+
SmallVector<SmallVector<int64_t>> getShapeMap(OpOperand *operand) const {
153105
auto info = reshapeInfos[operand->getOperandNumber()];
154-
SmallVector<SmallVector<DimSize>> shapeMap;
106+
SmallVector<SmallVector<int64_t>> shapeMap;
155107
for (auto [operandIdx, loopIdx] :
156108
llvm::enumerate(info.operandToIterationSpace)) {
157109
if (loopIdx == ReshapeOperandInfo::kNoMapping) {
158110
shapeMap.push_back(
159-
SmallVector<DimSize>{info.originalShape[operandIdx]});
111+
SmallVector<int64_t>{info.originalShape[operandIdx]});
160112
} else {
161113
shapeMap.push_back(loopShapeMap[loopIdx]);
162114
}
@@ -174,12 +126,17 @@ class ExpansionInfo {
174126
ReassociationIndicesRef getExpandedLoops(unsigned i) const {
175127
return loopReassoc[i];
176128
}
129+
ArrayRef<int64_t> getExpandedShapeOfLoop(unsigned i) const {
130+
return loopShapeMap[i];
131+
}
177132

178133
private:
134+
/// Extent of the iteration space in the original operation.
135+
SmallVector<int64_t> loopRanges;
179136
SmallVector<ReassociationIndices> loopReassoc;
180137
/// Mapping from extent of loops in the original operation, to the extent of
181138
/// loops in the expanded operation.
182-
SmallVector<SmallVector<DimSize>> loopShapeMap;
139+
SmallVector<SmallVector<int64_t>> loopShapeMap;
183140
unsigned expandedOpNumDims;
184141
/// Info about the reassociation and original shape for each operand.
185142
SmallVector<ReshapeOperandInfo> reshapeInfos;
@@ -239,7 +196,7 @@ class CollapsingInfo {
239196
LogicalResult ExpansionInfo::compute(
240197
SmallVector<ReshapeOperandInfo> infos, SmallVector<int64_t> loopRanges,
241198
OpOperand *fusableOpOperand, ArrayRef<ReassociationIndices> operandReassoc,
242-
ArrayRef<DimSize> expandedShape) {
199+
ArrayRef<int64_t> expandedShape) {
243200
if (operandReassoc.empty())
244201
return failure();
245202

@@ -249,8 +206,7 @@ LogicalResult ExpansionInfo::compute(
249206
for (auto [operandDim, iterDim] :
250207
llvm::enumerate(info.operandToIterationSpace)) {
251208
if (iterDim != ReshapeOperandInfo::kNoMapping &&
252-
ShapedType::isStatic(loopRanges[iterDim]) !=
253-
info.originalShape[operandDim].isStatic()) {
209+
loopRanges[iterDim] != info.originalShape[operandDim]) {
254210
return failure();
255211
}
256212
}
@@ -273,22 +229,12 @@ LogicalResult ExpansionInfo::compute(
273229
}
274230
}
275231

276-
// Fill in the remaining elements.
277-
for (const ReshapeOperandInfo &info : infos) {
278-
for (auto [operandIdx, loopIdx] :
279-
llvm::enumerate(info.operandToIterationSpace)) {
280-
if (loopIdx == ReshapeOperandInfo::kNoMapping ||
281-
!this->loopShapeMap[loopIdx].empty()) {
282-
continue;
283-
}
284-
285-
this->loopShapeMap[loopIdx] =
286-
SmallVector<DimSize>{info.originalShape[operandIdx]};
287-
}
288-
}
289-
232+
// Fill in the remaining elements with `loopRanges`
290233
this->expandedOpNumDims = 0;
291-
for (const auto &shapeMap : this->loopShapeMap) {
234+
for (const auto &[loopIdx, shapeMap] : llvm::enumerate(this->loopShapeMap)) {
235+
if (shapeMap.empty()) {
236+
this->loopShapeMap[loopIdx] = SmallVector<int64_t>{loopRanges[loopIdx]};
237+
}
292238
this->expandedOpNumDims += shapeMap.size();
293239
}
294240

@@ -298,6 +244,7 @@ LogicalResult ExpansionInfo::compute(
298244
}
299245
this->loopReassoc = computeReassocFromShapeMap(this->loopShapeMap);
300246
this->reshapeInfos = std::move(infos);
247+
this->loopRanges = std::move(loopRanges);
301248
return success();
302249
}
303250

@@ -360,7 +307,7 @@ getReshapeInfo(LinalgExt::AttentionOp attentionOp) {
360307
return operandInfo;
361308
}
362309

363-
operandInfo.originalShape = getDimSizes(opOperand.get());
310+
operandInfo.originalShape = operandType.getShape();
364311
for (auto result :
365312
attentionOp.getMatchingIndexingMap(&opOperand).getResults()) {
366313
operandInfo.operandToIterationSpace.push_back(
@@ -378,13 +325,13 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
378325
auto updateRank = scatterOp.getUpdateType().getRank();
379326

380327
ReshapeOperandInfo updateInfo;
381-
updateInfo.originalShape = getDimSizes(scatterOp.getUpdates());
328+
updateInfo.originalShape = scatterOp.getUpdateType().getShape();
382329
llvm::append_range(updateInfo.operandToIterationSpace,
383330
llvm::seq<int64_t>(0, updateRank));
384331
infos.push_back(std::move(updateInfo));
385332

386333
ReshapeOperandInfo indicesInfo;
387-
indicesInfo.originalShape = getDimSizes(scatterOp.getIndices());
334+
indicesInfo.originalShape = scatterOp.getIndicesType().getShape();
388335
llvm::append_range(indicesInfo.operandToIterationSpace,
389336
llvm::seq<int64_t>(0, scatterOp.getBatchRank()));
390337
if (scatterOp.getBatchRank() != scatterOp.getIndicesType().getRank())
@@ -393,7 +340,7 @@ getReshapeInfo(LinalgExt::ScatterOp scatterOp) {
393340
infos.push_back(std::move(indicesInfo));
394341

395342
ReshapeOperandInfo originalInfo;
396-
originalInfo.originalShape = getDimSizes(scatterOp.getOriginal());
343+
originalInfo.originalShape = scatterOp.getOriginalType().getShape();
397344
originalInfo.operandToIterationSpace.append(scatterOp.getIndexDepth(),
398345
ReshapeOperandInfo::kNoMapping);
399346
llvm::append_range(originalInfo.operandToIterationSpace,
@@ -409,15 +356,15 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
409356
auto outputRank = gatherOp.getOutputType().getRank();
410357

411358
ReshapeOperandInfo sourceInfo;
412-
sourceInfo.originalShape = getDimSizes(gatherOp.getSource());
359+
sourceInfo.originalShape = gatherOp.getSourceType().getShape();
413360
sourceInfo.operandToIterationSpace.append(gatherOp.getIndexDepth(),
414361
ReshapeOperandInfo::kNoMapping);
415362
llvm::append_range(sourceInfo.operandToIterationSpace,
416363
llvm::seq(outputRank - rankOfContiguousSlice, outputRank));
417364
infos.push_back(std::move(sourceInfo));
418365

419366
ReshapeOperandInfo indicesInfo;
420-
indicesInfo.originalShape = getDimSizes(gatherOp.getIndices());
367+
indicesInfo.originalShape = gatherOp.getIndicesType().getShape();
421368
llvm::append_range(indicesInfo.operandToIterationSpace,
422369
llvm::seq<int64_t>(0, gatherOp.getBatchRank()));
423370
if (gatherOp.getBatchRank() != gatherOp.getIndicesType().getRank())
@@ -426,7 +373,7 @@ getReshapeInfo(LinalgExt::GatherOp gatherOp) {
426373
infos.push_back(std::move(indicesInfo));
427374

428375
ReshapeOperandInfo outputInfo;
429-
outputInfo.originalShape = getDimSizes(gatherOp.getOutput());
376+
outputInfo.originalShape = gatherOp.getOutputType().getShape();
430377
llvm::append_range(outputInfo.operandToIterationSpace,
431378
llvm::seq<int64_t>(0, outputRank));
432379
infos.push_back(std::move(outputInfo));
@@ -460,26 +407,15 @@ fuseWithReshapeByExpansion(OpTy op, Operation *reshapeOp,
460407
auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
461408
auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
462409
bool isExpanding = (expandingReshapeOp != nullptr);
463-
Value expandedVal = isExpanding ? expandingReshapeOp.getResult()
464-
: collapsingReshapeOp.getSrc();
465-
SmallVector<DimSize> expandedSize;
466-
if (isExpanding) {
467-
// The SSA dims must dominate `op` in order to use them to create new
468-
// expand_shape ops.
469-
if (failed(moveValueDefinitions(rewriter,
470-
expandingReshapeOp.getOutputShape(), op))) {
471-
return std::nullopt;
472-
}
473-
llvm::append_range(expandedSize, expandingReshapeOp.getMixedOutputShape());
474-
} else {
475-
expandedSize = getDimSizes(expandedVal);
476-
}
410+
RankedTensorType expandedType = isExpanding
411+
? expandingReshapeOp.getResultType()
412+
: collapsingReshapeOp.getSrcType();
477413
ExpansionInfo info;
478414
if (failed(info.compute(
479415
getReshapeInfo(op), op.getStaticLoopRanges(), fusableOpOperand,
480416
isExpanding ? expandingReshapeOp.getReassociationIndices()
481417
: collapsingReshapeOp.getReassociationIndices(),
482-
expandedSize))) {
418+
expandedType.getShape()))) {
483419
return std::nullopt;
484420
}
485421

compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/reshape_fusion.mlir

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,19 @@ util.func public @attention_dynamic(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x
196196
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
197197
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
198198
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
199+
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
199200
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divsi %[[D0]]
201+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
200202
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
203+
// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
201204
// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
202-
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D6]], %[[D2]]]
203-
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C2]]
204-
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D6]], %[[D9]]]
205-
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D9]]) : tensor<2x?x?x?xf16>
205+
// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
206+
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divsi %[[D5]], %[[C2]]
207+
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
208+
// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
209+
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
210+
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divsi %[[D8]], %[[C2]]
211+
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
206212
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
207213
// CHECK-SAME: indexing_maps =
208214
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
@@ -250,13 +256,29 @@ util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tens
250256
// CHECK-SAME: %[[ARG3:.+]]: f16
251257
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor<?x?x?xf16>)
252258
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
259+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
253260
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
254-
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]]
255-
// CHECK-DAG: %[[SPLIT:.+]] = arith.divsi %[[DIM]], %[[C2]]
256-
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
257-
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
258-
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
259-
// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT]],
261+
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
262+
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
263+
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
264+
// CHECK-DAG: %[[D4:.+]] = tensor.dim %[[ARG2]], %[[C2]]
265+
// CHECK-DAG: %[[SPLIT0:.+]] = arith.divsi %[[D0]]
266+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SPLIT0]], %[[D1]], %[[D4]]) : tensor<2x?x?x?xf16>
267+
// CHECK-DAG: %[[QUERY:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT0]], %[[D1]], %[[D2]]]
268+
// CHECK-DAG: %[[D5:.+]] = tensor.dim %[[ARG1]], %[[C0]]
269+
// CHECK-DAG: %[[D6:.+]] = tensor.dim %[[ARG1]], %[[C1]]
270+
// CHECK-DAG: %[[D7:.+]] = tensor.dim %[[ARG1]], %[[C2]]
271+
// CHECK-DAG: %[[SPLIT1:.+]] = arith.divsi %[[D5]], %[[C2]]
272+
// CHECK-DAG: %[[KEY:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT1]], %[[D6]], %[[D7]]]
273+
// CHECK-DAG: %[[D8:.+]] = tensor.dim %[[ARG2]], %[[C0]]
274+
// CHECK-DAG: %[[D9:.+]] = tensor.dim %[[ARG2]], %[[C1]]
275+
// CHECK-DAG: %[[SPLIT2:.+]] = arith.divsi %[[D8]], %[[C2]]
276+
// CHECK-DAG: %[[CACHE:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT2]], %[[D9]], %[[D4]]]
277+
// CHECK-DAG: %[[D10:.+]] = tensor.dim %[[ARG4]], %[[C0]]
278+
// CHECK-DAG: %[[D11:.+]] = tensor.dim %[[ARG4]], %[[C1]]
279+
// CHECK-DAG: %[[D12:.+]] = tensor.dim %[[ARG4]], %[[C2]]
280+
// CHECK-DAG: %[[SPLIT3:.+]] = arith.divsi %[[D10]], %[[C2]]
281+
// CHECK-DAG: %[[MASK:.+]] = tensor.expand_shape %[[ARG4]] {{\[}}[0, 1], [2], [3]{{\]}} output_shape [2, %[[SPLIT3]], %[[D11]], %[[D12]]]
260282
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
261283
// CHECK-SAME: indexing_maps =
262284
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
@@ -266,6 +288,7 @@ util.func public @attention_dynamic_masked(%arg0: tensor<?x?x?xf16>, %arg1: tens
266288
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
267289
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
268290
// CHECK-SAME: ins(%[[QUERY]], %[[KEY]], %[[CACHE]], %[[ARG3]], %[[MASK]] :
291+
// CHECK-SAME: outs(%[[EMPTY]] :
269292
// CHECK: util.return %[[ATTENTION]]
270293

271294
// -----
@@ -687,23 +710,6 @@ util.func public @scatter_collapse_noop(%arg0: tensor<10xf16>, %arg1: tensor<10x
687710

688711
// -----
689712

690-
util.func public @scatter_collapse_multiple_dynamic(%arg0 : tensor<?x?x4x32x32xf16>, %arg1 : tensor<?xi64>, %arg2 : tensor<?x4x32x32xf16>) -> (tensor<?x4x32x32xf16>){
691-
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3], [4]] : tensor<?x?x4x32x32xf16> into tensor<?x4x32x32xf16>
692-
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x4x32x32xf16>, tensor<?xi64>) outs(%arg2 : tensor<?x4x32x32xf16>) {
693-
^bb0(%arg7: f16, %arg8: f16):
694-
iree_linalg_ext.yield %arg7 : f16
695-
} -> tensor<?x4x32x32xf16>
696-
util.return %0 : tensor<?x4x32x32xf16>
697-
}
698-
// CHECK-LABEL: util.func public @scatter_collapse_multiple_dynamic
699-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]:
700-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]:
701-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]:
702-
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
703-
// CHECK-SAME: ins({{.*}} : tensor<?x?x4x32x32xf16>, tensor<?x?xi64>
704-
705-
// -----
706-
707713
util.func public @gather_expand(%arg0: tensor<100x128xf16>, %arg1: tensor<10xi32>, %arg2: tensor<10x128xf16>) -> tensor<2x5x4x32xf16> {
708714
%c0 = arith.constant 0 : index
709715
%0 = iree_linalg_ext.gather dimension_map = [0] ins(%arg0, %arg1 : tensor<100x128xf16>, tensor<10xi32>) outs(%arg2 : tensor<10x128xf16>) -> tensor<10x128xf16>

0 commit comments

Comments
 (0)