Skip to content

Commit d7d67e9

Browse files
[mlir][Transforms] Add a PadTilingInterface transformation and hook it up to the transform dialect
This revision revisits the padding transformation from first principles and prepares it to work more generally with TilingInterface. Compared to structured.transform.pad it has the following differences: - no support for nofold, copy-back, transpose and hoisting: these have been carried by the padding op in the very early days of StructuredOps and have since then been separated out as independent transformations that compose. - no conflated static bounding box derivation attempts: pad_tiling_interface composes more naturally with or without tiling. - properly derives padding size on outputs where multiple dimensions contribute: this is not supported in structured.transform.pad - geared towards supporting TilingInterface once the proper control mechanisms are supported through a PadSizeComputationFunction (supports LinalgOp by default) This will gradually replace structured.transform.pad as it is fleshed out and tested more comprehensively.
1 parent 8631b4f commit d7d67e9

File tree

6 files changed

+709
-3
lines changed

6 files changed

+709
-3
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,82 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
11861186
}];
11871187
}
11881188

1189+
//===----------------------------------------------------------------------===//
1190+
// PadTilingInterfaceOp
1191+
//===----------------------------------------------------------------------===//
1192+
1193+
def PadTilingInterfaceOp : Op<Transform_Dialect, "structured.pad_tiling_interface",
1194+
[FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1195+
TransformOpInterface,
1196+
ReportTrackingListenerFailuresOpTrait]> {
1197+
let description = [{
1198+
Pads the operations pointed to by the target handle using the options
1199+
provided as operation attributes. The operation returns a handle to the
1200+
padded operation and to the padding operation ("tensor.pad").
1201+
1202+
#### Return modes
1203+
1204+
This operation ignores non-Linalg ops and drops them in the return.
1205+
In the future, this operation will support all TilingInterfaceOps.
1206+
1207+
This operation may produce a definite failure if the padding fails for any
1208+
reason.
1209+
1210+
If all the operations referred to by the `target` handle pad properly, the
1211+
transform succeeds. Otherwise the transform produces a silenceable failure.
1212+
The return handle points to only the subset of successfully produced
1213+
padded operations, which can be empty.
1214+
}];
1215+
1216+
let arguments =
1217+
(ins TransformHandleTypeInterface:$target,
1218+
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
1219+
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
1220+
Variadic<TransformAnyParamTypeOrAnyHandle>:$padding_sizes,
1221+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
1222+
$static_padding_sizes,
1223+
DefaultValuedAttr<UnitAttr, "false">:$pad_to_multiple_of);
1224+
let results = (outs TransformHandleTypeInterface:$padded,
1225+
TransformHandleTypeInterface:$pad);
1226+
1227+
let assemblyFormat = [{
1228+
$target
1229+
`to`
1230+
(`padding_sizes` custom<DynamicIndexList>($padding_sizes, $static_padding_sizes)^)?
1231+
(`pad_to_multiple_of` $pad_to_multiple_of^)?
1232+
attr-dict
1233+
`:` functional-type(operands, results)
1234+
}];
1235+
1236+
let hasVerifier = 1;
1237+
1238+
let builders = [
1239+
// Builder for a transform::PadOp with automatic inference of padding
1240+
// value. Warning: this will set the value 0 for the inferred elemental
1241+
// type without taking the op into account and thus only work for the
1242+
// add/mul ring at the moment.
1243+
// TODO: support other operations (e.g. min, max etc).
1244+
OpBuilder<(ins "Value":$target,
1245+
"ArrayRef<int64_t>":$paddingDimensions,
1246+
CArg<"ArrayRef<int64_t>", "{}">:$staticPaddingSizes,
1247+
CArg<"bool", "false">:$padToMultipleOf)>,
1248+
OpBuilder<(ins "Value":$target,
1249+
"ArrayRef<int64_t>":$paddingDimensions,
1250+
"ArrayRef<OpFoldResult>":$mixedPadPaddingSizes,
1251+
CArg<"bool", "false">:$usePrescribedTensorShapes)>
1252+
];
1253+
1254+
let extraClassDeclaration = [{
1255+
/// Returns a mix of dynamic `padding_sizes` and static `static_padding_sizes`.
1256+
SmallVector<OpFoldResult> getMixedPaddingSizes();
1257+
1258+
::mlir::DiagnosedSilenceableFailure apply(
1259+
::mlir::transform::TransformRewriter &rewriter,
1260+
::mlir::transform::TransformResults &results,
1261+
::mlir::transform::TransformState &state);
1262+
}];
1263+
}
1264+
11891265
//===----------------------------------------------------------------------===//
11901266
// HoistPadOp
11911267
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2121
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2222
#include "mlir/Dialect/X86Vector/Transforms.h"
23+
#include "mlir/IR/OpDefinition.h"
2324
#include "mlir/IR/PatternMatch.h"
2425
#include "mlir/Interfaces/TilingInterface.h"
2526
#include "mlir/Transforms/DialectConversion.h"
@@ -347,6 +348,34 @@ struct LinalgPaddingOptions {
347348
}
348349
};
349350

