2626#include " mlir/IR/AffineMap.h"
2727#include " mlir/IR/Builders.h"
2828#include " mlir/IR/BuiltinAttributes.h"
29- #include " mlir/IR/BuiltinOps.h"
3029#include " mlir/IR/BuiltinTypes.h"
3130#include " mlir/IR/DialectImplementation.h"
3231#include " mlir/IR/IRMapping.h"
4241#include " llvm/ADT/SmallVector.h"
4342#include " llvm/ADT/StringSet.h"
4443#include " llvm/ADT/TypeSwitch.h"
45- #include " llvm/ADT/bit.h"
4644
4745#include < cassert>
4846#include < cstdint>
@@ -2684,25 +2682,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
26842682 if (!v1Attr || !v2Attr)
26852683 return {};
26862684
2685+ // Fold shuffle poison, poison -> poison.
2686+ bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
2687+ bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
2688+ if (isV1Poison && isV2Poison)
2689+ return ub::PoisonAttr::get (getContext ());
2690+
26872691 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
26882692 // manipulation.
26892693 if (v1Type.getRank () != 1 )
26902694 return {};
26912695
2692- int64_t v1Size = v1Type.getDimSize (0 );
2696+ // Poison input attributes need special handling as they are not
2697+ // DenseElementsAttr. If an index is poison, we select the first element of
2698+ // the first non-poison input.
2699+ SmallVector<Attribute> v1Elements, v2Elements;
2700+ Attribute poisonElement;
2701+ if (!isV2Poison) {
2702+ v2Elements =
2703+ to_vector (cast<DenseElementsAttr>(v2Attr).getValues <Attribute>());
2704+ poisonElement = v2Elements[0 ];
2705+ }
2706+ if (!isV1Poison) {
2707+ v1Elements =
2708+ to_vector (cast<DenseElementsAttr>(v1Attr).getValues <Attribute>());
2709+ poisonElement = v1Elements[0 ];
2710+ }
26932711
26942712 SmallVector<Attribute> results;
2695- auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues <Attribute>();
2696- auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues <Attribute>();
2713+ int64_t v1Size = v1Type.getDimSize (0 );
26972714 for (int64_t maskIdx : mask) {
26982715 Attribute indexedElm;
2699- // Select v1[0] for poison indices.
27002716 // TODO: Return a partial poison vector when supported by the UB dialect.
27012717 if (maskIdx == ShuffleOp::kPoisonIndex ) {
2702- indexedElm = v1Elements[ 0 ] ;
2718+ indexedElm = poisonElement ;
27032719 } else {
2704- indexedElm =
2705- maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
2720+ if (maskIdx < v1Size)
2721+ indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
2722+ else
2723+ indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
27062724 }
27072725
27082726 results.push_back (indexedElm);
@@ -3319,13 +3337,15 @@ class InsertStridedSliceConstantFolder final
33193337 !destVector.hasOneUse ())
33203338 return failure ();
33213339
3322- auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3323-
33243340 TypedValue<VectorType> sourceValue = op.getSource ();
33253341 Attribute sourceCst;
33263342 if (!matchPattern (sourceValue, m_Constant (&sourceCst)))
33273343 return failure ();
33283344
3345+ // TODO: Support poison.
3346+ if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3347+ return failure ();
3348+
33293349 // TODO: Handle non-unit strides when they become available.
33303350 if (op.hasNonUnitStrides ())
33313351 return failure ();
@@ -3342,6 +3362,7 @@ class InsertStridedSliceConstantFolder final
33423362 // increasing linearized position indices.
33433363 // Because the destination may have higher dimensionality then the slice,
33443364 // we keep track of two overlapping sets of positions and offsets.
3365+ auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
33453366 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
33463367 auto sliceValuesIt = denseSlice.value_begin <Attribute>();
33473368 auto newValues = llvm::to_vector (denseDest.getValues <Attribute>());
0 commit comments