Skip to content

Commit cd9eb6c

Browse files
committed
[CIR][X86] Implement lowering for AVX512 mask builtins (kadd, kand, kandn, kor, kxor, knot, kmov)
This patch adds CIR codegen support for AVX512 mask operations on X86, including kadd, kand, kandn, kor, kxor, knot, and kmov in all supported mask widths. Each builtin now lowers to the expected vector<i1> form and bitcast representations in CIR, matching the semantics of the corresponding LLVM intrinsics. The patch also adds comprehensive CIR/LLVM/OGCG tests for AVX512F, AVX512DQ, and AVX512BW to validate the lowering behavior.
1 parent 8baa5bf commit cd9eb6c

File tree

4 files changed

+807
-2
lines changed

4 files changed

+807
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,60 @@ static mlir::Value emitVectorFCmp(CIRGenBuilderTy &builder,
6868
return bitCast;
6969
}
7070

71+
// Convert the mask from an integer type to a vector of i1.
72+
static mlir::Value getMaskVecValue(CIRGenFunction &cgf, const CallExpr *expr,
73+
mlir::Value mask, unsigned numElems) {
74+
auto &builder = cgf.getBuilder();
75+
76+
cir::VectorType maskTy =
77+
cir::VectorType::get(cgf.getBuilder().getSIntNTy(1),
78+
cast<cir::IntType>(mask.getType()).getWidth());
79+
mlir::Value maskVec = builder.createBitcast(mask, maskTy);
80+
81+
// If we have less than 8 elements, then the starting mask was an i8 and
82+
// we need to extract down to the right number of elements.
83+
if (numElems < 8) {
84+
SmallVector<mlir::Attribute, 4> indices;
85+
mlir::Type i32Ty = builder.getI32Type();
86+
for (auto i : llvm::seq<unsigned>(0, numElems))
87+
indices.push_back(cir::IntAttr::get(i32Ty, i));
88+
maskVec = builder.createVecShuffle(cgf.getLoc(expr->getExprLoc()), maskVec,
89+
maskVec, indices);
90+
}
91+
return maskVec;
92+
}
93+
94+
static mlir::Value emitX86MaskAddLogic(CIRGenFunction &cgf,
95+
const CallExpr *expr,
96+
const std::string &intrinsicName,
97+
SmallVectorImpl<mlir::Value> &ops) {
98+
CIRGenBuilderTy &builder = cgf.getBuilder();
99+
auto intTy = cast<cir::IntType>(ops[0].getType());
100+
unsigned numElts = intTy.getWidth();
101+
mlir::Value lhsVec = getMaskVecValue(cgf, expr, ops[0], numElts);
102+
mlir::Value rhsVec = getMaskVecValue(cgf, expr, ops[1], numElts);
103+
mlir::Type vecTy = lhsVec.getType();
104+
mlir::Value resVec = emitIntrinsicCallOp(cgf, expr, intrinsicName, vecTy,
105+
mlir::ValueRange{lhsVec, rhsVec});
106+
return builder.createBitcast(resVec, ops[0].getType());
107+
}
108+
109+
static mlir::Value emitX86MaskLogic(CIRGenFunction &cgf, const CallExpr *expr,
110+
cir::BinOpKind binOpKind,
111+
SmallVectorImpl<mlir::Value> &ops,
112+
bool invertLHS = false) {
113+
CIRGenBuilderTy &builder = cgf.getBuilder();
114+
unsigned numElts = cast<cir::IntType>(ops[0].getType()).getWidth();
115+
mlir::Value lhs = getMaskVecValue(cgf, expr, ops[0], numElts);
116+
mlir::Value rhs = getMaskVecValue(cgf, expr, ops[1], numElts);
117+
118+
if (invertLHS)
119+
lhs = builder.createNot(lhs);
120+
return builder.createBitcast(
121+
builder.createBinop(cgf.getLoc(expr->getExprLoc()), lhs, binOpKind, rhs),
122+
ops[0].getType());
123+
}
124+
71125
mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
72126
const CallExpr *expr) {
73127
if (builtinID == Builtin::BI__builtin_cpu_is) {
@@ -675,38 +729,64 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
675729
case X86::BI__builtin_ia32_ktestzsi:
676730
case X86::BI__builtin_ia32_ktestcdi:
677731
case X86::BI__builtin_ia32_ktestzdi:
732+
cgm.errorNYI(expr->getSourceRange(),
733+
std::string("unimplemented X86 builtin call: ") +
734+
getContext().BuiltinInfo.getName(builtinID));
735+
return {};
678736
case X86::BI__builtin_ia32_kaddqi:
737+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.b", ops);
679738
case X86::BI__builtin_ia32_kaddhi:
739+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.w", ops);
680740
case X86::BI__builtin_ia32_kaddsi:
741+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.d", ops);
681742
case X86::BI__builtin_ia32_kadddi:
743+
return emitX86MaskAddLogic(*this, expr, "x86.avx512.kadd.q", ops);
682744
case X86::BI__builtin_ia32_kandqi:
683745
case X86::BI__builtin_ia32_kandhi:
684746
case X86::BI__builtin_ia32_kandsi:
685747
case X86::BI__builtin_ia32_kanddi:
748+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops);
686749
case X86::BI__builtin_ia32_kandnqi:
687750
case X86::BI__builtin_ia32_kandnhi:
688751
case X86::BI__builtin_ia32_kandnsi:
689752
case X86::BI__builtin_ia32_kandndi:
753+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::And, ops, true);
690754
case X86::BI__builtin_ia32_korqi:
691755
case X86::BI__builtin_ia32_korhi:
692756
case X86::BI__builtin_ia32_korsi:
693757
case X86::BI__builtin_ia32_kordi:
758+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Or, ops);
694759
case X86::BI__builtin_ia32_kxnorqi:
695760
case X86::BI__builtin_ia32_kxnorhi:
696761
case X86::BI__builtin_ia32_kxnorsi:
697762
case X86::BI__builtin_ia32_kxnordi:
763+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops, true);
698764
case X86::BI__builtin_ia32_kxorqi:
699765
case X86::BI__builtin_ia32_kxorhi:
700766
case X86::BI__builtin_ia32_kxorsi:
701767
case X86::BI__builtin_ia32_kxordi:
768+
return emitX86MaskLogic(*this, expr, cir::BinOpKind::Xor, ops);
702769
case X86::BI__builtin_ia32_knotqi:
703770
case X86::BI__builtin_ia32_knothi:
704771
case X86::BI__builtin_ia32_knotsi:
705-
case X86::BI__builtin_ia32_knotdi:
772+
case X86::BI__builtin_ia32_knotdi: {
773+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
774+
unsigned numElts = intTy.getWidth();
775+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
776+
return builder.createBitcast(builder.createNot(resVec), ops[0].getType());
777+
}
706778
case X86::BI__builtin_ia32_kmovb:
707779
case X86::BI__builtin_ia32_kmovw:
708780
case X86::BI__builtin_ia32_kmovd:
709-
case X86::BI__builtin_ia32_kmovq:
781+
case X86::BI__builtin_ia32_kmovq: {
782+
// Bitcast to vXi1 type and then back to integer. This gets the mask
783+
// register type into the IR, but might be optimized out depending on
784+
// what's around it.
785+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
786+
unsigned numElts = intTy.getWidth();
787+
mlir::Value resVec = getMaskVecValue(*this, expr, ops[0], numElts);
788+
return builder.createBitcast(resVec, ops[0].getType());
789+
}
710790
case X86::BI__builtin_ia32_kunpckdi:
711791
case X86::BI__builtin_ia32_kunpcksi:
712792
case X86::BI__builtin_ia32_kunpckhi:

0 commit comments

Comments
 (0)