Skip to content

Commit b2ff774

Browse files
committed
[msan] Handle AVX512 VCVTPS2PH
This extends maybeExtendVectorShadowWithZeros() from 556c846 (#147377) to handle AVX512 VCVTPS2PH.
1 parent 6127e46 commit b2ff774

File tree

3 files changed

+151
-97
lines changed

3 files changed

+151
-97
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3429,8 +3429,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
34293429
return ShadowType;
34303430
}
34313431

3432-
/// Doubles the length of a vector shadow (filled with zeros) if necessary to
3433-
/// match the length of the shadow for the instruction.
3432+
/// Doubles the length of a vector shadow (extending with zeros) if necessary
3433+
/// to match the length of the shadow for the instruction.
3434+
/// If scalar types of the vectors are different, it will use the type of the
3435+
/// input vector.
34343436
/// This is more type-safe than CreateShadowCast().
34353437
Value *maybeExtendVectorShadowWithZeros(Value *Shadow, IntrinsicInst &I) {
34363438
IRBuilder<> IRB(&I);
@@ -3440,10 +3442,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
34403442
Value *FullShadow = getCleanShadow(&I);
34413443
assert(cast<FixedVectorType>(Shadow->getType())->getNumElements() <=
34423444
cast<FixedVectorType>(FullShadow->getType())->getNumElements());
3443-
assert(cast<FixedVectorType>(Shadow->getType())->getScalarType() ==
3444-
cast<FixedVectorType>(FullShadow->getType())->getScalarType());
34453445

3446-
if (Shadow->getType() == FullShadow->getType()) {
3446+
if (cast<FixedVectorType>(Shadow->getType())->getNumElements() ==
3447+
cast<FixedVectorType>(FullShadow->getType())->getNumElements()) {
34473448
FullShadow = Shadow;
34483449
} else {
34493450
// TODO: generalize beyond 2x?
@@ -4528,55 +4529,93 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
45284529
return isFixedFPVectorTy(V->getType());
45294530
}
45304531

4531-
// e.g., call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512
4532-
// (<16 x float> a, <16 x i32> writethru, i16 mask,
4533-
// i32 rounding)
4532+
// e.g., <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512
4533+
// (<16 x float> a, <16 x i32> writethru, i16 mask,
4534+
// i32 rounding)
4535+
//
4536+
// Inconveniently, some similar intrinsics have a different operand order:
4537+
// <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512
4538+
// (<16 x float> a, i32 rounding, <16 x i16> writethru,
4539+
// i16 mask)
4540+
//
4541+
// If the return type has more elements than A, the excess elements are
4542+
// zeroed (and the corresponding shadow is initialized).
4543+
// <8 x i16> @llvm.x86.avx512.mask.vcvtps2ph.128
4544+
// (<4 x float> a, i32 rounding, <8 x i16> writethru,
4545+
// i8 mask)
45344546
//
45354547
// dst[i] = mask[i] ? convert(a[i]) : writethru[i]
45364548
// dst_shadow[i] = mask[i] ? all_or_nothing(a_shadow[i]) : writethru_shadow[i]
45374549
// where all_or_nothing(x) is fully uninitialized if x has any
45384550
// uninitialized bits
4539-
void handleAVX512VectorConvertFPToInt(IntrinsicInst &I) {
4551+
void handleAVX512VectorConvertFPToInt(IntrinsicInst &I, bool LastMask) {
45404552
IRBuilder<> IRB(&I);
45414553

45424554
assert(I.arg_size() == 4);
45434555
Value *A = I.getOperand(0);
4544-
Value *WriteThrough = I.getOperand(1);
4545-
Value *Mask = I.getOperand(2);
4546-
Value *RoundingMode = I.getOperand(3);
4556+
Value *WriteThrough;
4557+
Value *Mask;
4558+
Value *RoundingMode;
4559+
if (LastMask) {
4560+
WriteThrough = I.getOperand(2);
4561+
Mask = I.getOperand(3);
4562+
RoundingMode = I.getOperand(1);
4563+
} else {
4564+
WriteThrough = I.getOperand(1);
4565+
Mask = I.getOperand(2);
4566+
RoundingMode = I.getOperand(3);
4567+
}
45474568

45484569
assert(isFixedFPVector(A));
45494570
assert(isFixedIntVector(WriteThrough));
45504571

45514572
unsigned ANumElements =
45524573
cast<FixedVectorType>(A->getType())->getNumElements();
4553-
assert(ANumElements ==
4554-
cast<FixedVectorType>(WriteThrough->getType())->getNumElements());
4574+
unsigned WriteThruNumElements =
4575+
cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
4576+
assert(ANumElements == WriteThruNumElements ||
4577+
ANumElements * 2 == WriteThruNumElements);
45554578

45564579
assert(Mask->getType()->isIntegerTy());
4557-
assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
4580+
unsigned MaskNumElements = Mask->getType()->getScalarSizeInBits();
4581+
assert(ANumElements == MaskNumElements ||
4582+
ANumElements * 2 == MaskNumElements);
4583+
4584+
assert(WriteThruNumElements == MaskNumElements);
4585+
45584586
insertCheckShadowOf(Mask, &I);
45594587

45604588
assert(RoundingMode->getType()->isIntegerTy());
4561-
// Only four bits of the rounding mode are used, though it's very
4589+
// Only some bits of the rounding mode are used, though it's very
45624590
// unusual to have uninitialized bits there (more commonly, it's a
45634591
// constant).
45644592
insertCheckShadowOf(RoundingMode, &I);
45654593

45664594
assert(I.getType() == WriteThrough->getType());
45674595

4596+
Value *AShadow = getShadow(A);
4597+
AShadow = maybeExtendVectorShadowWithZeros(AShadow, I);
4598+
4599+
if (ANumElements * 2 == MaskNumElements) {
4600+
// Ensure that the irrelevant bits of the mask are zero, hence selecting
4601+
// from the zeroed shadow instead of the writethrough's shadow.
4602+
Mask = IRB.CreateTrunc(Mask, IRB.getIntNTy(ANumElements));
4603+
Mask = IRB.CreateZExt(Mask, IRB.getIntNTy(MaskNumElements));
4604+
}
4605+
45684606
// Convert i16 mask to <16 x i1>
45694607
Mask = IRB.CreateBitCast(
4570-
Mask, FixedVectorType::get(IRB.getInt1Ty(), ANumElements));
4608+
Mask, FixedVectorType::get(IRB.getInt1Ty(), MaskNumElements));
45714609

4572-
Value *AShadow = getShadow(A);
4573-
/// For scalars:
4574-
/// Since they are converting from floating-point, the output is:
4610+
/// For floating-point to integer conversion, the output is:
45754611
/// - fully uninitialized if *any* bit of the input is uninitialized
45764612
/// - fully ininitialized if all bits of the input are ininitialized
45774613
/// We apply the same principle on a per-element basis for vectors.
4578-
AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(A)),
4579-
getShadowTy(A));
4614+
///
4615+
/// We use the scalar width of the return type instead of A's.
4616+
AShadow = IRB.CreateSExt(
4617+
IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow->getType())),
4618+
getShadowTy(&I));
45804619