351+
struct PadTilingInterfaceOptions {
352+
/// A padding value for every operand.
353+
SmallVector<Attribute> paddingValues;
354+
PadTilingInterfaceOptions &setPaddingValues(ArrayRef<Attribute> pv) {
355+
paddingValues.assign(pv.begin(), pv.end());
356+
return *this;
357+
}
358+
/// A list of iterator dimensions to pad.
359+
SmallVector<int64_t> paddingDimensions;
360+
PadTilingInterfaceOptions &setPaddingDimensions(ArrayRef<int64_t> pd) {
361+
paddingDimensions.assign(pd.begin(), pd.end());
362+
return *this;
363+
}
364+
/// A list of iterator dimensions sizes to pad to.
365+
SmallVector<OpFoldResult> paddingSizes;
366+
PadTilingInterfaceOptions &setPaddingSizes(ArrayRef<OpFoldResult> m) {
367+
paddingSizes.assign(m.begin(), m.end());
368+
return *this;
369+
}
370+
/// Pad iterator `paddingDimension[i]` to next multiple of `paddingSizes[i]`
371+
/// if true. Otherwise pad to `paddingSizes[i]`.
372+
bool padToMultipleOf;
373+
PadTilingInterfaceOptions &setPadToMultipleOf(bool b) {
374+
padToMultipleOf = b;
375+
return *this;
376+
}
377+
};
378+
350379
/// Callback function type used to perform the allocation for the promoted
351380
/// `subView`. In `boundingSubViewsize` a best attempt is made to find the
352381
/// smallest constant value for the size of the buffer needed for each
@@ -542,9 +571,9 @@ SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op);
542571
/// where relevant.
543572
void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
544573

545-
/// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands
546-
/// to a static bounding box. The original `opToPad` is cloned and operates on
547-
/// the padded tensors.
574+
/// Pad the iterator dimensions `options.paddingDimensions` of all `opToPad`
575+
/// operands to a static bounding box. The original `opToPad` is cloned and
576+
/// operates on the padded tensors.
548577
///
549578
/// * "options.padToMultipleOf" indicates that each padding dimension should be
550579
/// padded to the specified multiple.
@@ -561,6 +590,50 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
561590
SmallVector<Value> &replacements,
562591
SmallVector<tensor::PadOp> &padOps);
563592

