-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[CIR] X86 vector masked load builtins #169464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-clangir @llvm/pr-subscribers-llvm-ir Author: woruyu (woruyu) ChangesSummaryThis PR resolves #167752. Just for masked load parts. Full diff: https://github.com/llvm/llvm-project/pull/169464.diff 6 Files Affected:
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
index 85b38120169fd..e65dcf7531bfe 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h
+++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h
@@ -603,6 +603,34 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
addr.getAlignment().getAsAlign().value());
}
+ /// Create a call to a Masked Load intrinsic.
+ /// \p loc - expression location
+ /// \p ty - vector type to load
+ /// \p ptr - base pointer for the load
+ /// \p alignment - alignment of the source location
+ /// \p mask - vector of booleans which indicates what vector lanes should
+ /// be accessed in memory
+ /// \p passThru - pass-through value that is used to fill the masked-off
+ /// lanes
+ /// of the result
+ mlir::Value createMaskedLoad(mlir::Location loc, mlir::Type ty,
+ mlir::Value ptr, llvm::Align alignment,
+ mlir::Value mask, mlir::Value passThru) {
+
+ assert(mlir::isa<cir::VectorType>(ty) && "Type should be vector");
+ assert(mask && "Mask should not be all-ones (null)");
+
+ if (!passThru)
+ passThru = this->getConstant(loc, cir::PoisonAttr::get(ty));
+
+ mlir::Value ops[] = {ptr, this->getUInt32(int32_t(alignment.value()), loc),
+ mask, passThru};
+
+ return cir::LLVMIntrinsicCallOp::create(
+ *this, loc, getStringAttr("masked.load"), ty, ops)
+ .getResult();
+ }
+
cir::VecShuffleOp
createVecShuffle(mlir::Location loc, mlir::Value vec1, mlir::Value vec2,
llvm::ArrayRef<mlir::Attribute> maskAttrs) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
index 978fee7dbec9d..6a73227a7baf7 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
@@ -33,6 +33,40 @@ static mlir::Value emitIntrinsicCallOp(CIRGenFunction &cgf, const CallExpr *e,
.getResult();
}
+// Convert the mask from an integer type to a vector of i1.
+static mlir::Value getMaskVecValue(CIRGenFunction &cgf, mlir::Value mask,
+ unsigned numElts, mlir::Location loc) {
+ cir::VectorType maskTy =
+ cir::VectorType::get(cgf.getBuilder().getSIntNTy(1),
+ cast<cir::IntType>(mask.getType()).getWidth());
+
+ mlir::Value maskVec = cgf.getBuilder().createBitcast(mask, maskTy);
+
+ // If we have less than 8 elements, then the starting mask was an i8 and
+ // we need to extract down to the right number of elements.
+ if (numElts < 8) {
+ llvm::SmallVector<int64_t, 4> indices;
+ for (unsigned i = 0; i != numElts; ++i)
+ indices.push_back(i);
+ maskVec = cgf.getBuilder().createVecShuffle(loc, maskVec, maskVec, indices);
+ }
+
+ return maskVec;
+}
+
+static mlir::Value emitX86MaskedLoad(CIRGenFunction &cgf,
+ ArrayRef<mlir::Value> ops,
+ llvm::Align alignment,
+ mlir::Location loc) {
+ mlir::Type ty = ops[1].getType();
+ mlir::Value ptr = ops[0];
+ mlir::Value maskVec =
+ getMaskVecValue(cgf, ops[2], cast<cir::VectorType>(ty).getSize(), loc);
+
+ return cgf.getBuilder().createMaskedLoad(loc, ty, ptr, alignment, maskVec,
+ ops[1]);
+}
+
// OG has unordered comparison as a form of optimization in addition to
// ordered comparison, while CIR doesn't.
//
@@ -327,6 +361,11 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
case X86::BI__builtin_ia32_movdqa64store512_mask:
case X86::BI__builtin_ia32_storeaps512_mask:
case X86::BI__builtin_ia32_storeapd512_mask:
+ cgm.errorNYI(expr->getSourceRange(),
+ std::string("unimplemented X86 builtin call: ") +
+ getContext().BuiltinInfo.getName(builtinID));
+ return {};
+
case X86::BI__builtin_ia32_loadups128_mask:
case X86::BI__builtin_ia32_loadups256_mask:
case X86::BI__builtin_ia32_loadups512_mask:
@@ -345,6 +384,9 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
case X86::BI__builtin_ia32_loaddqudi128_mask:
case X86::BI__builtin_ia32_loaddqudi256_mask:
case X86::BI__builtin_ia32_loaddqudi512_mask:
+ return emitX86MaskedLoad(*this, ops, llvm::Align(1),
+ getLoc(expr->getExprLoc()));
+
case X86::BI__builtin_ia32_loadsbf16128_mask:
case X86::BI__builtin_ia32_loadsh128_mask:
case X86::BI__builtin_ia32_loadss128_mask:
diff --git a/clang/test/CIR/CodeGen/X86/avx512vl-builtins.c b/clang/test/CIR/CodeGen/X86/avx512vl-builtins.c
new file mode 100644
index 0000000000000..2029e3d4b3734
--- /dev/null
+++ b/clang/test/CIR/CodeGen/X86/avx512vl-builtins.c
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -fclangir -emit-cir -o %t.cir -Wall -Werror -Wsign-conversion
+// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -fclangir -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
+// RUN: FileCheck --check-prefixes=LLVM --input-file=%t.ll %s
+
+
+#include <immintrin.h>
+
+__m128 test_mm_mask_loadu_ps(__m128 __W, __mmask8 __U, void const *__P) {
+ // CIR-LABEL: _mm_mask_loadu_ps
+ // CIR: {{%.*}} = cir.call_llvm_intrinsic "masked.load" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (!cir.ptr<!cir.vector<4 x !cir.float>>, !u32i, !cir.vector<4 x !cir.int<s, 1>>, !cir.vector<4 x !cir.float>) -> !cir.vector<4 x !cir.float>
+
+ // LLVM-LABEL: @test_mm_mask_loadu_ps
+ // LLVM: @llvm.masked.load.v4f32.p0(ptr %{{.*}}, i32 1, <4 x i1> %{{.*}}, <4 x float> %{{.*}})
+ return _mm_mask_loadu_ps(__W, __U, __P);
+}
+
+
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 8f3cc54747074..355a2b85defd4 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2524,9 +2524,10 @@ def int_vp_is_fpclass:
//
def int_masked_load:
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
- [llvm_anyptr_ty,
+ [llvm_anyptr_ty, llvm_i32_ty,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, LLVMMatchType<0>],
- [IntrReadMem, IntrArgMemOnly, NoCapture<ArgIndex<0>>]>;
+ [IntrReadMem, IntrArgMemOnly, ImmArg<ArgIndex<1>>,
+ NoCapture<ArgIndex<0>>]>;
def int_masked_store:
DefaultAttrsIntrinsic<[],
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 59a213b47825a..e2f484fada9ed 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -7151,6 +7151,7 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
switch (IID) {
case Intrinsic::masked_load:
case Intrinsic::masked_gather: {
+
Value *MaskArg = Args[1];
Value *PassthruArg = Args[2];
// If the mask is all zeros or undef, the "passthru" argument is the result.
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 7cc1980d24c33..cd757bcb82bda 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6275,10 +6275,13 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
Check(Call.getType()->isVectorTy(), "masked_load: must return a vector",
Call);
- Value *Mask = Call.getArgOperand(1);
- Value *PassThru = Call.getArgOperand(2);
+ ConstantInt *Alignment = cast<ConstantInt>(Call.getArgOperand(1));
+ Value *Mask = Call.getArgOperand(2);
+ Value *PassThru = Call.getArgOperand(3);
Check(Mask->getType()->isVectorTy(), "masked_load: mask must be vector",
Call);
+ Check(Alignment->getValue().isPowerOf2(),
+ "masked_load: alignment must be a power of 2", Call);
Check(PassThru->getType() == Call.getType(),
"masked_load: pass through and return type must match", Call);
Check(cast<VectorType>(Mask->getType())->getElementCount() ==
|
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp,h,c -- clang/test/CIR/CodeGen/X86/avx512vl-builtins.c clang/lib/CIR/CodeGen/CIRGenBuilder.h clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp llvm/lib/Analysis/InstructionSimplify.cpp llvm/lib/IR/Verifier.cpp --diff_from_common_commit
View the diff from clang-format here.diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index e2f484fad..de80dfa4b 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -7151,7 +7151,7 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
switch (IID) {
case Intrinsic::masked_load:
case Intrinsic::masked_gather: {
-
+
Value *MaskArg = Args[1];
Value *PassthruArg = Args[2];
// If the mask is all zeros or undef, the "passthru" argument is the result.
|
|
Friendly ping, @andykaylor and I find #163802 modify int_masked_load Intrinsic to not use mask as args, which means I need to modify cir codes to emit call intrisic(not use mask as args), Is it right? Any suggestion! |
@woruyu Definitely do not revert the LLVM intrinsic back to its previous state. We should probably add masked load and masked store operations to CIR. The LLVM dialect already has these and they take alignment as an attribute. It's possible that the LLVM dialect lowering will need to be updated. @bcardosolopes Has any work been done towards having masked loads and stores in CIR? |
Summary
This PR address #167752. Just for masked load parts.