Skip to content

Commit 475e0ee

Browse files
authored
[msan][NFCI] Generalize handleAVX512VectorGenericMaskedFP() operands (#159966)
This generalizes handleAVX512VectorGenericMaskedFP() (introduced in #158397), to potentially handle intrinsics that have A/WriteThru/Mask in an operand order that is different to AVX512/AVX10 rcp and rsqrt. Any operands other than A and WriteThru must be fully initialized. For example, the generalized handler could be applied in follow-up work to many of the AVX512 rndscale intrinsics: ``` <32 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.512(<32 x half>, i32, <32 x half>, i32, i32) <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float>, i32, <16 x float>, i16, i32) <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double>, i32, <8 x double>, i8, i32) A Imm WriteThru Mask Rounding <8 x float> @llvm.x86.avx512.mask.rndscale.ps.256(<8 x float>, i32, <8 x float>, i8) <4 x float> @llvm.x86.avx512.mask.rndscale.ps.128(<4 x float>, i32, <4 x float>, i8) <4 x double> @llvm.x86.avx512.mask.rndscale.pd.256(<4 x double>, i32, <4 x double>, i8) <2 x double> @llvm.x86.avx512.mask.rndscale.pd.128(<2 x double>, i32, <2 x double>, i8) A Imm WriteThru Mask ```
1 parent 370db9c commit 475e0ee

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4926,36 +4926,56 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
49264926
// <2 x double> @llvm.x86.avx512.rcp14.pd.128
49274927
// (<2 x double>, <2 x double>, i8)
49284928
//
4929+
// <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512
4930+
// (<8 x double>, i32, <8 x double>, i8, i32)
4931+
// A Imm WriteThru Mask Rounding
4932+
//
4933+
// All operands other than A and WriteThru (e.g., Mask, Imm, Rounding) must
4934+
// be fully initialized.
4935+
//
49294936
// Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i]
49304937
// Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
4931-
void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I) {
4938+
void handleAVX512VectorGenericMaskedFP(IntrinsicInst &I, unsigned AIndex,
4939+
unsigned WriteThruIndex,
4940+
unsigned MaskIndex) {
49324941
IRBuilder<> IRB(&I);
49334942

4934-
assert(I.arg_size() == 3);
4935-
Value *A = I.getOperand(0);
4936-
Value *WriteThrough = I.getOperand(1);
4937-
Value *Mask = I.getOperand(2);
4943+
unsigned NumArgs = I.arg_size();
4944+
assert(AIndex < NumArgs);
4945+
assert(WriteThruIndex < NumArgs);
4946+
assert(MaskIndex < NumArgs);
4947+
assert(AIndex != WriteThruIndex);
4948+
assert(AIndex != MaskIndex);
4949+
assert(WriteThruIndex != MaskIndex);
4950+
4951+
Value *A = I.getOperand(AIndex);
4952+
Value *WriteThru = I.getOperand(WriteThruIndex);
4953+
Value *Mask = I.getOperand(MaskIndex);
49384954

49394955
assert(isFixedFPVector(A));
4940-
assert(isFixedFPVector(WriteThrough));
4956+
assert(isFixedFPVector(WriteThru));
49414957

49424958
[[maybe_unused]] unsigned ANumElements =
49434959
cast<FixedVectorType>(A->getType())->getNumElements();
49444960
unsigned OutputNumElements =
4945-
cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
4961+
cast<FixedVectorType>(WriteThru->getType())->getNumElements();
49464962
assert(ANumElements == OutputNumElements);
49474963

4948-
assert(Mask->getType()->isIntegerTy());
4949-
// Some bits of the mask might be unused, but check them all anyway
4950-
// (typically the mask is an integer constant).
4951-
insertCheckShadowOf(Mask, &I);
4964+
for (unsigned i = 0; i < NumArgs; ++i) {
4965+
if (i != AIndex && i != WriteThruIndex) {
4966+
// Imm, Mask, Rounding etc. are "control" data, hence we require that
4967+
// they be fully initialized.
4968+
assert(I.getOperand(i)->getType()->isIntegerTy());
4969+
insertCheckShadowOf(I.getOperand(i), &I);
4970+
}
4971+
}
49524972

49534973
// The mask has 1 bit per element of A, but a minimum of 8 bits.
49544974
if (Mask->getType()->getScalarSizeInBits() == 8 && ANumElements < 8)
49554975
Mask = IRB.CreateTrunc(Mask, Type::getIntNTy(*MS.C, ANumElements));
49564976
assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
49574977

4958-
assert(I.getType() == WriteThrough->getType());
4978+
assert(I.getType() == WriteThru->getType());
49594979

49604980
Mask = IRB.CreateBitCast(
49614981
Mask, FixedVectorType::get(IRB.getInt1Ty(), OutputNumElements));
@@ -4966,9 +4986,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
49664986
AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow)),
49674987
AShadow->getType());
49684988

4969-
Value *WriteThroughShadow = getShadow(WriteThrough);
4989+
Value *WriteThruShadow = getShadow(WriteThru);
49704990

4971-
Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
4991+
Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThruShadow);
49724992
setShadow(&I, Shadow);
49734993

49744994
setOriginForNaryOp(I);
@@ -6202,7 +6222,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
62026222
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
62036223
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
62046224
case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
6205-
handleAVX512VectorGenericMaskedFP(I);
6225+
handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1,
6226+
/*MaskIndex=*/2);
62066227
break;
62076228

62086229
// AVX512/AVX10 Reciprocal Square Root
@@ -6253,7 +6274,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
62536274
case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
62546275
case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
62556276
case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
6256-
handleAVX512VectorGenericMaskedFP(I);
6277+
handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/1,
6278+
/*MaskIndex=*/2);
62576279
break;
62586280

62596281
// AVX512 FP16 Arithmetic

0 commit comments

Comments
 (0)