45814620
Value *WriteThroughShadow = getShadow(WriteThrough);
45824621
Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
@@ -5920,11 +5959,29 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
59205959
/*trailingVerbatimArgs=*/1);
59215960
break;
59225961

5962+
// Convert Packed Single Precision Floating-Point Values
5963+
// to Packed SignedDoubleword Integer Values
5964+
//
5965+
// <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512
5966+
// (<16 x float>, <16 x i32>, i16, i32)
59235967
case Intrinsic::x86_avx512_mask_cvtps2dq_512: {
5924-
handleAVX512VectorConvertFPToInt(I);
5968+
handleAVX512VectorConvertFPToInt(I, /*LastMask=*/false);
59255969
break;
59265970
}
59275971

5972+
// Convert Single-Precision FP Value to 16-bit FP Value
5973+
// <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512
5974+
// (<16 x float>, i32, <16 x i16>, i16)
5975+
// <8 x i16> @llvm.x86.avx512.mask.vcvtps2ph.128
5976+
// (<4 x float>, i32, <8 x i16>, i8)
5977+
// <8 x i16> @llvm.x86.avx512.mask.vcvtps2ph.256
5978+
// (<8 x float>, i32, <8 x i16>, i8)
5979+
case Intrinsic::x86_avx512_mask_vcvtps2ph_512:
5980+
case Intrinsic::x86_avx512_mask_vcvtps2ph_256:
5981+
case Intrinsic::x86_avx512_mask_vcvtps2ph_128:
5982+
handleAVX512VectorConvertFPToInt(I, /*LastMask=*/true);
5983+
break;
5984+
59285985
// AVX512 PMOV: Packed MOV, with truncation
59295986
// Precisely handled by applying the same intrinsic to the shadow
59305987
case Intrinsic::x86_avx512_mask_pmov_dw_512:

llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,50 +1903,46 @@ define <16 x i16> @test_x86_vcvtps2ph_256(<16 x float> %a0, <16 x i16> %src, i16
19031903
; CHECK-NEXT: [[TMP3:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
19041904
; CHECK-NEXT: [[TMP4:%.*]] = load i64, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 104) to ptr), align 8
19051905
; CHECK-NEXT: call void @llvm.donothing()
1906-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
1907-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP5]], 0
1908-
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP6:%.*]], label [[TMP7:%.*]], !prof [[PROF1]]
1909-
; CHECK: 6:
1910-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
1911-
; CHECK-NEXT: unreachable
1912-
; CHECK: 7:
1906+
; CHECK-NEXT: [[TMP6:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
1907+
; CHECK-NEXT: [[TMP7:%.*]] = sext <16 x i1> [[TMP6]] to <16 x i16>
1908+
; CHECK-NEXT: [[TMP8:%.*]] = select <16 x i1> splat (i1 true), <16 x i16> [[TMP7]], <16 x i16> zeroinitializer
19131909
; CHECK-NEXT: [[RES1:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0:%.*]], i32 2, <16 x i16> zeroinitializer, i16 -1)
1914-
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
1915-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP8]], 0
1910+
; CHECK-NEXT: [[TMP10:%.*]] = bitcast i16 [[MASK:%.*]] to <16 x i1>
1911+
; CHECK-NEXT: [[TMP11:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
1912+
; CHECK-NEXT: [[TMP12:%.*]] = sext <16 x i1> [[TMP11]] to <16 x i16>
1913+
; CHECK-NEXT: [[TMP13:%.*]] = select <16 x i1> [[TMP10]], <16 x i16> [[TMP12]], <16 x i16> zeroinitializer
19161914
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i16 [[TMP2]], 0
1917-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP1]], [[_MSCMP2]]
1918-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP9:%.*]], label [[TMP10:%.*]], !prof [[PROF1]]
1919-
; CHECK: 9:
1915+
; CHECK-NEXT: br i1 [[_MSCMP2]], label [[TMP14:%.*]], label [[TMP15:%.*]], !prof [[PROF1]]
1916+
; CHECK: 12:
19201917
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
19211918
; CHECK-NEXT: unreachable
1922-
; CHECK: 10:
1923-
; CHECK-NEXT: [[RES2:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0]], i32 11, <16 x i16> zeroinitializer, i16 [[MASK:%.*]])
1924-
; CHECK-NEXT: [[TMP11:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
1925-
; CHECK-NEXT: [[_MSCMP3:%.*]] = icmp ne i512 [[TMP11]], 0
1926-
; CHECK-NEXT: [[TMP12:%.*]] = bitcast <16 x i16> [[TMP3]] to i256
1927-
; CHECK-NEXT: [[_MSCMP4:%.*]] = icmp ne i256 [[TMP12]], 0
1928-
; CHECK-NEXT: [[_MSOR5:%.*]] = or i1 [[_MSCMP3]], [[_MSCMP4]]
1929-
; CHECK-NEXT: [[_MSCMP6:%.*]] = icmp ne i16 [[TMP2]], 0
1930-
; CHECK-NEXT: [[_MSOR7:%.*]] = or i1 [[_MSOR5]], [[_MSCMP6]]
1931-
; CHECK-NEXT: br i1 [[_MSOR7]], label [[TMP13:%.*]], label [[TMP14:%.*]], !prof [[PROF1]]
19321919
; CHECK: 13:
1920+
; CHECK-NEXT: [[RES2:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0]], i32 11, <16 x i16> zeroinitializer, i16 [[MASK]])
1921+
; CHECK-NEXT: [[TMP25:%.*]] = bitcast i16 [[MASK]] to <16 x i1>
1922+
; CHECK-NEXT: [[TMP26:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
1923+
; CHECK-NEXT: [[TMP27:%.*]] = sext <16 x i1> [[TMP26]] to <16 x i16>
1924+
; CHECK-NEXT: [[TMP20:%.*]] = select <16 x i1> [[TMP25]], <16 x i16> [[TMP27]], <16 x i16> [[TMP3]]
1925+
; CHECK-NEXT: [[_MSCMP6:%.*]] = icmp ne i16 [[TMP2]], 0
1926+
; CHECK-NEXT: br i1 [[_MSCMP6]], label [[TMP22:%.*]], label [[TMP23:%.*]], !prof [[PROF1]]
1927+
; CHECK: 18:
19331928
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
19341929
; CHECK-NEXT: unreachable
1935-
; CHECK: 14:
1930+
; CHECK: 19:
19361931
; CHECK-NEXT: [[RES3:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0]], i32 12, <16 x i16> [[SRC:%.*]], i16 [[MASK]])
19371932
; CHECK-NEXT: [[_MSCMP8:%.*]] = icmp ne i64 [[TMP4]], 0
1938-
; CHECK-NEXT: br i1 [[_MSCMP8]], label [[TMP15:%.*]], label [[TMP16:%.*]], !prof [[PROF1]]
1939-
; CHECK: 15:
1933+
; CHECK-NEXT: br i1 [[_MSCMP8]], label [[TMP24:%.*]], label [[TMP21:%.*]], !prof [[PROF1]]
1934+
; CHECK: 20:
19401935
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
19411936
; CHECK-NEXT: unreachable
1942-
; CHECK: 16:
1937+
; CHECK: 21:
19431938
; CHECK-NEXT: [[TMP17:%.*]] = ptrtoint ptr [[DST:%.*]] to i64
19441939
; CHECK-NEXT: [[TMP18:%.*]] = xor i64 [[TMP17]], 87960930222080
19451940
; CHECK-NEXT: [[TMP19:%.*]] = inttoptr i64 [[TMP18]] to ptr
1946-
; CHECK-NEXT: store <16 x i16> zeroinitializer, ptr [[TMP19]], align 32
1941+
; CHECK-NEXT: store <16 x i16> [[TMP8]], ptr [[TMP19]], align 32
19471942
; CHECK-NEXT: store <16 x i16> [[RES1]], ptr [[DST]], align 32
1943+
; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i16> [[TMP13]], [[TMP20]]
19481944
; CHECK-NEXT: [[RES:%.*]] = add <16 x i16> [[RES2]], [[RES3]]
1949-
; CHECK-NEXT: store <16 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
1945+
; CHECK-NEXT: store <16 x i16> [[_MSPROP]], ptr @__msan_retval_tls, align 8
19501946
; CHECK-NEXT: ret <16 x i16> [[RES]]
19511947
;
19521948
%res1 = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> %a0, i32 2, <16 x i16> zeroinitializer, i16 -1)

0 commit comments

Comments
 (0)