@@ -2673,43 +2673,51 @@ static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
26732673}
26742674
26752675OpFoldResult vector::ShuffleOp::fold (FoldAdaptor adaptor) {
2676- VectorType v1Type = getV1VectorType ();
2676+ auto v1Type = getV1VectorType ();
2677+ auto v2Type = getV2VectorType ();
2678+
2679+ assert (!v1Type.isScalable () && !v2Type.isScalable () &&
2680+ " Vector shuffle does not support scalable vectors" );
2681+
26772682 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
26782683 // but must be a canonicalization into a vector.broadcast.
26792684 if (v1Type.getRank () == 0 )
26802685 return {};
26812686
2682- // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2683- if (!v1Type. isScalable () &&
2684- isStepIndexArray (getMask () , 0 , v1Type.getDimSize (0 )))
2687+ // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
2688+ auto mask = getMask ();
2689+ if ( isStepIndexArray (mask , 0 , v1Type.getDimSize (0 )))
26852690 return getV1 ();
2686- // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2687- if (!getV1VectorType ().isScalable () && !getV2VectorType ().isScalable () &&
2688- isStepIndexArray (getMask (), getV1VectorType ().getDimSize (0 ),
2689- getV2VectorType ().getDimSize (0 )))
2691+ // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
2692+ if (isStepIndexArray (mask, v1Type.getDimSize (0 ), v2Type.getDimSize (0 )))
26902693 return getV2 ();
26912694
2692- Attribute lhs = adaptor.getV1 (), rhs = adaptor.getV2 ();
2693- if (!lhs || !rhs )
2695+ Attribute v1Attr = adaptor.getV1 (), v2Attr = adaptor.getV2 ();
2696+ if (!v1Attr || !v2Attr )
26942697 return {};
26952698
2696- auto lhsType =
2697- llvm::cast<VectorType>(llvm::cast<DenseElementsAttr>(lhs).getType ());
26982699 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
26992700 // manipulation.
2700- if (lhsType .getRank () != 1 )
2701+ if (v1Type .getRank () != 1 )
27012702 return {};
2702- int64_t lhsSize = lhsType.getDimSize (0 );
2703+
2704+ int64_t v1Size = v1Type.getDimSize (0 );
27032705
27042706 SmallVector<Attribute> results;
2705- auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues <Attribute>();
2706- auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues <Attribute>();
2707- for (int64_t i : this ->getMask ()) {
2708- if (i >= lhsSize) {
2709- results.push_back (rhsElements[i - lhsSize]);
2707+ auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues <Attribute>();
2708+ auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues <Attribute>();
2709+ for (int64_t maskIdx : mask) {
2710+ Attribute indexedElm;
2711+ // Select v1[0] for poison indices.
2712+ // TODO: Return a partial poison vector when supported by the UB dialect.
2713+ if (maskIdx == ShuffleOp::kPoisonIndex ) {
2714+ indexedElm = v1Elements[0 ];
27102715 } else {
2711- results.push_back (lhsElements[i]);
2716+ indexedElm =
2717+ maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
27122718 }
2719+
2720+ results.push_back (indexedElm);
27132721 }
27142722
27152723 return DenseElementsAttr::get (getResultVectorType (), results);
0 commit comments