|
15 | 15 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
16 | 16 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
17 | 17 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
| 18 | +#include "mlir/Dialect/CommonFolders.h" |
18 | 19 | #include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
19 | 20 | #include "mlir/Dialect/Linalg/IR/Linalg.h"
|
20 | 21 | #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
|
|
27 | 28 | #include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
28 | 29 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
29 | 30 | #include "mlir/Dialect/Transform/Utils/Utils.h"
|
| 31 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
30 | 32 | #include "mlir/Dialect/Utils/IndexingUtils.h"
|
31 | 33 | #include "mlir/Dialect/Utils/StaticValueUtils.h"
|
32 | 34 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
@@ -1985,22 +1987,27 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
|
1985 | 1987 |
|
1986 | 1988 | // Convert the padding values to attributes.
|
1987 | 1989 | SmallVector<Attribute> paddingValues;
|
1988 |
| - for (auto const &it : |
| 1990 | + for (auto const &[untypedAttr, elementOrTensorType] : |
1989 | 1991 | llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
|
1990 |
| - auto attr = dyn_cast<TypedAttr>(std::get<0>(it)); |
| 1992 | + |
| 1993 | + if (isa<ub::PoisonAttr>(untypedAttr)) { |
| 1994 | + paddingValues.push_back(untypedAttr); |
| 1995 | + continue; |
| 1996 | + } |
| 1997 | + auto attr = dyn_cast<TypedAttr>(untypedAttr); |
1991 | 1998 | if (!attr) {
|
1992 |
| - emitOpError("expects padding values to be typed attributes"); |
| 1999 | + emitOpError("expects padding values to be typed attributes or poison"); |
1993 | 2000 | return DiagnosedSilenceableFailure::definiteFailure();
|
1994 | 2001 | }
|
1995 |
| - Type elementType = getElementTypeOrSelf(std::get<1>(it)); |
| 2002 | + Type elementType = getElementTypeOrSelf(elementOrTensorType); |
1996 | 2003 | // Try to parse string attributes to obtain an attribute of element type.
|
1997 | 2004 | if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
|
1998 | 2005 | auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
|
1999 | 2006 | stringAttr, getContext(), elementType,
|
2000 | 2007 | /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
|
2001 | 2008 | if (!parsedAttr || parsedAttr.getType() != elementType) {
|
2002 | 2009 | auto diag = this->emitOpError("expects a padding that parses to ")
|
2003 |
| - << elementType << ", got " << std::get<0>(it); |
| 2010 | + << elementType << ", got " << untypedAttr; |
2004 | 2011 | diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
|
2005 | 2012 | return DiagnosedSilenceableFailure::definiteFailure();
|
2006 | 2013 | }
|
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
|
2235 | 2242 | llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
|
2236 | 2243 | auto attr = dyn_cast<TypedAttr>(untypedAttr);
|
2237 | 2244 | Type elementType = getElementTypeOrSelf(elementOrTensorType);
|
| 2245 | + |
| 2246 | + if (isa<ub::PoisonAttr>(untypedAttr)) { |
| 2247 | + paddingValues.push_back(untypedAttr); |
| 2248 | + continue; |
| 2249 | + } |
2238 | 2250 | if (!attr) {
|
2239 |
| - emitOpError("expects padding values to be typed attributes"); |
| 2251 | + emitOpError("expects padding values to be typed attributes or poison"); |
2240 | 2252 | return DiagnosedSilenceableFailure::definiteFailure();
|
2241 | 2253 | }
|
2242 | 2254 | // Try to parse string attributes to obtain an attribute of element type.
|
|
0 commit comments