Skip to content

Commit e1cde0d

Browse files
[mlir][transform] NFC - extract a minimal DomainAndOperandsAffineMapTransferInterface out of LinalgStructuredInterface and use that for PadTilingInterface
1 parent 6edf2eb commit e1cde0d

File tree

4 files changed

+75
-54
lines changed

4 files changed

+75
-54
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,60 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
222222
];
223223
}
224224

225+
def DomainAndOperandsAffineMapTransferInterface
226+
: OpInterface<"DomainAndOperandsAffineMapTransferInterface"> {
227+
let description = [{
228+
Interface for operations that connect an iteration domain to operands via
229+
affine maps. Provides methods to access indexing maps between iteration
230+
domain and operand index spaces.
231+
}];
232+
let cppNamespace = "::mlir::linalg";
233+
let methods = [
234+
InterfaceMethod<
235+
/*desc=*/[{
236+
Return the indexing maps attribute within the current operation.
237+
}],
238+
/*retTy=*/"ArrayAttr",
239+
/*methodName=*/"getIndexingMaps"
240+
>,
241+
InterfaceMethod<
242+
/*desc=*/[{
243+
Return the indexing maps within the current operation.
244+
}],
245+
/*retTy=*/"SmallVector<AffineMap>",
246+
/*methodName=*/"getIndexingMapsArray",
247+
/*args=*/(ins),
248+
/*methodBody=*/"",
249+
/*defaultImplementation=*/[{
250+
auto range = $_op.getIndexingMaps()
251+
.template getAsValueRange<AffineMapAttr>();
252+
return {range.begin(), range.end()};
253+
}]
254+
>,
255+
InterfaceMethod<
256+
/*desc=*/[{
257+
Return the input or output indexing map for `opOperand`.
258+
}],
259+
/*retTy=*/"AffineMap",
260+
/*methodName=*/"getMatchingIndexingMap",
261+
/*args=*/(ins "OpOperand*":$opOperand),
262+
/*methodBody=*/"",
263+
/*defaultImplementation=*/[{
264+
assert(opOperand->getOwner() == this->getOperation());
265+
auto indexingMaps =
266+
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
267+
return *(indexingMaps.begin() + opOperand->getOperandNumber());
268+
}]
269+
>,
270+
];
271+
}
272+
225273
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
226274
def LinalgStructuredInterface
227-
: OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
275+
: OpInterface<"LinalgOp", [
276+
DestinationStyleOpInterface,
277+
DomainAndOperandsAffineMapTransferInterface
278+
]> {
228279
let cppNamespace = "::mlir::linalg";
229280
let methods = [
230281
//===------------------------------------------------------------------===//
@@ -465,21 +516,6 @@ def LinalgStructuredInterface
465516
blockArgument.getArgNumber());
466517
}]
467518
>,
468-
InterfaceMethod<
469-
/*desc=*/[{
470-
Return the input or output indexing map for `opOperand`.
471-
}],
472-
/*retTy=*/"AffineMap",
473-
/*methodName=*/"getMatchingIndexingMap",
474-
/*args=*/(ins "OpOperand*":$opOperand),
475-
/*methodBody=*/"",
476-
/*defaultImplementation=*/[{
477-
assert(opOperand->getOwner() == this->getOperation());
478-
auto indexingMaps =
479-
$_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
480-
return *(indexingMaps.begin() + opOperand->getOperandNumber());
481-
}]
482-
>,
483519
InterfaceMethod<
484520
/*desc=*/[{
485521
Return the indexing map for a `result`.
@@ -576,27 +612,6 @@ def LinalgStructuredInterface
576612
/*methodBody=*/"",
577613
/*defaultImplementation=*/[{ return success(); }]
578614
>,
579-
InterfaceMethod<
580-
/*desc=*/[{
581-
Return the indexing maps attribute within the current operation.
582-
}],
583-
/*retTy=*/"ArrayAttr",
584-
/*methodName=*/"getIndexingMaps"
585-
>,
586-
InterfaceMethod<
587-
/*desc=*/[{
588-
Return the indexing maps within the current operation.
589-
}],
590-
/*retTy=*/"SmallVector<AffineMap>",
591-
/*methodName=*/"getIndexingMapsArray",
592-
/*args=*/(ins),
593-
/*methodBody=*/"",
594-
/*defaultImplementation=*/[{
595-
auto range = $_op.getIndexingMaps()
596-
.template getAsValueRange<AffineMapAttr>();
597-
return {range.begin(), range.end()};
598-
}]
599-
>,
600615
InterfaceMethod<
601616
/*desc=*/[{
602617
Return true if any of the operands has a dynamic shape.

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,9 @@ using PadSizeComputationFunction =
613613

614614
/// Specific helper for Linalg ops.
615615
FailureOr<SmallVector<OpFoldResult>>
616-
computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
617-
ArrayRef<Range> iterationDomain,
618-
const PadTilingInterfaceOptions &options);
616+
computeDomainAndOperandsAffineMapTransferInterfacePaddedShape(
617+
RewriterBase &rewriter, OpOperand &operandToPad,
618+
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
619619

620620
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
621621
///
@@ -627,12 +627,12 @@ computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
627627
// tensor::PadOp.
628628
/// * The tensor::PadOp is returned on success.
629629

630-
FailureOr<TilingInterface>
631-
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
632-
const PadTilingInterfaceOptions &constOptions,
633-
SmallVector<tensor::PadOp> &padOps,
634-
PadSizeComputationFunction computePaddingSizeFun =
635-
&computeLinalgPaddedShape);
630+
FailureOr<TilingInterface> rewriteAsPaddedOp(
631+
RewriterBase &rewriter, TilingInterface opToPad,
632+
const PadTilingInterfaceOptions &constOptions,
633+
SmallVector<tensor::PadOp> &padOps,
634+
PadSizeComputationFunction computePaddingSizeFun =
635+
&computeDomainAndOperandsAffineMapTransferInterfacePaddedShape);
636636

637637
namespace detail {
638638

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,10 +2229,14 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
22292229
return diag;
22302230
}
22312231

2232-
// Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
2233-
// map / C++ APIs to compute the effect of padding on operands.
2234-
if (!isa<LinalgOp>(targetOp.getOperation())) {
2235-
auto diag = emitSilenceableError() << "only LinalgOp supported atm";
2232+
// Only DomainAndOperandsAffineMapTransferInterface ops for now, until
2233+
// TilingInterface exposes a loopsToOperand map / C++ APIs to compute the
2234+
// effect of padding on operands.
2235+
if (!isa<DomainAndOperandsAffineMapTransferInterface>(
2236+
targetOp.getOperation())) {
2237+
auto diag = emitSilenceableError()
2238+
<< "only DomainAndOperandsAffineMapTransferInterface ops "
2239+
"supported atm";
22362240
diag.attachNote(target->getLoc()) << "target op";
22372241
return diag;
22382242
}

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,13 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
155155
return paddedShape;
156156
}
157157

158-
FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
158+
FailureOr<SmallVector<OpFoldResult>>
159+
linalg::computeDomainAndOperandsAffineMapTransferInterfacePaddedShape(
159160
RewriterBase &rewriter, OpOperand &operandToPad,
160161
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
161-
auto linalgOp = llvm::dyn_cast<LinalgOp>(operandToPad.getOwner());
162-
if (!linalgOp)
162+
auto transferOp = llvm::dyn_cast<DomainAndOperandsAffineMapTransferInterface>(
163+
operandToPad.getOwner());
164+
if (!transferOp)
163165
return failure();
164166

165167
// clang-format off
@@ -173,7 +175,7 @@ FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
173175
for (const Range &range : iterationDomain)
174176
loopUpperBounds.push_back(range.size);
175177

176-
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad);
178+
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
177179
return computePaddedShape(
178180
rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
179181
indexingMap, loopUpperBounds, options);

0 commit comments

Comments
 (0)