Skip to content

Commit 1742a95

Browse files
Rename and add a test
1 parent e1cde0d commit 1742a95

File tree

5 files changed

+63
-26
lines changed

5 files changed

+63
-26
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,7 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
222222
];
223223
}
224224

225-
def DomainAndOperandsAffineMapTransferInterface
226-
: OpInterface<"DomainAndOperandsAffineMapTransferInterface"> {
225+
def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
227226
let description = [{
228227
Interface for operations that connect an iteration domain to operands via
229228
affine maps. Provides methods to access indexing maps between iteration
@@ -274,7 +273,7 @@ def DomainAndOperandsAffineMapTransferInterface
274273
def LinalgStructuredInterface
275274
: OpInterface<"LinalgOp", [
276275
DestinationStyleOpInterface,
277-
DomainAndOperandsAffineMapTransferInterface
276+
IndexingMapOpInterface
278277
]> {
279278
let cppNamespace = "::mlir::linalg";
280279
let methods = [

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,7 @@ using PadSizeComputationFunction =
612612
const PadTilingInterfaceOptions &)>;
613613

614614
/// Specific helper for Linalg ops.
615-
FailureOr<SmallVector<OpFoldResult>>
616-
computeDomainAndOperandsAffineMapTransferInterfacePaddedShape(
615+
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
617616
RewriterBase &rewriter, OpOperand &operandToPad,
618617
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
619618

@@ -627,12 +626,12 @@ computeDomainAndOperandsAffineMapTransferInterfacePaddedShape(
627626
// tensor::PadOp.
628627
/// * The tensor::PadOp is returned on success.
629628

630-
FailureOr<TilingInterface> rewriteAsPaddedOp(
631-
RewriterBase &rewriter, TilingInterface opToPad,
632-
const PadTilingInterfaceOptions &constOptions,
633-
SmallVector<tensor::PadOp> &padOps,
634-
PadSizeComputationFunction computePaddingSizeFun =
635-
&computeDomainAndOperandsAffineMapTransferInterfacePaddedShape);
629+
FailureOr<TilingInterface>
630+
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
631+
const PadTilingInterfaceOptions &constOptions,
632+
SmallVector<tensor::PadOp> &padOps,
633+
PadSizeComputationFunction computePaddingSizeFun =
634+
&computeIndexingMapOpInterfacePaddedShape);
636635

637636
namespace detail {
638637

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,14 +2229,12 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
22292229
return diag;
22302230
}
22312231

2232-
// Only DomainAndOperandsAffineMapTransferInterface ops for now, until
2233-
// TilingInterface exposes a loopsToOperand map / C++ APIs to compute the
2234-
// effect of padding on operands.
2235-
if (!isa<DomainAndOperandsAffineMapTransferInterface>(
2236-
targetOp.getOperation())) {
2237-
auto diag = emitSilenceableError()
2238-
<< "only DomainAndOperandsAffineMapTransferInterface ops "
2239-
"supported atm";
2232+
// Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2233+
// loopsToOperand map / C++ APIs to compute the effect of padding on
2234+
// operands.
2235+
if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2236+
auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2237+
"supported atm";
22402238
diag.attachNote(target->getLoc()) << "target op";
22412239
return diag;
22422240
}

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,11 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
156156
}
157157

158158
FailureOr<SmallVector<OpFoldResult>>
159-
linalg::computeDomainAndOperandsAffineMapTransferInterfacePaddedShape(
159+
linalg::computeIndexingMapOpInterfacePaddedShape(
160160
RewriterBase &rewriter, OpOperand &operandToPad,
161161
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
162-
auto transferOp = llvm::dyn_cast<DomainAndOperandsAffineMapTransferInterface>(
163-
operandToPad.getOwner());
162+
auto transferOp =
163+
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
164164
if (!transferOp)
165165
return failure();
166166

@@ -257,7 +257,18 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
257257
SmallVector<Value> newOperands;
258258
newOperands.reserve(opToPad->getNumOperands());
259259
for (OpOperand &opOperand : opToPad->getOpOperands()) {
260-
LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n");
260+
Value operand = opOperand.get();
261+
LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
262+
263+
// 2.a. Skip scalar-like operands.
264+
Type operandType = operand.getType();
265+
if (!isa<RankedTensorType>(operandType)) {
266+
assert(!isa<ShapedType>(operandType) ||
267+
isa<VectorType>(operandType) &&
268+
"Unexpected non-vector ShapedType");
269+
newOperands.push_back(operand);
270+
continue;
271+
}
261272
// 2.a. Compute padded shape.
262273
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
263274
computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
@@ -268,14 +279,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
268279
// 2.b. Expect proper `paddingValues`.
269280
// TODO: we may want to allow garbage padding in the future, in which case
270281
// we would just not assert.
271-
assert(opOperand.getOperandNumber() < options.paddingValues.size() &&
272-
"--no padding value specified");
282+
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
283+
return rewriter.notifyMatchFailure(opToPad,
284+
"--no padding value specified");
285+
}
273286
Attribute paddingValueAttr =
274287
options.paddingValues[opOperand.getOperandNumber()];
275288

276289
// 2.c. Perform actual padding.
277290
Value paddedOperand = padOperand(
278-
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(opOperand.get()),
291+
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
279292
*maybePaddedShape, paddingValueAttr);
280293
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
281294

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
// RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s
22

3+
// CHECK-LABEL: pad_lhs
4+
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
5+
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
6+
{
7+
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
8+
func.return %0 : tensor<24x25xf32>
9+
}
10+
11+
module attributes {transform.with_named_sequence} {
12+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
13+
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1
14+
: (!transform.any_op) -> !transform.any_op
15+
16+
// Tile to 5 then pad to 8
17+
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
18+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
19+
20+
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
21+
padding_values=[0.0 : f32, 0.0 : f32],
22+
padding_dimensions=[0]
23+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
24+
25+
transform.yield
26+
}
27+
}
28+
29+
// -----
30+
331
// CHECK-LABEL: pad_lhs
432
func.func @pad_lhs(
533
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)

0 commit comments

Comments
 (0)