2626#include " mlir/IR/AffineMap.h"
2727#include " mlir/IR/Builders.h"
2828#include " mlir/IR/BuiltinAttributes.h"
29- #include " mlir/IR/BuiltinOps.h"
3029#include " mlir/IR/BuiltinTypes.h"
3130#include " mlir/IR/DialectImplementation.h"
3231#include " mlir/IR/IRMapping.h"
4241#include " llvm/ADT/SmallVector.h"
4342#include " llvm/ADT/StringSet.h"
4443#include " llvm/ADT/TypeSwitch.h"
45- #include " llvm/ADT/bit.h"
4644
4745#include < cassert>
4846#include < cstdint>
@@ -2696,25 +2694,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
26962694 if (!v1Attr || !v2Attr)
26972695 return {};
26982696
2697+ // Fold shuffle poison, poison -> poison.
2698+ bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2699+ bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2700+ if (isV1Poison && isV2Poison)
2701+ return ub::PoisonAttr::get (getContext ());
2702+
26992703 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
27002704 // manipulation.
27012705 if (v1Type.getRank () != 1 )
27022706 return {};
27032707
2704- int64_t v1Size = v1Type.getDimSize (0 );
2708+ // Poison input attributes need special handling as they are not
2709+ // DenseElementsAttr. If an index is poison, we select the first element of
2710+ // the first non-poison input.
2711+ SmallVector<Attribute> v1Elements, v2Elements;
2712+ Attribute poisonElement;
2713+ if (!isV2Poison) {
2714+ v2Elements =
2715+ to_vector (cast<DenseElementsAttr>(v2Attr).getValues <Attribute>());
2716+ poisonElement = v2Elements[0 ];
2717+ }
2718+ if (!isV1Poison) {
2719+ v1Elements =
2720+ to_vector (cast<DenseElementsAttr>(v1Attr).getValues <Attribute>());
2721+ poisonElement = v1Elements[0 ];
2722+ }
27052723
27062724 SmallVector<Attribute> results;
2707- auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues <Attribute>();
2708- auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues <Attribute>();
2725+ int64_t v1Size = v1Type.getDimSize (0 );
27092726 for (int64_t maskIdx : mask) {
27102727 Attribute indexedElm;
2711- // Select v1[0] for poison indices.
27122728 // TODO: Return a partial poison vector when supported by the UB dialect.
27132729 if (maskIdx == ShuffleOp::kPoisonIndex ) {
2714- indexedElm = v1Elements[ 0 ] ;
2730+ indexedElm = poisonElement ;
27152731 } else {
2716- indexedElm =
2717- maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
2732+ if (maskIdx < v1Size)
2733+ indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2734+ else
2735+ indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
27182736 }
27192737
27202738 results.push_back (indexedElm);
@@ -3332,13 +3350,15 @@ class InsertStridedSliceConstantFolder final
33323350 !destVector.hasOneUse ())
33333351 return failure ();
33343352
3335- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3336-
33373353 TypedValue<VectorType> sourceValue = op.getSource ();
33383354 Attribute sourceCst;
33393355 if (!matchPattern (sourceValue, m_Constant (&sourceCst)))
33403356 return failure ();
33413357
3358+ // TODO: Support poison.
3359+ if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3360+ return failure ();
3361+
33423362 // TODO: Handle non-unit strides when they become available.
33433363 if (op.hasNonUnitStrides ())
33443364 return failure ();
@@ -3355,6 +3375,7 @@ class InsertStridedSliceConstantFolder final
33553375 // increasing linearized position indices.
33563376 // Because the destination may have higher dimensionality then the slice,
33573377 // we keep track of two overlapping sets of positions and offsets.
3378+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
33583379 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
33593380 auto sliceValuesIt = denseSlice.value_begin <Attribute>();
33603381 auto newValues = llvm::to_vector (denseDest.getValues <Attribute>());
0 commit comments