@@ -381,3 +381,96 @@ LogicalResult LLVM::detail::oneToOneRewrite(
381381 rewriter.replaceOp (op, results);
382382 return success ();
383383}
384+
385+ static unsigned getBitWidth (Type type) {
386+ if (type.isIntOrFloat ())
387+ return type.getIntOrFloatBitWidth ();
388+
389+ auto vec = cast<VectorType>(type);
390+ assert (!vec.isScalable () && " scalable vectors are not supported" );
391+ return vec.getNumElements () * getBitWidth (vec.getElementType ());
392+ }
393+
394+ static Value createI32Constant (OpBuilder &builder, Location loc,
395+ int32_t value) {
396+ Type i32 = builder.getI32Type ();
397+ return builder.create <LLVM::ConstantOp>(loc, i32 , value);
398+ }
399+
400+ SmallVector<Value> mlir::LLVM::decomposeValue (OpBuilder &builder, Location loc,
401+ Value src, Type dstType) {
402+ Type srcType = src.getType ();
403+ if (srcType == dstType)
404+ return {src};
405+
406+ unsigned srcBitWidth = getBitWidth (srcType);
407+ unsigned dstBitWidth = getBitWidth (dstType);
408+ if (srcBitWidth == dstBitWidth) {
409+ Value cast = builder.create <LLVM::BitcastOp>(loc, dstType, src);
410+ return {cast};
411+ }
412+
413+ if (dstBitWidth > srcBitWidth) {
414+ auto smallerInt = builder.getIntegerType (srcBitWidth);
415+ if (srcType != smallerInt)
416+ src = builder.create <LLVM::BitcastOp>(loc, smallerInt, src);
417+
418+ auto largerInt = builder.getIntegerType (dstBitWidth);
419+ Value res = builder.create <LLVM::ZExtOp>(loc, largerInt, src);
420+ return {res};
421+ }
422+ assert (srcBitWidth % dstBitWidth == 0 &&
423+ " src bit width must be a multiple of dst bit width" );
424+ int64_t numElements = srcBitWidth / dstBitWidth;
425+ auto vecType = VectorType::get (numElements, dstType);
426+
427+ src = builder.create <LLVM::BitcastOp>(loc, vecType, src);
428+
429+ SmallVector<Value> res;
430+ for (auto i : llvm::seq (numElements)) {
431+ Value idx = createI32Constant (builder, loc, i);
432+ Value elem = builder.create <LLVM::ExtractElementOp>(loc, src, idx);
433+ res.emplace_back (elem);
434+ }
435+
436+ return res;
437+ }
438+
439+ Value mlir::LLVM::composeValue (OpBuilder &builder, Location loc, ValueRange src,
440+ Type dstType) {
441+ assert (!src.empty () && " src range must not be empty" );
442+ if (src.size () == 1 ) {
443+ Value res = src.front ();
444+ if (res.getType () == dstType)
445+ return res;
446+
447+ unsigned srcBitWidth = getBitWidth (res.getType ());
448+ unsigned dstBitWidth = getBitWidth (dstType);
449+ if (dstBitWidth < srcBitWidth) {
450+ auto largerInt = builder.getIntegerType (srcBitWidth);
451+ if (res.getType () != largerInt)
452+ res = builder.create <LLVM::BitcastOp>(loc, largerInt, res);
453+
454+ auto smallerInt = builder.getIntegerType (dstBitWidth);
455+ res = builder.create <LLVM::TruncOp>(loc, smallerInt, res);
456+ }
457+
458+ if (res.getType () != dstType)
459+ res = builder.create <LLVM::BitcastOp>(loc, dstType, res);
460+
461+ return res;
462+ }
463+
464+ int64_t numElements = src.size ();
465+ auto srcType = VectorType::get (numElements, src.front ().getType ());
466+ Value res = builder.create <LLVM::PoisonOp>(loc, srcType);
467+ for (auto &&[i, elem] : llvm::enumerate (src)) {
468+ Value idx = createI32Constant (builder, loc, i);
469+ res = builder.create <LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
470+ }
471+
472+ if (res.getType () != dstType)
473+ res = builder.create <LLVM::BitcastOp>(loc, dstType, res);
474+
475+ return res;
476+ }
0 commit comments