Skip to content

Commit 75d320c

Browse files
Address comments (round 1).
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 35bc469 commit 75d320c

File tree

7 files changed

+66
-33
lines changed

7 files changed

+66
-33
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace tensor {
3131
FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
3232
OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
3333

34-
/// Method to swap an `tensor.insert_slice`s with its consumer when the
34+
/// Method to swap `tensor.insert_slice`s with their consumers when the
3535
/// consumer implements the `TilingInterface`. The size of `sliceOps` and
3636
/// `consumerOperands` is expected to be the same. Every entry in
3737
/// `consumerOperands` represents a use of the the corresponding

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
202202
InterfaceMethod<
203203
/*desc=*/[{
204204
Method to generate the tiled implementation of an operation that uses
205-
exactly tiles of the given operands.
205+
the exact tiles of the given operands.
206206

207207
This method is required to allow operations to be "tiled and fused"
208208
with an (already tiled) producer. Given tiles of the producer, this

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@
2222
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2323
#include "mlir/Interfaces/TilingInterface.h"
2424
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
25+
#include "llvm/Support/Debug.h"
2526
#include <optional>
2627

28+
#define DEBUG_TYPE "linalg-tiling-interface-impl"
29+
2730
using namespace mlir;
2831
using namespace mlir::linalg;
2932

@@ -170,9 +173,11 @@ struct LinalgOpTilingInterface
170173
OpFoldResult seenOffset = it->second;
171174
OpFoldResult seenSize = mappedSizes.lookup(position);
172175
if (seenOffset != offset || seenSize != size) {
173-
return linalgOp->emitOpError(
174-
"inconsistent iteration space mapping from offsets/sizes of "
175-
"operands/results");
176+
LLVM_DEBUG({
177+
llvm::dbgs() << "inconsistent iteration space mapping from "
178+
"offsets/sizes of operands/results";
179+
});
180+
return failure();
176181
}
177182
} else {
178183
mappedOffsets[position] = offset;
@@ -874,8 +879,11 @@ struct PackOpTiling
874879
ArrayRef<SmallVector<OpFoldResult>> allSizes,
875880
SmallVectorImpl<OpFoldResult> &resultOffsets,
876881
SmallVectorImpl<OpFoldResult> &resultSizes) const {
877-
if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
878-
return op->emitOpError("unsupporeted operands for consumer fusion");
882+
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
883+
LLVM_DEBUG(
884+
{ llvm::dbgs() << "unsupported operands for consumer fusion"; });
885+
return failure();
886+
}
879887

880888
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
881889
ArrayRef<OpFoldResult> sizes(allSizes[0]);
@@ -943,8 +951,11 @@ struct PackOpTiling
943951
Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
944952
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
945953
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
946-
if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
947-
return op->emitOpError("unhandled operands for consumer fusion");
954+
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
955+
LLVM_DEBUG(
956+
{ llvm ::dbgs() << "unhandled operands for consumer fusion"; });
957+
return failure();
958+
}
948959

949960
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
950961
ArrayRef<OpFoldResult> sizes(allSizes[0]);
@@ -1228,7 +1239,8 @@ struct UnPackOpTiling
12281239
SmallVectorImpl<OpFoldResult> &resultOffsets,
12291240
SmallVectorImpl<OpFoldResult> &resultSizes) const {
12301241
if (operandNumbers.size() != 1) {
1231-
return op->emitOpError("unable to handle multiple operands");
1242+
LLVM_DEBUG({ llvm::dbgs() << "unable to handle multiple operands"; });
1243+
return failure();
12321244
}
12331245
auto unPackOp = cast<UnPackOp>(op);
12341246
unsigned operandNumber = operandNumbers[0];
@@ -1293,7 +1305,8 @@ struct UnPackOpTiling
12931305
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
12941306
ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
12951307
if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1296-
return op->emitOpError("unhandled operands for consumer fusion");
1308+
LLVM_DEBUG({ llvm::dbgs() << "unhandled operands for consumer fusion"; });
1309+
return failure();
12971310
}
12981311
auto unPackOp = cast<UnPackOp>(op);
12991312
ArrayRef<OpFoldResult> offsets(allOffsets[0]);

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,15 +2060,16 @@ static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
20602060
[&](auto op) {
20612061
return getUntiledConsumerFromSlice(rewriter, op, loops);
20622062
})
2063-
.Default([](Operation *op) {
2064-
return op->emitOpError("unhandled slice type");
2063+
.Default([&](Operation *op) {
2064+
return rewriter.notifyMatchFailure(op, "unhandled slice type");
20652065
});
20662066
if (failed(fusedOperand)) {
20672067
return failure();
20682068
}
20692069
if (!fusedOperands.empty() &&
20702070
fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2071-
return fusedOperands.front()->getOwner()->emitOpError(
2071+
return rewriter.notifyMatchFailure(
2072+
fusedOperand.value()->getOwner(),
20722073
"all candidate slices must be to the same consumer");
20732074
}
20742075
fusedOperands.push_back(fusedOperand.value());
@@ -2125,23 +2126,23 @@ mlir::scf::tileAndFuseConsumerOfSlices(
21252126
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
21262127
MutableArrayRef<LoopLikeOpInterface> loops) {
21272128
if (candidateSlices.empty()) {
2128-
return emitError(rewriter.getUnknownLoc(),
2129-
"no candidate slices provided for consumer fusion");
2129+
return rewriter.notifyMatchFailure(
2130+
rewriter.getUnknownLoc(),
2131+
"no candidate slices provided for consumer fusion");
21302132
}
21312133
// Return if `loops` is empty, return an error for now. Caller is expected
21322134
// to handle this case.
21332135
if (loops.empty()) {
2134-
return candidateSlices.front()->emitOpError(
2136+
return rewriter.notifyMatchFailure(
2137+
candidateSlices.front(),
21352138
"cannot call tile and fuse consumer with an empty loop nest");
21362139
}
21372140

2138-
if (!(llvm::all_of(
2139-
candidateSlices,
2140-
[](Operation *op) { return isa<tensor::InsertSliceOp>(op); }) ||
2141-
llvm::all_of(candidateSlices, [](Operation *op) {
2142-
return isa<tensor::ParallelInsertSliceOp>(op);
2143-
}))) {
2144-
return candidateSlices.front()->emitOpError(
2141+
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2142+
llvm::all_of(candidateSlices,
2143+
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2144+
return rewriter.notifyMatchFailure(
2145+
candidateSlices.front(),
21452146
"candidates slices need to be all `tensor.extract_slice`s or "
21462147
"`tensor.parallel_insert_slice`s");
21472148
}
@@ -2261,8 +2262,14 @@ mlir::scf::tileAndFuseConsumerOfSlices(
22612262
for (auto candidateSliceOp : clonedInsertSlices) {
22622263
SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
22632264
SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
2265+
SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
22642266

22652267
// 9. Check all insert stride is 1.
2268+
if (!llvm::all_of(strides, isOneInteger)) {
2269+
return rewriter.notifyMatchFailure(
2270+
candidateSliceOp, "containingOp's result yield with stride");
2271+
}
2272+
22662273
allOffsets.emplace_back(std::move(offsets));
22672274
allSizes.emplace_back(std::move(sizes));
22682275
}

mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1818
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1919
#include "mlir/Interfaces/TilingInterface.h"
20+
#include "llvm/Support/Debug.h"
21+
22+
#define DEBUG_TYPE "tensor-swap-slices"
2023

2124
using namespace mlir;
2225

@@ -43,21 +46,28 @@ FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
4346
OpBuilder &builder, ArrayRef<tensor::InsertSliceOp> sliceOps,
4447
ArrayRef<OpOperand *> consumerOperands) {
4548
if (sliceOps.empty()) {
46-
return emitError(builder.getUnknownLoc(),
47-
"expected candidate slices list to be non-empty");
49+
LLVM_DEBUG(
50+
{ llvm::dbgs() << "expected candidate slices list to be non-empty"; });
51+
return failure();
4852
}
4953
if (sliceOps.size() != consumerOperands.size()) {
50-
return sliceOps.front()->emitOpError(
51-
"expected as many operands as the number of slices passed");
54+
LLVM_DEBUG({
55+
llvm::dbgs()
56+
<< "expected as many operands as the number of slices passed";
57+
});
58+
return failure();
5259
}
5360
auto consumerOp =
5461
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
5562
if (!consumerOp)
5663
return failure();
5764
for (auto opOperand : consumerOperands.drop_front()) {
5865
if (opOperand->getOwner() != consumerOp) {
59-
return consumerOp->emitOpError(
60-
"expected all consumer operands to be from the same operation");
66+
LLVM_DEBUG({
67+
llvm::dbgs()
68+
<< "expected all consumer operands to be from the same operation";
69+
});
70+
return failure();
6171
}
6272
}
6373

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ module attributes {transform.with_named_sequence} {
688688

689689
// -----
690690

691-
// Check that when the given operand tiles are incosistent, tiling fails.
691+
// Check that when the given operand tiles are inconsistent, tiling fails.
692692

693693
func.func @multi_slice_fusion2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>, %arg2 : tensor<?xf32>, %arg3 : index) -> tensor<?xf32> {
694694
%c0 = arith.constant 0 : index
@@ -881,14 +881,14 @@ func.func @multi_slice_fusion_invalid(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<
881881
linalg.yield %0: f32
882882
} -> tensor<?x?xf32>
883883
scf.forall.in_parallel {
884+
// expected-error @below {{failed to fuse consumer of slice}}
884885
tensor.parallel_insert_slice %generic0 into %init0[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
885886
: tensor<?x?xf32> into tensor<?x?xf32>
886887
tensor.parallel_insert_slice %generic1 into %init1[%iv0, %iv1] [%tilesize0, %tilesize1] [1, 1]
887888
: tensor<?x?xf32> into tensor<?x?xf32>
888889
}
889890
}
890891
%empty = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
891-
// expected-error @below {{inconsistent iteration space mapping from offsets/sizes of operands/results}}
892892
%result = linalg.generic {
893893
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
894894
iterator_types = ["parallel", "parallel"]}

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include "mlir/IR/Dominance.h"
2222
#include "mlir/IR/OpImplementation.h"
2323
#include "mlir/Interfaces/TilingInterface.h"
24+
#include "llvm/Support/Debug.h"
25+
26+
#define DEBUG_TYPE "test-tiling-interface"
2427

2528
#define GET_OP_CLASSES
2629
#include "TestTilingInterfaceTransformOps.h.inc"
@@ -182,7 +185,7 @@ static LogicalResult applyFuseConsumer(
182185
scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops);
183186

184187
if (failed(fuseConsumerResults))
185-
return failure();
188+
return slices.front()->emitOpError("failed to fuse consumer of slice");
186189

187190
// Report back the relevant handles to the transform op.
188191
for (OpOperand *origConsumerOperand :

0 commit comments

Comments
 (0)