Skip to content

Commit f2f65ed

Browse files
srcarrollftynse
andauthored
[mlir][transform] Add support for transform.param pad multiples in PadOp (#90755)
This patch modifies the definition of `PadOp` to take transform params and handles for the `pad_to_multiple_of` operand. --------- Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent 4dede5e commit f2f65ed

File tree

5 files changed

+205
-66
lines changed

5 files changed

+205
-66
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,8 @@ def PackTransposeOp : Op<Transform_Dialect, "structured.pack_transpose", [
978978
//===----------------------------------------------------------------------===//
979979

980980
def PadOp : Op<Transform_Dialect, "structured.pad",
981-
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
982-
DeclareOpInterfaceMethods<TransformOpInterface>,
981+
[FunctionalStyleTransformOpTrait, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
982+
TransformOpInterface,
983983
ReportTrackingListenerFailuresOpTrait]> {
984984
let description = [{
985985
Pads the operations pointed to by the target handle using the options
@@ -1011,7 +1011,9 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
10111011
(ins TransformHandleTypeInterface:$target,
10121012
DefaultValuedAttr<ArrayAttr, "{}">:$padding_values,
10131013
DefaultValuedAttr<I64ArrayAttr, "{}">:$padding_dimensions,
1014-
OptionalAttr<I64ArrayAttr>:$pad_to_multiple_of,
1014+
Variadic<TransformAnyParamTypeOrAnyHandle>:$pad_to_multiple_of,
1015+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
1016+
$static_pad_to_multiple_of,
10151017
DefaultValuedAttr<I64ArrayAttr, "{}">:$pack_paddings,
10161018
DefaultValuedAttr<
10171019
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
@@ -1021,8 +1023,13 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
10211023
TransformHandleTypeInterface:$pad,
10221024
TransformHandleTypeInterface:$copy);
10231025

1024-
let assemblyFormat =
1025-
"$target attr-dict `:` functional-type(operands, results)";
1026+
let assemblyFormat = [{
1027+
$target
1028+
(`pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of)^)?
1029+
attr-dict
1030+
`:` functional-type(operands, results)
1031+
}];
1032+
10261033
let hasVerifier = 1;
10271034

10281035
let builders = [
@@ -1033,7 +1040,13 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
10331040
// TODO: support other operations (e.g. min, max etc).
10341041
OpBuilder<(ins "Value":$target,
10351042
"ArrayRef<int64_t>":$paddingDimensions,
1036-
CArg<"ArrayRef<int64_t>", "{}">:$padToMultipleOf,
1043+
CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
1044+
CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
1045+
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
1046+
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
1047+
OpBuilder<(ins "Value":$target,
1048+
"ArrayRef<int64_t>":$paddingDimensions,
1049+
"ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
10371050
CArg<"ArrayRef<int64_t>", "{}">:$packPaddings,
10381051
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
10391052
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
@@ -1043,11 +1056,13 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
10431056
/// copy_back_op attribute value indicating that no copy back is desired.
10441057
static constexpr StringRef kCopyOpNone = "none";
10451058

1046-
::mlir::DiagnosedSilenceableFailure applyToOne(
1047-
::mlir::transform::TransformRewriter &rewriter,
1048-
::mlir::linalg::LinalgOp target,
1049-
::mlir::transform::ApplyToEachResultList &results,
1050-
::mlir::transform::TransformState &state);
1059+
/// Returns a mix of dynamic `pad_to_multiple_of` and static `static_pad_to_multiple_of`.
1060+
SmallVector<OpFoldResult> getMixedPadToMultipleOf();
1061+
1062+
::mlir::DiagnosedSilenceableFailure apply(
1063+
::mlir::transform::TransformRewriter &rewriter,
1064+
::mlir::transform::TransformResults &results,
1065+
::mlir::transform::TransformState &state);
10511066
}];
10521067
}
10531068

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 110 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,54 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(
171171
return DiagnosedSilenceableFailure::success();
172172
}
173173

174+
/// When possible, converts each `OpFoldResult` in `mixedResult` to
175+
/// an integer if the value can be statically inferred. If a result
176+
/// is a `Value` then it must be either a `ParamType` or a handle
177+
/// to an a constant like op.
178+
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
179+
TransformState &state, TransformOpInterface &transformOp,
180+
ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
181+
for (OpFoldResult paramOrHandle : mixedResults) {
182+
if (isa<Attribute>(paramOrHandle)) {
183+
reified.push_back(
184+
cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
185+
continue;
186+
} else if (isa<ParamType>(paramOrHandle.get<Value>().getType())) {
187+
ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
188+
if (params.size() != 1)
189+
return transformOp.emitSilenceableError() << "expected a single param";
190+
reified.push_back(
191+
cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192+
continue;
193+
}
194+
195+
Value handle = paramOrHandle.get<Value>();
196+
if (!isa<TransformHandleTypeInterface>(handle.getType()))
197+
return transformOp.emitSilenceableError() << "unexpected value handle";
198+
auto payload = state.getPayloadOps(handle);
199+
if (!llvm::hasSingleElement(payload))
200+
return transformOp.emitSilenceableError()
201+
<< "requires param or handle that is mapped to 1 payload op";
202+
203+
Operation *paramOrHandlePayloadOp = *payload.begin();
204+
if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205+
!paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
206+
return transformOp.emitSilenceableError()
207+
<< "requires param or handle to be result of op with 1 index "
208+
"result";
209+
}
210+
211+
IntegerAttr attr;
212+
if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
213+
return transformOp.emitSilenceableError()
214+
<< "requires param or handle to be the result of a constant like "
215+
"op";
216+
217+
reified.push_back(attr.getInt());
218+
}
219+
return DiagnosedSilenceableFailure::success();
220+
}
221+
174222
//===----------------------------------------------------------------------===//
175223
// Apply...PatternsOp
176224
//===----------------------------------------------------------------------===//
@@ -1664,6 +1712,8 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
16641712
// PadOp
16651713
//===---------------------------------------------------------------------===//
16661714

1715+
static const StringLiteral kPadToMultipleOfKeyword = "pad_to_multiple_of";
1716+
16671717
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
16681718
ArrayRef<int64_t> paddingDimensions,
16691719
ArrayRef<int64_t> padToMultipleOf,
@@ -1677,18 +1727,60 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
16771727
/*target=*/target,
16781728
/*paddingValues=*/ArrayAttr(), // let inference handle this
16791729
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1730+
/*padToMultipleOf=*/ValueRange{},
16801731
/*padToMultipleOf=*/
1681-
(padToMultipleOf.empty() ? ArrayAttr()
1682-
: b.getI64ArrayAttr(padToMultipleOf)),
1732+
(padToMultipleOf.empty()
1733+
? DenseI64ArrayAttr()
1734+
: b.getDenseI64ArrayAttr(padToMultipleOf)),
1735+
/*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1736+
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
1737+
/*copyBackOp=*/b.getStringAttr(copyBackOp));
1738+
}
1739+
1740+
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1741+
ArrayRef<int64_t> paddingDimensions,
1742+
ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1743+
ArrayRef<int64_t> packPaddings,
1744+
ArrayRef<Attribute> transposePaddings,
1745+
StringRef copyBackOp) {
1746+
auto resultType = transform::AnyOpType::get(b.getContext());
1747+
SmallVector<int64_t> staticPadToMultipleOf;
1748+
SmallVector<Value> dynamicPadToMultipleOf;
1749+
dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1750+
staticPadToMultipleOf);
1751+
return build(/*builder=*/b,
1752+
/*result=*/result,
1753+
/*types=*/TypeRange{resultType, resultType},
1754+
/*target=*/target,
1755+
/*paddingValues=*/ArrayAttr(), // let inference handle this
1756+
/*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1757+
/*padToMultipleOf=*/dynamicPadToMultipleOf,
1758+
/*padToMultipleOf=*/staticPadToMultipleOf,
16831759
/*packPaddings=*/b.getI64ArrayAttr(packPaddings),
16841760
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
16851761
/*copyBackOp=*/b.getStringAttr(copyBackOp));
16861762
}
16871763

1764+
void PadOp::getEffects(
1765+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1766+
consumesHandle(getTarget(), effects);
1767+
onlyReadsHandle(getPadToMultipleOf(), effects);
1768+
producesHandle(getPadded(), effects);
1769+
producesHandle(getPad(), effects);
1770+
producesHandle(getCopy(), effects);
1771+
modifiesPayload(effects);
1772+
}
1773+
1774+
SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1775+
Builder b(getContext());
1776+
return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1777+
}
1778+
16881779
DiagnosedSilenceableFailure
16891780
transform::PadOp::apply(transform::TransformRewriter &rewriter,
16901781
transform::TransformResults &results,
16911782
transform::TransformState &state) {
1783+
auto transformOp = cast<TransformOpInterface>(getOperation());
16921784
SmallVector<Operation *> paddedOps, padOps, copyBackOps;
16931785

16941786
for (Operation *target : state.getPayloadOps(getTarget())) {
@@ -1749,10 +1841,16 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
17491841
LinalgPaddingOptions options;
17501842
options.paddingDimensions =
17511843
extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1752-
SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
1753-
if (getPadToMultipleOf().has_value())
1844+
1845+
SmallVector<int64_t> padToMultipleOf;
1846+
DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
1847+
state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1848+
if (!status.succeeded())
1849+
return status;
1850+
if (padToMultipleOf.empty())
17541851
padToMultipleOf =
1755-
extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
1852+
SmallVector<int64_t>(options.paddingDimensions.size(), 1);
1853+
17561854
options.padToMultipleOf = padToMultipleOf;
17571855
options.paddingValues = paddingValues;
17581856
options.packPaddings = packPaddings;
@@ -1819,8 +1917,8 @@ LogicalResult transform::PadOp::verify() {
18191917
"integers, found "
18201918
<< getPaddingDimensions();
18211919
}
1822-
if (getPadToMultipleOf().has_value()) {
1823-
if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
1920+
if (!getMixedPadToMultipleOf().empty()) {
1921+
if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
18241922
return emitOpError() << "expects as many multiples as padding_dimensions";
18251923
}
18261924
}
@@ -3204,49 +3302,12 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
32043302
auto targets = state.getPayloadOps(getTarget());
32053303
if (std::empty(targets))
32063304
return DiagnosedSilenceableFailure::success();
3207-
3305+
auto transformOp = cast<TransformOpInterface>(getOperation());
32083306
SmallVector<int64_t> vectorSizes;
3209-
for (OpFoldResult sz : getMixedVectorSizes()) {
3210-
if (sz.is<Attribute>()) {
3211-
auto attr = sz.get<Attribute>();
3212-
vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
3213-
continue;
3214-
} else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
3215-
ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
3216-
if (params.size() != 1)
3217-
return emitSilenceableFailure(getLoc()) << "expected a single param";
3218-
vectorSizes.push_back(
3219-
cast<IntegerAttr>(params.front()).getValue().getSExtValue());
3220-
continue;
3221-
}
3222-
3223-
auto szPayloads = state.getPayloadOps(sz.get<Value>());
3224-
if (!llvm::hasSingleElement(szPayloads)) {
3225-
auto diag = this->emitOpError(
3226-
"requires vector size handle that is mapped to 1 payload op");
3227-
diag.attachNote(sz.get<Value>().getLoc())
3228-
<< "mapped to " << llvm::range_size(szPayloads) << " payload ops";
3229-
return DiagnosedSilenceableFailure::definiteFailure();
3230-
}
3231-
3232-
Operation *szPayloadOp = *szPayloads.begin();
3233-
if (szPayloadOp->getNumResults() != 1 ||
3234-
!szPayloadOp->getResult(0).getType().isIndex()) {
3235-
auto diag = this->emitOpError(
3236-
"requires vector size payload op with 1 index result");
3237-
diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3238-
return DiagnosedSilenceableFailure::definiteFailure();
3239-
}
3240-
3241-
IntegerAttr attr;
3242-
if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
3243-
auto diag = this->emitOpError("requires constant vector size");
3244-
diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3245-
return DiagnosedSilenceableFailure::definiteFailure();
3246-
}
3247-
3248-
vectorSizes.push_back(attr.getInt());
3249-
}
3307+
DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
3308+
state, transformOp, getMixedVectorSizes(), vectorSizes);
3309+
if (!status.succeeded())
3310+
return status;
32503311

32513312
// TODO: Check that the correct number of vectorSizes was provided.
32523313
for (Operation *target : targets) {

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,9 @@ def __init__(
374374
self,
375375
target: Union[Operation, OpView, Value],
376376
*,
377+
pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
377378
padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
378379
padding_dimensions: OptionalIntList = None,
379-
pad_to_multiple_of: OptionalIntList = None,
380380
pack_paddings: OptionalIntList = None,
381381
transpose_paddings: Optional[
382382
Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]
@@ -385,6 +385,16 @@ def __init__(
385385
loc=None,
386386
ip=None,
387387
):
388+
if pad_to_multiple_of is None:
389+
dynamic_pad_to_multiple_of = []
390+
static_pad_to_multiple_of = None
391+
else:
392+
(
393+
dynamic_pad_to_multiple_of,
394+
static_pad_to_multiple_of,
395+
_,
396+
) = _dispatch_dynamic_index_list(pad_to_multiple_of)
397+
388398
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
389399

390400
any_op_type = transform.AnyOpType.get()
@@ -393,9 +403,10 @@ def __init__(
393403
any_op_type,
394404
any_op_type,
395405
target,
406+
pad_to_multiple_of=dynamic_pad_to_multiple_of,
396407
padding_values=padding_values,
397408
padding_dimensions=padding_dimensions,
398-
pad_to_multiple_of=pad_to_multiple_of,
409+
static_pad_to_multiple_of=static_pad_to_multiple_of,
399410
pack_paddings=pack_paddings,
400411
transpose_paddings=transpose_paddings,
401412
copy_back_op=copy_back_op,

mlir/test/Dialect/Linalg/transform-op-pad.mlir

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,9 @@ func.func @pad_to_multiple(%arg0: tensor<24x12xf32>,
7373
module attributes {transform.with_named_sequence} {
7474
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
7575
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
76-
%padded, %pad, %copy_back = transform.structured.pad %0 {
76+
%padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [2, 2, 1] {
7777
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
7878
padding_dimensions=[0, 1, 2],
79-
pad_to_multiple_of=[2, 2, 1],
8079
pack_paddings=[1, 1, 0]
8180
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
8281
transform.yield
@@ -87,6 +86,42 @@ module attributes {transform.with_named_sequence} {
8786

8887
#map = affine_map<()[s0] -> (-s0 + 12, 7)>
8988

89+
// CHECK-LABEL: @parametrized_pad_to_multiple
90+
func.func @parametrized_pad_to_multiple(%arg0: tensor<24x12xf32>,
91+
%arg1: tensor<12x25xf32>,
92+
%arg2: tensor<24x25xf32>,
93+
%iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> {
94+
%0 = affine.min #map()[%iv2]
95+
%1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32>
96+
%2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor<?x5xf32>
97+
%3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32>
98+
99+
// CHECK: linalg.matmul
100+
// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x7xf32>, tensor<7x6xf32>)
101+
// CHECK-SAME: outs(%{{.*}} : tensor<4x6xf32>)
102+
%4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor<?x5xf32>) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32>
103+
%5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32>
104+
func.return %5 : tensor<24x25xf32>
105+
}
106+
107+
108+
module attributes {transform.with_named_sequence} {
109+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
110+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
111+
%c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
112+
%padded, %pad, %copy_back = transform.structured.pad %0 pad_to_multiple_of [%c2, 2, 1] {
113+
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
114+
padding_dimensions=[0, 1, 2],
115+
pack_paddings=[1, 1, 0]
116+
} : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
117+
transform.yield
118+
}
119+
}
120+
121+
// -----
122+
123+
#map = affine_map<()[s0] -> (-s0 + 12, 7)>
124+
90125
// CHECK-LABEL: @static_sizes_output_divisible_on_empty_op
91126
func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>,
92127
%arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,

0 commit comments

Comments
 (0)