Skip to content

Commit b574bcf

Browse files
authored
[mlir][TD] Support padding with poison (#152003)
Signed-off-by: James Newling <[email protected]>
1 parent 45b4f1b commit b574bcf

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
1616
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1717
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
18+
#include "mlir/Dialect/CommonFolders.h"
1819
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1920
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2021
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
2728
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
2829
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
2930
#include "mlir/Dialect/Transform/Utils/Utils.h"
31+
#include "mlir/Dialect/UB/IR/UBOps.h"
3032
#include "mlir/Dialect/Utils/IndexingUtils.h"
3133
#include "mlir/Dialect/Utils/StaticValueUtils.h"
3234
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -1985,22 +1987,27 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
19851987

19861988
// Convert the padding values to attributes.
19871989
SmallVector<Attribute> paddingValues;
1988-
for (auto const &it :
1990+
for (auto const &[untypedAttr, elementOrTensorType] :
19891991
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1990-
auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1992+
1993+
if (isa<ub::PoisonAttr>(untypedAttr)) {
1994+
paddingValues.push_back(untypedAttr);
1995+
continue;
1996+
}
1997+
auto attr = dyn_cast<TypedAttr>(untypedAttr);
19911998
if (!attr) {
1992-
emitOpError("expects padding values to be typed attributes");
1999+
emitOpError("expects padding values to be typed attributes or poison");
19932000
return DiagnosedSilenceableFailure::definiteFailure();
19942001
}
1995-
Type elementType = getElementTypeOrSelf(std::get<1>(it));
2002+
Type elementType = getElementTypeOrSelf(elementOrTensorType);
19962003
// Try to parse string attributes to obtain an attribute of element type.
19972004
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
19982005
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
19992006
stringAttr, getContext(), elementType,
20002007
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
20012008
if (!parsedAttr || parsedAttr.getType() != elementType) {
20022009
auto diag = this->emitOpError("expects a padding that parses to ")
2003-
<< elementType << ", got " << std::get<0>(it);
2010+
<< elementType << ", got " << untypedAttr;
20042011
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
20052012
return DiagnosedSilenceableFailure::definiteFailure();
20062013
}
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
22352242
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
22362243
auto attr = dyn_cast<TypedAttr>(untypedAttr);
22372244
Type elementType = getElementTypeOrSelf(elementOrTensorType);
2245+
2246+
if (isa<ub::PoisonAttr>(untypedAttr)) {
2247+
paddingValues.push_back(untypedAttr);
2248+
continue;
2249+
}
22382250
if (!attr) {
2239-
emitOpError("expects padding values to be typed attributes");
2251+
emitOpError("expects padding values to be typed attributes or poison");
22402252
return DiagnosedSilenceableFailure::definiteFailure();
22412253
}
22422254
// Try to parse string attributes to obtain an attribute of element type.

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1212
#include "mlir/Dialect/Complex/IR/Complex.h"
1313
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1415
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1516
#include "mlir/IR/AffineExpr.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
230231
Value paddingValue;
231232
if (auto complexTy =
232233
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
233-
auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
234-
paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
235-
complexTy, complexAttr);
236-
} else {
237-
paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
238-
cast<TypedAttr>(paddingValueAttr));
234+
if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
235+
paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
236+
complexTy, complexAttr);
237+
}
238+
} else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
239+
paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
240+
getElementTypeOrSelf(v.getType()));
241+
} else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
242+
paddingValue =
243+
arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
239244
}
245+
assert(paddingValue && "failed to create value from padding attribute");
240246

241247
// Pad the operand to the bounding box defined by `paddedShape`.
242248
SmallVector<int64_t> tensorShape;

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ module attributes {transform.with_named_sequence} {
1414
: (!transform.any_op) -> !transform.any_op
1515

1616
// Tile to 5 then pad to 8
17-
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
17+
%fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
1818
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1919

2020
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
21-
padding_values=[0.0 : f32, 0.0 : f32]
21+
padding_values= [#ub.poison, 0.0 : f32]
2222
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2323

2424
transform.yield
@@ -33,9 +33,9 @@ func.func @pad_lhs(
3333
-> tensor<24x25xf32>
3434
{
3535
// CHECK: scf.for %{{.*}} -> (tensor<24x25xf32>)
36-
// CHECK: tensor.pad %{{.*}}
36+
// CHECK: tensor.pad %{{.*}}
3737
// CHECK: : tensor<?x12xf32> to tensor<8x12xf32>
38-
// CHECK: tensor.pad %{{.*}}
38+
// CHECK: tensor.pad %{{.*}}
3939
// CHECK: : tensor<?x25xf32> to tensor<8x25xf32>
4040
// CHECK: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor<8x12xf32>, tensor<12x25xf32>) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
4141
// CHECK: tensor.extract_slice %{{.*}}[0, 0] [%{{.*}}, 25] [1, 1]
@@ -92,7 +92,7 @@ module {
9292
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
9393
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
9494
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
95-
transform.yield
95+
transform.yield
9696
}
9797
}
9898
}
@@ -147,7 +147,7 @@ module {
147147
%padded, %pad = transform.structured.pad_tiling_interface %0 to padding_sizes [8, 0, 14] {
148148
padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]
149149
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
150-
transform.yield
150+
transform.yield
151151
}
152152
}
153153
}

0 commit comments

Comments
 (0)