|
15 | 15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | 16 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
17 | 17 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| 18 | +#include "mlir/Dialect/CommonFolders.h" |
18 | 19 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
19 | 20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
20 | 21 | #include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h" |
|
27 | 28 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
28 | 29 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
29 | 30 | #include "mlir/Dialect/Transform/Utils/Utils.h" |
| 31 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
30 | 32 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
31 | 33 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
32 | 34 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
@@ -1985,22 +1987,27 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter, |
1985 | 1987 |
|
1986 | 1988 | // Convert the padding values to attributes. |
1987 | 1989 | SmallVector<Attribute> paddingValues; |
1988 | | - for (auto const &it : |
| 1990 | + for (auto const &[untypedAttr, elementOrTensorType] : |
1989 | 1991 | llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) { |
1990 | | - auto attr = dyn_cast<TypedAttr>(std::get<0>(it)); |
| 1992 | + |
| 1993 | + if (isa<ub::PoisonAttr>(untypedAttr)) { |
| 1994 | + paddingValues.push_back(untypedAttr); |
| 1995 | + continue; |
| 1996 | + } |
| 1997 | + auto attr = dyn_cast<TypedAttr>(untypedAttr); |
1991 | 1998 | if (!attr) { |
1992 | | - emitOpError("expects padding values to be typed attributes"); |
| 1999 | + emitOpError("expects padding values to be typed attributes or poison"); |
1993 | 2000 | return DiagnosedSilenceableFailure::definiteFailure(); |
1994 | 2001 | } |
1995 | | - Type elementType = getElementTypeOrSelf(std::get<1>(it)); |
| 2002 | + Type elementType = getElementTypeOrSelf(elementOrTensorType); |
1996 | 2003 | // Try to parse string attributes to obtain an attribute of element type. |
1997 | 2004 | if (auto stringAttr = dyn_cast<StringAttr>(attr)) { |
1998 | 2005 | auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute( |
1999 | 2006 | stringAttr, getContext(), elementType, |
2000 | 2007 | /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); |
2001 | 2008 | if (!parsedAttr || parsedAttr.getType() != elementType) { |
2002 | 2009 | auto diag = this->emitOpError("expects a padding that parses to ") |
2003 | | - << elementType << ", got " << std::get<0>(it); |
| 2010 | + << elementType << ", got " << untypedAttr; |
2004 | 2011 | diag.attachNote(linalgTarget.getLoc()) << "when applied to this op"; |
2005 | 2012 | return DiagnosedSilenceableFailure::definiteFailure(); |
2006 | 2013 | } |
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, |
2235 | 2242 | llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) { |
2236 | 2243 | auto attr = dyn_cast<TypedAttr>(untypedAttr); |
2237 | 2244 | Type elementType = getElementTypeOrSelf(elementOrTensorType); |
| 2245 | + |
| 2246 | + if (isa<ub::PoisonAttr>(untypedAttr)) { |
| 2247 | + paddingValues.push_back(untypedAttr); |
| 2248 | + continue; |
| 2249 | + } |
2238 | 2250 | if (!attr) { |
2239 | | - emitOpError("expects padding values to be typed attributes"); |
| 2251 | + emitOpError("expects padding values to be typed attributes or poison"); |
2240 | 2252 | return DiagnosedSilenceableFailure::definiteFailure(); |
2241 | 2253 | } |
2242 | 2254 | // Try to parse string attributes to obtain an attribute of element type. |
|
0 commit comments