Skip to content

Commit db250e2

Browse files
committed
add use_prescribed_tensor_shapes option
1 parent 26fd8bd commit db250e2

File tree

3 files changed

+70
-7
lines changed

3 files changed

+70
-7
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,14 +1134,16 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
11341134
DefaultValuedAttr<
11351135
TypedArrayAttrBase<I64ArrayAttr, "array of arrays of i64">,
11361136
"{}">:$transpose_paddings,
1137-
DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op);
1137+
DefaultValuedAttr<StrAttr, "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copy_back_op,
1138+
DefaultValuedAttr<UnitAttr, "false">:$use_prescribed_tensor_shapes);
11381139
let results = (outs TransformHandleTypeInterface:$padded,
11391140
TransformHandleTypeInterface:$pad,
11401141
TransformHandleTypeInterface:$copy);
11411142

11421143
let assemblyFormat = [{
11431144
$target
11441145
(`pad_to_multiple_of` custom<DynamicIndexList>($pad_to_multiple_of, $static_pad_to_multiple_of)^)?
1146+
(`use_prescribed_tensor_shapes` $use_prescribed_tensor_shapes^)?
11451147
attr-dict
11461148
`:` functional-type(operands, results)
11471149
}];
@@ -1159,13 +1161,15 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
11591161
CArg<"ArrayRef<int64_t>", "{}">:$staticPadToMultipleOf,
11601162
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
11611163
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
1162-
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>,
1164+
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
1165+
CArg<"bool", "false">:$usePrescribedTensorShapes)>,
11631166
OpBuilder<(ins "Value":$target,
11641167
"ArrayRef<int64_t>":$paddingDimensions,
11651168
"ArrayRef<OpFoldResult>":$mixedPadToMultipleOf,
11661169
CArg<"ArrayRef<int64_t>", "{}">:$nofoldFlags,
11671170
CArg<"ArrayRef<Attribute>", "{}">:$transposePaddings,
1168-
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp)>
1171+
CArg<"StringRef", "::mlir::bufferization::MaterializeInDestinationOp::getOperationName()">:$copyBackOp,
1172+
CArg<"bool", "false">:$usePrescribedTensorShapes)>
11691173
];
11701174

11711175
let extraClassDeclaration = [{

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,7 +1907,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19071907
ArrayRef<int64_t> padToMultipleOf,
19081908
ArrayRef<int64_t> nofoldFlags,
19091909
ArrayRef<Attribute> transposePaddings,
1910-
StringRef copyBackOp) {
1910+
StringRef copyBackOp,
1911+
bool usePrescribedTensorShapes) {
19111912
auto resultType = transform::AnyOpType::get(b.getContext());
19121913
return build(/*builder=*/b,
19131914
/*result=*/result,
@@ -1922,15 +1923,18 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19221923
: b.getDenseI64ArrayAttr(padToMultipleOf)),
19231924
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
19241925
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
1925-
/*copyBackOp=*/b.getStringAttr(copyBackOp));
1926+
/*copyBackOp=*/b.getStringAttr(copyBackOp),
1927+
/*usePrescribedTensorShapes=*/
1928+
usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
19261929
}
19271930

19281931
void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19291932
ArrayRef<int64_t> paddingDimensions,
19301933
ArrayRef<OpFoldResult> mixedPadToMultipleOf,
19311934
ArrayRef<int64_t> nofoldFlags,
19321935
ArrayRef<Attribute> transposePaddings,
1933-
StringRef copyBackOp) {
1936+
StringRef copyBackOp,
1937+
bool usePrescribedTensorShapes) {
19341938
auto resultType = transform::AnyOpType::get(b.getContext());
19351939
SmallVector<int64_t> staticPadToMultipleOf;
19361940
SmallVector<Value> dynamicPadToMultipleOf;
@@ -1946,7 +1950,8 @@ void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
19461950
/*padToMultipleOf=*/staticPadToMultipleOf,
19471951
/*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
19481952
/*transposePaddings=*/b.getArrayAttr(transposePaddings),
1949-
/*copyBackOp=*/b.getStringAttr(copyBackOp));
1953+
/*copyBackOp=*/copyBackOp,
1954+
/*usePrescribedTensorShapes=*/usePrescribedTensorShapes);
19501955
}
19511956

19521957
void PadOp::getEffects(
@@ -2051,11 +2056,32 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
20512056
} else {
20522057
llvm_unreachable("unsupported copy_back op");
20532058
}
2059+
// Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2060+
bool irChanged = false;
2061+
if (getUsePrescribedTensorShapes() &&
2062+
linalgTarget.hasPureTensorSemantics()) {
2063+
for (OpOperand &operand : linalgTarget->getOpOperands()) {
2064+
for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2065+
if (ShapedType::isDynamic(dim))
2066+
continue;
2067+
options.setSizeToPadTo(operand.getOperandNumber(), i,
2068+
tensor::getMixedSize(rewriter,
2069+
operand.get().getLoc(),
2070+
operand.get(), i));
2071+
irChanged = true;
2072+
}
2073+
}
2074+
}
20542075

20552076
SmallVector<Value> replacements;
20562077
SmallVector<tensor::PadOp> newPadOps;
20572078
if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
20582079
replacements, newPadOps))) {
2080+
if (irChanged) {
2081+
auto diag = emitDefiniteFailure() << "failed to pad op";
2082+
diag.attachNote(target->getLoc()) << "target op";
2083+
return diag;
2084+
}
20592085
auto diag = emitSilenceableError() << "failed to pad op";
20602086
diag.attachNote(target->getLoc()) << "target op";
20612087
return diag;

mlir/test/Dialect/Linalg/transform-op-pad.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,39 @@ module attributes {transform.with_named_sequence} {
313313

314314
// -----
315315

316+
// Test dynamic padding using `use_prescribed_tensor_shapes`
317+
318+
// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (-s1 + (s0 ceildiv 7) * 7)>
319+
// CHECK: @use_prescribed_tensor_shapes
320+
// CHECK: (%[[ARG0:.*]]: tensor<?x12xf32>, %[[ARG1:.*]]: tensor<12x?xf32>
321+
func.func @use_prescribed_tensor_shapes(%arg0: tensor<?x12xf32>,
322+
%arg1: tensor<12x?xf32>,
323+
%arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
324+
// CHECK: %[[C1_0:.*]] = arith.constant 1 : index
325+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG1]], %[[C1_0]] : tensor<12x?xf32>
326+
// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
327+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG1]], %[[C1_1]] : tensor<12x?xf32>
328+
// CHECK: %[[PADDING:.*]] = affine.apply #[[MAP]]()[%[[DIM_0]], %[[DIM_1]]]
329+
// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG1]] low[0, 0] high[0, %[[PADDING]]] {
330+
// CHECK: linalg.matmul ins(%[[ARG0]], %[[PADDED]] : tensor<?x12xf32>, tensor<12x?xf32>)
331+
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x12xf32>, tensor<12x?xf32>) outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
332+
func.return %0 : tensor<?x?xf32>
333+
}
334+
335+
module attributes {transform.with_named_sequence} {
336+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
337+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
338+
%padded, %pad, %copy_back = transform.structured.pad %0
339+
pad_to_multiple_of [7] use_prescribed_tensor_shapes {
340+
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
341+
padding_dimensions=[1]
342+
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
343+
transform.yield
344+
}
345+
}
346+
347+
// -----
348+
316349
// Check that the padding can be applied even when the output argument of the
317350
// linalg op is not produced by an empty op or an extract_slice op.
318351

0 commit comments

Comments
 (0)