|
11 | 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | 12 | #include "mlir/Dialect/Complex/IR/Complex.h" |
13 | 13 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 14 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
14 | 15 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
15 | 16 | #include "mlir/IR/AffineExpr.h" |
16 | 17 | #include "mlir/IR/BuiltinAttributes.h" |
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, |
230 | 231 | Value paddingValue; |
231 | 232 | if (auto complexTy = |
232 | 233 | dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) { |
233 | | - auto complexAttr = cast<ArrayAttr>(paddingValueAttr); |
234 | | - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), |
235 | | - complexTy, complexAttr); |
236 | | - } else { |
237 | | - paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), |
238 | | - cast<TypedAttr>(paddingValueAttr)); |
| 234 | + if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) { |
| 235 | + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), |
| 236 | + complexTy, complexAttr); |
| 237 | + } |
| 238 | + } else if (isa<ub::PoisonAttr>(paddingValueAttr)) { |
| 239 | + paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), |
| 240 | + getElementTypeOrSelf(v.getType())); |
| 241 | + } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) { |
| 242 | + paddingValue = |
| 243 | + arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); |
239 | 244 | } |
| 245 | + assert(paddingValue && "failed to create value from padding attribute"); |
240 | 246 |
|
241 | 247 | // Pad the operand to the bounding box defined by `paddedShape`. |
242 | 248 | SmallVector<int64_t> tensorShape; |
|
0 commit comments