@@ -163,6 +163,7 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder,
163163
164164 if (numElems < 8 ) {
165165 SmallVector<mlir::Attribute, 4 > indices;
166+ indices.reserve (numElems);
166167 mlir::Type i32Ty = builder.getSInt32Ty ();
167168 for (auto i : llvm::seq<unsigned >(0 , numElems))
168169 indices.push_back (cir::IntAttr::get (i32Ty, i));
@@ -172,43 +173,11 @@ static mlir::Value getBoolMaskVecValue(CIRGenBuilderTy &builder,
172173 return maskVec;
173174}
174175
175- // Helper function mirroring OG's bool Constant::isAllOnesValue()
176- static bool isAllOnesValue (mlir::Value value) {
177- auto constOp = mlir::dyn_cast_or_null<cir::ConstantOp>(value.getDefiningOp ());
178- if (!constOp)
179- return false ;
180-
181- // Check for -1 integers
182- if (auto intAttr = constOp.getValueAttr <cir::IntAttr>()) {
183- return intAttr.getValue ().isAllOnes ();
184- }
185-
186- // Check for FP which are bitcasted from -1 integers
187- if (auto fpAttr = constOp.getValueAttr <cir::FPAttr>()) {
188- return fpAttr.getValue ().bitcastToAPInt ().isAllOnes ();
189- }
190-
191- // Check for constant vectors with splat values
192- if (cir::VectorType v = dyn_cast<cir::VectorType>(constOp.getType ())) {
193- if (auto vecAttr = constOp.getValueAttr <mlir::DenseElementsAttr>()) {
194- if (vecAttr.isSplat ()) {
195- auto splatAttr = vecAttr.getSplatValue <mlir::Attribute>();
196- if (auto splatInt = mlir::dyn_cast<cir::IntAttr>(splatAttr)) {
197- return splatInt.getValue ().isAllOnes ();
198- }
199- }
200- }
201- }
202-
203- return false ;
204- }
205-
206176static mlir::Value emitX86Select (CIRGenBuilderTy &builder, mlir::Location loc,
207177 mlir::Value mask, mlir::Value op0,
208178 mlir::Value op1) {
209-
210179 // If the mask is all ones just return first argument.
211- if (isAllOnesValue (mask))
180+ if (cir::ConstantOp:: isAllOnesValue (mask))
212181 return op0;
213182
214183 mask = getBoolMaskVecValue (builder, loc, mask,
@@ -958,22 +927,21 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
958927 unsigned numElts = dstTy.getSize ();
959928 unsigned srcNumElts = cast<cir::VectorType>(ops[0 ].getType ()).getSize ();
960929 unsigned subVectors = srcNumElts / numElts;
930+ assert (llvm::isPowerOf2_32 (subVectors) && " Expected power of 2 subvectors" );
961931 unsigned index =
962932 ops[1 ].getDefiningOp <cir::ConstantOp>().getIntValue ().getZExtValue ();
963933
964934 index &= subVectors - 1 ; // Remove any extra bits.
965935 index *= numElts;
966936
967937 int64_t indices[16 ];
968- for (unsigned i = 0 ; i != numElts; ++i)
969- indices[i] = i + index;
938+ std::iota (indices, indices + numElts, index);
970939
971940 mlir::Value zero = builder.getNullValue (ops[0 ].getType (), loc);
972941 mlir::Value res =
973942 builder.createVecShuffle (loc, ops[0 ], zero, ArrayRef (indices, numElts));
974- if (ops.size () == 4 ) {
943+ if (ops.size () == 4 )
975944 res = emitX86Select (builder, loc, ops[3 ], res, ops[2 ]);
976- }
977945
978946 return res;
979947 }
0 commit comments