diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index ec94dcaa2c051..ebc2163f491ec 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -4342,6 +4342,61 @@ struct MemorySanitizerVisitor : public InstVisitor { setOriginForNaryOp(I); } + // e.g., call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512 + // (<16 x float> a, <16 x i32> writethru, i16 mask, + // i32 rounding) + // + // dst[i] = mask[i] ? convert(a[i]) : writethru[i] + // dst_shadow[i] = mask[i] ? all_or_nothing(a_shadow[i]) : writethru_shadow[i] + // where all_or_nothing(x) is fully uninitialized if x has any + // uninitialized bits + void handleAVX512VectorConvertFPToInt(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + + assert(I.arg_size() == 4); + Value *A = I.getOperand(0); + Value *WriteThrough = I.getOperand(1); + Value *Mask = I.getOperand(2); + [[maybe_unused]] Value *RoundingMode = I.getOperand(3); + + assert(isa(A->getType())); + assert(A->getType()->isFPOrFPVectorTy()); + + assert(isa(WriteThrough->getType())); + assert(WriteThrough->getType()->isIntOrIntVectorTy()); + + unsigned ANumElements = + cast(A->getType())->getNumElements(); + assert(ANumElements == + cast(WriteThrough->getType())->getNumElements()); + + assert(Mask->getType()->isIntegerTy()); + assert(Mask->getType()->getScalarSizeInBits() == ANumElements); + + assert(RoundingMode->getType()->isIntegerTy()); + + assert(I.getType() == WriteThrough->getType()); + + // Convert i16 mask to <16 x i1> + Mask = IRB.CreateBitCast( + Mask, FixedVectorType::get(IRB.getInt1Ty(), ANumElements)); + + Value *AShadow = getShadow(A); + /// For scalars: + /// Since they are converting from floating-point, the output is: + /// - fully uninitialized if *any* bit of the input is uninitialized + /// - fully ininitialized if all bits of the input are ininitialized + /// We apply the same principle on a per-element basis for vectors. + AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(A)), + getShadowTy(A)); + + Value *WriteThroughShadow = getShadow(WriteThrough); + Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow); + + setShadow(&I, Shadow); + setOriginForNaryOp(I); + } + // Instrument BMI / BMI2 intrinsics. // All of these intrinsics are Z = I(X, Y) // where the types of all operands and the result match, and are either i32 or @@ -5318,6 +5373,11 @@ struct MemorySanitizerVisitor : public InstVisitor { handleAVXVpermi2var(I); break; + case Intrinsic::x86_avx512_mask_cvtps2dq_512: { + handleAVX512VectorConvertFPToInt(I); + break; + } + case Intrinsic::x86_avx512fp16_mask_add_sh_round: case Intrinsic::x86_avx512fp16_mask_sub_sh_round: case Intrinsic::x86_avx512fp16_mask_mul_sh_round: diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll index 43595dccc35bc..10144509d96a6 100644 --- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll +++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll @@ -8152,34 +8152,19 @@ define <16 x i32>@test_int_x86_avx512_mask_cvt_ps2dq_512(<16 x float> %x0, <16 x ; CHECK-LABEL: @test_int_x86_avx512_mask_cvt_ps2dq_512( ; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8 ; CHECK-NEXT: [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8 -; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8 ; CHECK-NEXT: call void @llvm.donothing() -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i32> [[TMP1]] to i512 -; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <16 x i32> [[TMP2]] to i512 -; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0 -; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]] -; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i16 [[TMP3]], 0 -; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]] -; CHECK-NEXT: br i1 [[_MSOR3]], label [[TMP6:%.*]], label [[TMP7:%.*]], !prof [[PROF1]] -; CHECK: 6: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]] -; CHECK-NEXT: unreachable -; CHECK: 7: -; CHECK-NEXT: [[RES:%.*]] = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> [[X0:%.*]], <16 x i32> [[X1:%.*]], i16 [[X2:%.*]], i32 10) -; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i32> [[TMP1]] to i512 -; CHECK-NEXT: [[_MSCMP4:%.*]] = icmp ne i512 [[TMP8]], 0 -; CHECK-NEXT: [[TMP9:%.*]] = bitcast <16 x i32> [[TMP2]] to i512 -; CHECK-NEXT: [[_MSCMP5:%.*]] = icmp ne i512 [[TMP9]], 0 -; CHECK-NEXT: [[_MSOR6:%.*]] = or i1 [[_MSCMP4]], [[_MSCMP5]] -; CHECK-NEXT: br i1 [[_MSOR6]], label [[TMP10:%.*]], label [[TMP11:%.*]], !prof [[PROF1]] -; CHECK: 10: -; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]] -; CHECK-NEXT: unreachable -; CHECK: 11: +; CHECK-NEXT: [[TMP3:%.*]] = bitcast i16 [[X2:%.*]] to <16 x i1> +; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP5:%.*]] = sext <16 x i1> [[TMP4]] to <16 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = select <16 x i1> [[TMP3]], <16 x i32> [[TMP5]], <16 x i32> [[TMP2]] +; CHECK-NEXT: [[RES:%.*]] = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> [[X0:%.*]], <16 x i32> [[X1:%.*]], i16 [[X2]], i32 10) +; CHECK-NEXT: [[TMP7:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = sext <16 x i1> [[TMP7]] to <16 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP8]], <16 x i32> [[TMP2]] ; CHECK-NEXT: [[RES1:%.*]] = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> [[X0]], <16 x i32> [[X1]], i16 -1, i32 8) +; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i32> [[TMP6]], [[TMP9]] ; CHECK-NEXT: [[RES2:%.*]] = add <16 x i32> [[RES]], [[RES1]] -; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8 +; CHECK-NEXT: store <16 x i32> [[_MSPROP]], ptr @__msan_retval_tls, align 8 ; CHECK-NEXT: ret <16 x i32> [[RES2]] ; %res = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> %x0, <16 x i32> %x1, i16 %x2, i32 10)