Skip to content

Commit 6576969

Browse files
committed
[CIR][X86] Implement lowering for AVX512 mask builtins (kadd, kand, kandn, kor, kxor, knot, kmov)
This patch adds CIR codegen support for AVX512 mask operations on X86, including kadd, kand, kandn, kor, kxor, knot, and kmov in all supported mask widths. Each builtin now lowers to the expected vector<i1> form and bitcast representations in CIR, matching the semantics of the corresponding LLVM intrinsics. The patch also adds comprehensive CIR/LLVM/OGCG tests for AVX512F, AVX512DQ, and AVX512BW to validate the lowering behavior.
1 parent 97023fb commit 6576969

File tree

4 files changed

+887
-119
lines changed

4 files changed

+887
-119
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ static mlir::Value emitVectorFCmp(CIRGenBuilderTy &builder,
6868
return bitCast;
6969
}
7070

71+
7172
static mlir::Value getMaskVecValue(CIRGenFunction &cgf, const CallExpr *expr,
7273
mlir::Value mask, unsigned numElems) {
7374

@@ -90,6 +91,37 @@ static mlir::Value getMaskVecValue(CIRGenFunction &cgf, const CallExpr *expr,
9091
return maskVec;
9192
}
9293

94+
static mlir::Value emitX86MaskAddLogic(CIRGenFunction &cgf,
95+
const CallExpr *expr,
96+
const std::string &intrinsicName,
97+
SmallVectorImpl<mlir::Value> &ops) {
98+
CIRGenBuilderTy &builder = cgf.getBuilder();
99+
auto intTy = cast<cir::IntType>(ops[0].getType());
100+
unsigned numElts = intTy.getWidth();
101+
mlir::Value lhsVec = getMaskVecValue(cgf, expr, ops[0], numElts);
102+
mlir::Value rhsVec = getMaskVecValue(cgf, expr, ops[1], numElts);
103+
mlir::Type vecTy = lhsVec.getType();
104+
mlir::Value resVec = emitIntrinsicCallOp(cgf, expr, intrinsicName, vecTy,
105+
mlir::ValueRange{lhsVec, rhsVec});
106+
return builder.createBitcast(resVec, ops[0].getType());
107+
}
108+
109+
static mlir::Value emitX86MaskLogic(CIRGenFunction &cgf, const CallExpr *expr,
110+
cir::BinOpKind binOpKind,
111+
SmallVectorImpl<mlir::Value> &ops,
112+
bool invertLHS = false) {
113+
CIRGenBuilderTy &builder = cgf.getBuilder();
114+
unsigned numElts = cast<cir::IntType>(ops[0].getType()).getWidth();
115+
mlir::Value lhs = getMaskVecValue(cgf, expr, ops[0], numElts);
116+
mlir::Value rhs = getMaskVecValue(cgf, expr, ops[1], numElts);
117+
118+
if (invertLHS)
119+
lhs = builder.createNot(lhs);
120+
return builder.createBitcast(
121+
builder.createBinop(cgf.getLoc(expr->getExprLoc()), lhs, binOpKind, rhs),
122+
ops[0].getType());
123+
}
124+
93125
mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
94126
const CallExpr *expr) {
95127
if (builtinID == Builtin::BI__builtin_cpu_is) {
@@ -743,38 +775,64 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
743775
case X86::BI__builtin_ia32_ktestzsi:
744776
case X86::BI__builtin_ia32_ktestcdi:
745777
case X86::BI__builtin_ia32_ktestzdi:
778+
cgm.errorNYI(expr->getSourceRange(),
779+
std::string("unimplemented X86 builtin call: ") +
780+
getContext().BuiltinInfo.getName(builtinID));
781+
return {};
746782
case X86::BI__builtin_ia32_kaddqi:
783+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.b", ops);
747784
case X86::BI__builtin_ia32_kaddhi:
785+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.w", ops);
748786
case X86::BI__builtin_ia32_kaddsi:
787+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.d", ops);
749788
case X86::BI__builtin_ia32_kadddi:
789+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.q", ops);
750790
case X86::BI__builtin_ia32_kandqi:
751791
case X86::BI__builtin_ia32_kandhi:
752792
case X86::BI__builtin_ia32_kandsi:
753793
case X86::BI__builtin_ia32_kanddi:
794+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops);
754795
case X86::BI__builtin_ia32_kandnqi:
755796
case X86::BI__builtin_ia32_kandnhi:
756797
case X86::BI__builtin_ia32_kandnsi:
757798
case X86::BI__builtin_ia32_kandndi:
799+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops, true);
758800
case X86::BI__builtin_ia32_korqi:
759801
case X86::BI__builtin_ia32_korhi:
760802
case X86::BI__builtin_ia32_korsi:
761803
case X86::BI__builtin_ia32_kordi:
804+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Or, ops);
762805
case X86::BI__builtin_ia32_kxnorqi:
763806
case X86::BI__builtin_ia32_kxnorhi:
764807
case X86::BI__builtin_ia32_kxnorsi:
765808
case X86::BI__builtin_ia32_kxnordi:
809+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops, true);
766810
case X86::BI__builtin_ia32_kxorqi:
767811
case X86::BI__builtin_ia32_kxorhi:
768812
case X86::BI__builtin_ia32_kxorsi:
769813
case X86::BI__builtin_ia32_kxordi:
814+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops);
770815
case X86::BI__builtin_ia32_knotqi:
771816
case X86::BI__builtin_ia32_knothi:
772817
case X86::BI__builtin_ia32_knotsi:
773-
case X86::BI__builtin_ia32_knotdi:
818+
case X86::BI__builtin_ia32_knotdi: {
819+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
820+
unsigned numElts = intTy.getWidth();
821+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
822+
return builder.createBitcast(builder.createNot(resVec), ops[0].getType());
823+
}
774824
case X86::BI__builtin_ia32_kmovb:
775825
case X86::BI__builtin_ia32_kmovw:
776826
case X86::BI__builtin_ia32_kmovd:
777-
case X86::BI__builtin_ia32_kmovq:
827+
case X86::BI__builtin_ia32_kmovq: {
828+
// Bitcast to vXi1 type and then back to integer. This gets the mask
829+
// register type into the IR, but might be optimized out depending on
830+
// what's around it.
831+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
832+
unsigned numElts = intTy.getWidth();
833+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
834+
return builder.createBitcast(resVec, ops[0].getType());
835+
}
778836
case X86::BI__builtin_ia32_kunpckdi:
779837
case X86::BI__builtin_ia32_kunpcksi:
780838
case X86::BI__builtin_ia32_kunpckhi:

0 commit comments

Comments
 (0)