593+
/// Helper function to compute the padded shape of the given value `v` of
594+
/// `RankedTensorType` given:
595+
/// - the `indexingSizes` as a list of OpFoldResult.
596+
/// - an `indexingMap` that encodes how the padded shape varies with
597+
/// increases in `indexingSizes`.
598+
/// The implementation iteratively combines increases from contributing using
599+
/// affine.apply operations.
600+
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps and
601+
/// provides a gentle portability path for Linalg-like ops with affine maps.
602+
/// In the future, more general interfaces can be devised to encode similar
603+
/// shape evolutions and map between an op and its operands.
604+
SmallVector<OpFoldResult>
605+
computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
606+
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
607+
const PadTilingInterfaceOptions &options);
608+
609+
using PadSizeComputationFunction =
610+
std::function<FailureOr<SmallVector<OpFoldResult>>(
611+
RewriterBase &, OpOperand &, ArrayRef<Range>,
612+
const PadTilingInterfaceOptions &)>;
613+
614+
/// Specific helper for Linalg ops.
615+
FailureOr<SmallVector<OpFoldResult>>
616+
computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
617+
ArrayRef<Range> iterationDomain,
618+
const PadTilingInterfaceOptions &options);
619+
620+
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
621+
///
622+
/// * "options.paddingSizes" indicates that each padding dimension should be
623+
/// padded to the specified padding size.
624+
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
625+
// interpreted as the bounding box (dynamic) value to pad to.
626+
/// * Use "options.paddingValues" to set the padding value of the created
627+
// tensor::PadOp.
628+
/// * The tensor::PadOp is returned on success.
629+
630+
FailureOr<TilingInterface>
631+
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
632+
const PadTilingInterfaceOptions &constOptions,
633+
SmallVector<tensor::PadOp> &padOps,
634+
PadSizeComputationFunction computePaddingSizeFun =
635+
&computeLinalgPaddedShape);
636+
564637
namespace detail {
565638

566639
/// Helper struct to hold the results of building a packing loop nest.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "llvm/ADT/ScopeExit.h"
4646
#include "llvm/ADT/TypeSwitch.h"
4747
#include "llvm/Support/Debug.h"
48+
#include "llvm/Support/LogicalResult.h"
4849
#include <type_traits>
4950

5051
using namespace mlir;
@@ -2155,6 +2156,166 @@ LogicalResult transform::PadOp::verify() {
21552156
return success();
21562157
}
21572158

2159+
//===---------------------------------------------------------------------===//
2160+
// PadTilingInterfaceOp
2161+
//===---------------------------------------------------------------------===//
2162+
2163+
void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2164+
OperationState &result,
2165+
Value target,
2166+
ArrayRef<int64_t> paddingDimensions,
2167+
ArrayRef<int64_t> paddingSizes,
2168+
bool padToMultipleOf) {
2169+
auto resultType = transform::AnyOpType::get(b.getContext());
2170+
return build(/*builder=*/b,
2171+
/*result=*/result,
2172+
/*types=*/TypeRange{resultType, resultType},
2173+
/*target=*/target,
2174+
/*paddingValues=*/ArrayAttr(), // let inference handle this
2175+
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
2176+
/*paddingSizes=*/ValueRange{},
2177+
/*paddingSizes=*/
2178+
(paddingSizes.empty() ? DenseI64ArrayAttr()
2179+
: b.getDenseI64ArrayAttr(paddingSizes)),
2180+
/*padToMultipleOf=*/
2181+
padToMultipleOf ? b.getUnitAttr() : nullptr);
2182+
}
2183+
2184+
void transform::PadTilingInterfaceOp::build(
2185+
OpBuilder &b, OperationState &result, Value target,
2186+
ArrayRef<int64_t> paddingDimensions,
2187+
ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2188+
auto resultType = transform::AnyOpType::get(b.getContext());
2189+
SmallVector<int64_t> staticPaddingSizes;
2190+
SmallVector<Value> dynamicPaddingSizes;
2191+
dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2192+
staticPaddingSizes);
2193+
return build(/*builder=*/b,
2194+
/*result=*/result,
2195+
/*types=*/TypeRange{resultType, resultType},
2196+
/*target=*/target,
2197+
/*paddingValues=*/ArrayAttr(), // let inference handle this
2198+
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
2199+
/*paddingSizes=*/dynamicPaddingSizes,
2200+
/*paddingSizes=*/staticPaddingSizes,
2201+
/*usePrescribedTensorShapes=*/padToMultipleOf);
2202+
}
2203+
2204+
void transform::PadTilingInterfaceOp::getEffects(
2205+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2206+
consumesHandle(getTargetMutable(), effects);
2207+
onlyReadsHandle(getPaddingSizesMutable(), effects);
2208+
producesHandle(getOperation()->getOpResults(), effects);
2209+
modifiesPayload(effects);
2210+
}
2211+
2212+
SmallVector<OpFoldResult>
2213+
transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2214+
Builder b(getContext());
2215+
return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2216+
}
2217+
2218+
DiagnosedSilenceableFailure
2219+
transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2220+
transform::TransformResults &results,
2221+
transform::TransformState &state) {
2222+
SmallVector<Operation *> paddedOps, padOps;
2223+
2224+
for (Operation *target : state.getPayloadOps(getTarget())) {
2225+
auto targetOp = dyn_cast<TilingInterface>(target);
2226+
if (!targetOp) {
2227+
auto diag = emitSilenceableError() << "expected TilingInterface target";
2228+
diag.attachNote(target->getLoc()) << "target op";
2229+
return diag;
2230+
}
2231+
2232+
// Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
2233+
// map / C++ APIs to compute the effect of padding on operands.
2234+
if (!isa<LinalgOp>(targetOp.getOperation())) {
2235+
auto diag = emitSilenceableError() << "only LinalgOp supported atm";
2236+
diag.attachNote(target->getLoc()) << "target op";
2237+
return diag;
2238+
}
2239+
2240+
// Convert the padding values to attributes.
2241+
SmallVector<Attribute> paddingValues;
2242+
for (auto const &it :
2243+
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2244+
auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
2245+
if (!attr) {
2246+
emitOpError("expects padding values to be typed attributes");
2247+
return DiagnosedSilenceableFailure::definiteFailure();
2248+
}
2249+
Type elementType = getElementTypeOrSelf(std::get<1>(it));
2250+
// Try to parse string attributes to obtain an attribute of element type.
2251+
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2252+
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2253+
stringAttr, getContext(), elementType,
2254+
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2255+
if (!parsedAttr || parsedAttr.getType() != elementType) {
2256+
auto diag = this->emitOpError("expects a padding that parses to ")
2257+
<< elementType << ", got " << std::get<0>(it);
2258+
diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2259+
return DiagnosedSilenceableFailure::definiteFailure();
2260+
}
2261+
paddingValues.push_back(parsedAttr);
2262+
continue;
2263+
}
2264+
// Otherwise, add the attribute directly.
2265+
if (attr.getType() != elementType) {
2266+
auto diag = this->emitOpError("expects a padding value of type ")
2267+
<< elementType << ", got " << attr;
2268+
diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2269+
return DiagnosedSilenceableFailure::definiteFailure();
2270+
}
2271+
paddingValues.push_back(attr);
2272+
}
2273+
2274+
// Set options.
2275+
TilingInterface paddedOp;
2276+
PadTilingInterfaceOptions options;
2277+
options.setPaddingValues(paddingValues)
2278+
.setPaddingDimensions(
2279+
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
2280+
.setPaddingSizes(getMixedPaddingSizes())
2281+
.setPadToMultipleOf(getPadToMultipleOf());
2282+
2283+
// Apply padding.
2284+
SmallVector<tensor::PadOp> newPadOps;
2285+
FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2286+
rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2287+
newPadOps);
2288+
if (failed(maybePaddedOp)) {
2289+
auto diag = emitSilenceableError() << "failed to pad op";
2290+
diag.attachNote(target->getLoc()) << "target op";
2291+
return diag;
2292+
}
2293+
2294+
// Set transform results.
2295+
paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2296+
padOps.append(newPadOps.begin(), newPadOps.end());
2297+
}
2298+
2299+
results.set(cast<OpResult>(getPadded()), paddedOps);
2300+
results.set(cast<OpResult>(getPad()), padOps);
2301+
return DiagnosedSilenceableFailure::success();
2302+
}
2303+
2304+
LogicalResult transform::PadTilingInterfaceOp::verify() {
2305+
SmallVector<int64_t> paddingDimensions =
2306+
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2307+
if (any_of(paddingDimensions,
2308+
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
2309+
return emitOpError() << "expects padding_dimensions to contain positive "
2310+
"integers, found "
2311+
<< getPaddingDimensions();
2312+
}
2313+
if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
2314+
return emitOpError() << "expects as many multiples as padding_dimensions";
2315+
}
2316+
return success();
2317+
}
2318+
21582319
//===---------------------------------------------------------------------===//
21592320
// HoistPadOp
21602321
//===---------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2929
BlockPackMatmul.cpp
3030
PackAndUnpackPatterns.cpp
3131
Padding.cpp
32+
PadTilingInterface.cpp
3233
Promotion.cpp
3334
RuntimeOpVerification.cpp
3435
Specialize.cpp

0 commit comments

Comments
 (0)