Skip to content

Commit 16293d2

Browse files
committed
[CIR][X86] Implement lowering for AVX512 mask builtins (kadd, kand, kandn, kor, kxor, knot, kmov)
This patch adds CIR codegen support for AVX512 mask operations on X86, including kadd, kand, kandn, kor, kxor, knot, and kmov in all supported mask widths. Each builtin now lowers to the expected vector<i1> form and bitcast representations in CIR, matching the semantics of the corresponding LLVM intrinsics. The patch also adds comprehensive CIR/LLVM/OGCG tests for AVX512F, AVX512DQ, and AVX512BW to validate the lowering behavior.
1 parent 8baa5bf commit 16293d2

File tree

4 files changed

+737
-3
lines changed

4 files changed

+737
-3
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,45 @@ static mlir::Value emitVectorFCmp(CIRGenBuilderTy &builder,
6868
return bitCast;
6969
}
7070

71+
// Convert the mask from an integer type to a vector of i1.
72+
static mlir::Value getMaskVecValue(CIRGenFunction &cgf, const CallExpr *expr,
73+
mlir::Value mask, unsigned numElems) {
74+
auto &builder = cgf.getBuilder();
75+
76+
cir::VectorType maskTy =
77+
cir::VectorType::get(cgf.getBuilder().getSIntNTy(1),
78+
cast<cir::IntType>(mask.getType()).getWidth());
79+
mlir::Value maskVec = builder.createBitcast(mask, maskTy);
80+
81+
// If we have less than 8 elements, then the starting mask was an i8 and
82+
// we need to extract down to the right number of elements.
83+
if (numElems < 8) {
84+
SmallVector<mlir::Attribute, 4> indices;
85+
mlir::Type i32Ty = builder.getI32Type();
86+
for (auto i : llvm::seq<unsigned>(0, numElems))
87+
indices.push_back(cir::IntAttr::get(i32Ty, i));
88+
maskVec = builder.createVecShuffle(cgf.getLoc(expr->getExprLoc()), maskVec,
89+
maskVec, indices);
90+
}
91+
return maskVec;
92+
}
93+
94+
static mlir::Value emitX86MaskLogic(CIRGenFunction &cgf, const CallExpr *expr,
95+
cir::BinOpKind opc,
96+
SmallVectorImpl<mlir::Value> &ops,
97+
bool InvertLHS = false) {
98+
CIRGenBuilderTy &builder = cgf.getBuilder();
99+
unsigned numElts = cast<cir::IntType>(ops[0].getType()).getWidth();
100+
mlir::Value LHS = getMaskVecValue(cgf, expr, ops[0], numElts);
101+
mlir::Value RHS = getMaskVecValue(cgf, expr, ops[1], numElts);
102+
103+
if (InvertLHS)
104+
LHS = builder.createNot(LHS);
105+
return builder.createBitcast(
106+
builder.createBinop(cgf.getLoc(expr->getExprLoc()), LHS, opc, RHS),
107+
ops[0].getType());
108+
}
109+
71110
mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
72111
const CallExpr *expr) {
73112
if (builtinID == Builtin::BI__builtin_cpu_is) {
@@ -675,38 +714,86 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
675714
case X86::BI__builtin_ia32_ktestzsi:
676715
case X86::BI__builtin_ia32_ktestcdi:
677716
case X86::BI__builtin_ia32_ktestzdi:
717+
cgm.errorNYI(expr->getSourceRange(),
718+
std::string("unimplemented X86 builtin call: ") +
719+
getContext().BuiltinInfo.getName(builtinID));
720+
return {};
678721
case X86::BI__builtin_ia32_kaddqi:
679722
case X86::BI__builtin_ia32_kaddhi:
680723
case X86::BI__builtin_ia32_kaddsi:
681-
case X86::BI__builtin_ia32_kadddi:
724+
case X86::BI__builtin_ia32_kadddi: {
725+
std::string intrinsicName;
726+
switch (builtinID) {
727+
default:
728+
llvm_unreachable("Unsupported intrinsic!");
729+
case X86::BI__builtin_ia32_kaddqi:
730+
intrinsicName = "x86.avx512.kadd.b";
731+
break;
732+
case X86::BI__builtin_ia32_kaddhi:
733+
intrinsicName = "x86.avx512.kadd.w";
734+
break;
735+
case X86::BI__builtin_ia32_kaddsi:
736+
intrinsicName = "x86.avx512.kadd.d";
737+
break;
738+
case X86::BI__builtin_ia32_kadddi:
739+
intrinsicName = "x86.avx512.kadd.q";
740+
break;
741+
}
742+
auto intTy = cast<cir::IntType>(ops[0].getType());
743+
unsigned numElts = intTy.getWidth();
744+
mlir::Value lhsVec = getMaskVecValue(*this, expr, ops[0], numElts);
745+
mlir::Value rhsVec = getMaskVecValue(*this, expr, ops[1], numElts);
746+
mlir::Type vecTy = lhsVec.getType();
747+
mlir::Value resVec = emitIntrinsicCallOp(*this, expr, intrinsicName, vecTy,
748+
mlir::ValueRange{lhsVec, rhsVec});
749+
return builder.createBitcast(resVec, ops[0].getType());
750+
}
682751
case X86::BI__builtin_ia32_kandqi:
683752
case X86::BI__builtin_ia32_kandhi:
684753
case X86::BI__builtin_ia32_kandsi:
685754
case X86::BI__builtin_ia32_kanddi:
755+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops);
686756
case X86::BI__builtin_ia32_kandnqi:
687757
case X86::BI__builtin_ia32_kandnhi:
688758
case X86::BI__builtin_ia32_kandnsi:
689759
case X86::BI__builtin_ia32_kandndi:
760+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops, true);
690761
case X86::BI__builtin_ia32_korqi:
691762
case X86::BI__builtin_ia32_korhi:
692763
case X86::BI__builtin_ia32_korsi:
693764
case X86::BI__builtin_ia32_kordi:
765+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Or, ops);
694766
case X86::BI__builtin_ia32_kxnorqi:
695767
case X86::BI__builtin_ia32_kxnorhi:
696768
case X86::BI__builtin_ia32_kxnorsi:
697769
case X86::BI__builtin_ia32_kxnordi:
770+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops, true);
698771
case X86::BI__builtin_ia32_kxorqi:
699772
case X86::BI__builtin_ia32_kxorhi:
700773
case X86::BI__builtin_ia32_kxorsi:
701774
case X86::BI__builtin_ia32_kxordi:
775+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops);
702776
case X86::BI__builtin_ia32_knotqi:
703777
case X86::BI__builtin_ia32_knothi:
704778
case X86::BI__builtin_ia32_knotsi:
705-
case X86::BI__builtin_ia32_knotdi:
779+
case X86::BI__builtin_ia32_knotdi: {
780+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
781+
unsigned numElts = intTy.getWidth();
782+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
783+
return builder.createBitcast(builder.createNot(resVec), ops[0].getType());
784+
}
706785
case X86::BI__builtin_ia32_kmovb:
707786
case X86::BI__builtin_ia32_kmovw:
708787
case X86::BI__builtin_ia32_kmovd:
709-
case X86::BI__builtin_ia32_kmovq:
788+
case X86::BI__builtin_ia32_kmovq: {
789+
// Bitcast to vXi1 type and then back to integer. This gets the mask
790+
// register type into the IR, but might be optimized out depending on
791+
// what's around it.
792+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
793+
unsigned numElts = intTy.getWidth();
794+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
795+
return builder.createBitcast(resVec, ops[0].getType());
796+
}
710797
case X86::BI__builtin_ia32_kunpckdi:
711798
case X86::BI__builtin_ia32_kunpcksi:
712799
case X86::BI__builtin_ia32_kunpckhi:

0 commit comments

Comments
 (0)