Skip to content

Commit f186964

Browse files
committed
[CIR][X86] Implement lowering for AVX512 mask builtins (kadd, kand, kandn, kor, kxor, knot, kmov)
This patch adds CIR codegen support for AVX512 mask operations on X86, including kadd, kand, kandn, kor, kxor, knot, and kmov in all supported mask widths. Each builtin now lowers to the expected vector<i1> form and bitcast representations in CIR, matching the semantics of the corresponding LLVM intrinsics. The patch also adds comprehensive CIR/LLVM/OGCG tests for AVX512F, AVX512DQ, and AVX512BW to validate the lowering behavior.
1 parent 97023fb commit f186964

File tree

4 files changed

+769
-2
lines changed

4 files changed

+769
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,37 @@ static mlir::Value getMaskVecValue(CIRGenFunction &cgf, const CallExpr *expr,
9090
return maskVec;
9191
}
9292

93+
static mlir::Value emitX86MaskAddLogic(CIRGenFunction &cgf,
94+
const CallExpr *expr,
95+
const std::string &intrinsicName,
96+
SmallVectorImpl<mlir::Value> &ops) {
97+
CIRGenBuilderTy &builder = cgf.getBuilder();
98+
auto intTy = cast<cir::IntType>(ops[0].getType());
99+
unsigned numElts = intTy.getWidth();
100+
mlir::Value lhsVec = getMaskVecValue(cgf, expr, ops[0], numElts);
101+
mlir::Value rhsVec = getMaskVecValue(cgf, expr, ops[1], numElts);
102+
mlir::Type vecTy = lhsVec.getType();
103+
mlir::Value resVec = emitIntrinsicCallOp(cgf, expr, intrinsicName, vecTy,
104+
mlir::ValueRange{lhsVec, rhsVec});
105+
return builder.createBitcast(resVec, ops[0].getType());
106+
}
107+
108+
static mlir::Value emitX86MaskLogic(CIRGenFunction &cgf, const CallExpr *expr,
109+
cir::BinOpKind binOpKind,
110+
SmallVectorImpl<mlir::Value> &ops,
111+
bool invertLHS = false) {
112+
CIRGenBuilderTy &builder = cgf.getBuilder();
113+
unsigned numElts = cast<cir::IntType>(ops[0].getType()).getWidth();
114+
mlir::Value lhs = getMaskVecValue(cgf, expr, ops[0], numElts);
115+
mlir::Value rhs = getMaskVecValue(cgf, expr, ops[1], numElts);
116+
117+
if (invertLHS)
118+
lhs = builder.createNot(lhs);
119+
return builder.createBitcast(
120+
builder.createBinop(cgf.getLoc(expr->getExprLoc()), lhs, binOpKind, rhs),
121+
ops[0].getType());
122+
}
123+
93124
mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
94125
const CallExpr *expr) {
95126
if (builtinID == Builtin::BI__builtin_cpu_is) {
@@ -743,38 +774,64 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
743774
case X86::BI__builtin_ia32_ktestzsi:
744775
case X86::BI__builtin_ia32_ktestcdi:
745776
case X86::BI__builtin_ia32_ktestzdi:
777+
cgm.errorNYI(expr->getSourceRange(),
778+
std::string("unimplemented X86 builtin call: ") +
779+
getContext().BuiltinInfo.getName(builtinID));
780+
return {};
746781
case X86::BI__builtin_ia32_kaddqi:
782+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.b", ops);
747783
case X86::BI__builtin_ia32_kaddhi:
784+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.w", ops);
748785
case X86::BI__builtin_ia32_kaddsi:
786+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.d", ops);
749787
case X86::BI__builtin_ia32_kadddi:
788+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.q", ops);
750789
case X86::BI__builtin_ia32_kandqi:
751790
case X86::BI__builtin_ia32_kandhi:
752791
case X86::BI__builtin_ia32_kandsi:
753792
case X86::BI__builtin_ia32_kanddi:
793+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops);
754794
case X86::BI__builtin_ia32_kandnqi:
755795
case X86::BI__builtin_ia32_kandnhi:
756796
case X86::BI__builtin_ia32_kandnsi:
757797
case X86::BI__builtin_ia32_kandndi:
798+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops, true);
758799
case X86::BI__builtin_ia32_korqi:
759800
case X86::BI__builtin_ia32_korhi:
760801
case X86::BI__builtin_ia32_korsi:
761802
case X86::BI__builtin_ia32_kordi:
803+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Or, ops);
762804
case X86::BI__builtin_ia32_kxnorqi:
763805
case X86::BI__builtin_ia32_kxnorhi:
764806
case X86::BI__builtin_ia32_kxnorsi:
765807
case X86::BI__builtin_ia32_kxnordi:
808+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops, true);
766809
case X86::BI__builtin_ia32_kxorqi:
767810
case X86::BI__builtin_ia32_kxorhi:
768811
case X86::BI__builtin_ia32_kxorsi:
769812
case X86::BI__builtin_ia32_kxordi:
813+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops);
770814
case X86::BI__builtin_ia32_knotqi:
771815
case X86::BI__builtin_ia32_knothi:
772816
case X86::BI__builtin_ia32_knotsi:
773-
case X86::BI__builtin_ia32_knotdi:
817+
case X86::BI__builtin_ia32_knotdi: {
818+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
819+
unsigned numElts = intTy.getWidth();
820+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
821+
return builder.createBitcast(builder.createNot(resVec), ops[0].getType());
822+
}
774823
case X86::BI__builtin_ia32_kmovb:
775824
case X86::BI__builtin_ia32_kmovw:
776825
case X86::BI__builtin_ia32_kmovd:
777-
case X86::BI__builtin_ia32_kmovq:
826+
case X86::BI__builtin_ia32_kmovq: {
827+
// Bitcast to vXi1 type and then back to integer. This gets the mask
828+
// register type into the IR, but might be optimized out depending on
829+
// what's around it.
830+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
831+
unsigned numElts = intTy.getWidth();
832+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
833+
return builder.createBitcast(resVec, ops[0].getType());
834+
}
778835
case X86::BI__builtin_ia32_kunpckdi:
779836
case X86::BI__builtin_ia32_kunpcksi:
780837
case X86::BI__builtin_ia32_kunpckhi:

0 commit comments

Comments
 (0)