@@ -4911,6 +4911,69 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
49114911 setOriginForNaryOp (I);
49124912 }
49134913
4914+ // Handle llvm.x86.avx512.* instructions that take a vector of floating-point
4915+ // values and perform an operation whose shadow propagation should be handled
4916+ // as all-or-nothing [*], with masking provided by a vector and a mask
4917+ // supplied as an integer.
4918+ //
4919+ // [*] if all bits of a vector element are initialized, the output is fully
4920+ // initialized; otherwise, the output is fully uninitialized
4921+ //
4922+ // e.g., <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
4923+ // (<16 x float>, <16 x float>, i16)
4924+ // A WriteThru Mask
4925+ //
4926+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
4927+ // (<2 x double>, <2 x double>, i8)
4928+ //
4929+ // Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i]
4930+ // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
4931+ void handleAVX512VectorGenericMaskedFP (IntrinsicInst &I) {
4932+ IRBuilder<> IRB (&I);
4933+
4934+ assert (I.arg_size () == 3 );
4935+ Value *A = I.getOperand (0 );
4936+ Value *WriteThrough = I.getOperand (1 );
4937+ Value *Mask = I.getOperand (2 );
4938+
4939+ assert (isFixedFPVector (A));
4940+ assert (isFixedFPVector (WriteThrough));
4941+
4942+ [[maybe_unused]] unsigned ANumElements =
4943+ cast<FixedVectorType>(A->getType ())->getNumElements ();
4944+ unsigned OutputNumElements =
4945+ cast<FixedVectorType>(WriteThrough->getType ())->getNumElements ();
4946+ assert (ANumElements == OutputNumElements);
4947+
4948+ assert (Mask->getType ()->isIntegerTy ());
4949+ // Some bits of the mask might be unused, but check them all anyway
4950+ // (typically the mask is an integer constant).
4951+ insertCheckShadowOf (Mask, &I);
4952+
4953+ // The mask has 1 bit per element of A, but a minimum of 8 bits.
4954+ if (Mask->getType ()->getScalarSizeInBits () == 8 && ANumElements < 8 )
4955+ Mask = IRB.CreateTrunc (Mask, Type::getIntNTy (*MS.C , ANumElements));
4956+ assert (Mask->getType ()->getScalarSizeInBits () == ANumElements);
4957+
4958+ assert (I.getType () == WriteThrough->getType ());
4959+
4960+ Mask = IRB.CreateBitCast (
4961+ Mask, FixedVectorType::get (IRB.getInt1Ty (), OutputNumElements));
4962+
4963+ Value *AShadow = getShadow (A);
4964+
4965+ // All-or-nothing shadow
4966+ AShadow = IRB.CreateSExt (IRB.CreateICmpNE (AShadow, getCleanShadow (AShadow)),
4967+ AShadow->getType ());
4968+
4969+ Value *WriteThroughShadow = getShadow (WriteThrough);
4970+
4971+ Value *Shadow = IRB.CreateSelect (Mask, AShadow, WriteThroughShadow);
4972+ setShadow (&I, Shadow);
4973+
4974+ setOriginForNaryOp (I);
4975+ }
4976+
49144977 // For sh.* compiler intrinsics:
49154978 // llvm.x86.avx512fp16.mask.{add/sub/mul/div/max/min}.sh.round
49164979 // (<8 x half>, <8 x half>, <8 x half>, i8, i32)
@@ -6091,6 +6154,108 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
60916154 break ;
60926155 }
60936156
6157+ // AVX512/AVX10 Reciprocal
6158+ // <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
6159+ // (<16 x float>, <16 x float>, i16)
6160+ // <8 x float> @llvm.x86.avx512.rsqrt14.ps.256
6161+ // (<8 x float>, <8 x float>, i8)
6162+ // <4 x float> @llvm.x86.avx512.rsqrt14.ps.128
6163+ // (<4 x float>, <4 x float>, i8)
6164+ //
6165+ // <8 x double> @llvm.x86.avx512.rsqrt14.pd.512
6166+ // (<8 x double>, <8 x double>, i8)
6167+ // <4 x double> @llvm.x86.avx512.rsqrt14.pd.256
6168+ // (<4 x double>, <4 x double>, i8)
6169+ // <2 x double> @llvm.x86.avx512.rsqrt14.pd.128
6170+ // (<2 x double>, <2 x double>, i8)
6171+ //
6172+ // <32 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.512
6173+ // (<32 x bfloat>, <32 x bfloat>, i32)
6174+ // <16 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.256
6175+ // (<16 x bfloat>, <16 x bfloat>, i16)
6176+ // <8 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.128
6177+ // (<8 x bfloat>, <8 x bfloat>, i8)
6178+ //
6179+ // <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512
6180+ // (<32 x half>, <32 x half>, i32)
6181+ // <16 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.256
6182+ // (<16 x half>, <16 x half>, i16)
6183+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.128
6184+ // (<8 x half>, <8 x half>, i8)
6185+ //
6186+ // TODO: 3-operand variants are not handled:
6187+ // <2 x double> @llvm.x86.avx512.rsqrt14.sd
6188+ // (<2 x double>, <2 x double>, <2 x double>, i8)
6189+ // <4 x float> @llvm.x86.avx512.rsqrt14.ss
6190+ // (<4 x float>, <4 x float>, <4 x float>, i8)
6191+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.sh
6192+ // (<8 x half>, <8 x half>, <8 x half>, i8)
6193+ case Intrinsic::x86_avx512_rsqrt14_ps_512:
6194+ case Intrinsic::x86_avx512_rsqrt14_ps_256:
6195+ case Intrinsic::x86_avx512_rsqrt14_ps_128:
6196+ case Intrinsic::x86_avx512_rsqrt14_pd_512:
6197+ case Intrinsic::x86_avx512_rsqrt14_pd_256:
6198+ case Intrinsic::x86_avx512_rsqrt14_pd_128:
6199+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_512:
6200+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_256:
6201+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_128:
6202+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
6203+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
6204+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
6205+ handleAVX512VectorGenericMaskedFP (I);
6206+ break ;
6207+
6208+ // AVX512/AVX10 Reciprocal Square Root
6209+ // <16 x float> @llvm.x86.avx512.rcp14.ps.512
6210+ // (<16 x float>, <16 x float>, i16)
6211+ // <8 x float> @llvm.x86.avx512.rcp14.ps.256
6212+ // (<8 x float>, <8 x float>, i8)
6213+ // <4 x float> @llvm.x86.avx512.rcp14.ps.128
6214+ // (<4 x float>, <4 x float>, i8)
6215+ //
6216+ // <8 x double> @llvm.x86.avx512.rcp14.pd.512
6217+ // (<8 x double>, <8 x double>, i8)
6218+ // <4 x double> @llvm.x86.avx512.rcp14.pd.256
6219+ // (<4 x double>, <4 x double>, i8)
6220+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
6221+ // (<2 x double>, <2 x double>, i8)
6222+ //
6223+ // <32 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.512
6224+ // (<32 x bfloat>, <32 x bfloat>, i32)
6225+ // <16 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.256
6226+ // (<16 x bfloat>, <16 x bfloat>, i16)
6227+ // <8 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.128
6228+ // (<8 x bfloat>, <8 x bfloat>, i8)
6229+ //
6230+ // <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512
6231+ // (<32 x half>, <32 x half>, i32)
6232+ // <16 x half> @llvm.x86.avx512fp16.mask.rcp.ph.256
6233+ // (<16 x half>, <16 x half>, i16)
6234+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.ph.128
6235+ // (<8 x half>, <8 x half>, i8)
6236+ //
6237+ // TODO: 3-operand variants are not handled:
6238+ // <2 x double> @llvm.x86.avx512.rcp14.sd
6239+ // (<2 x double>, <2 x double>, <2 x double>, i8)
6240+ // <4 x float> @llvm.x86.avx512.rcp14.ss
6241+ // (<4 x float>, <4 x float>, <4 x float>, i8)
6242+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.sh
6243+ // (<8 x half>, <8 x half>, <8 x half>, i8)
6244+ case Intrinsic::x86_avx512_rcp14_ps_512:
6245+ case Intrinsic::x86_avx512_rcp14_ps_256:
6246+ case Intrinsic::x86_avx512_rcp14_ps_128:
6247+ case Intrinsic::x86_avx512_rcp14_pd_512:
6248+ case Intrinsic::x86_avx512_rcp14_pd_256:
6249+ case Intrinsic::x86_avx512_rcp14_pd_128:
6250+ case Intrinsic::x86_avx10_mask_rcp_bf16_512:
6251+ case Intrinsic::x86_avx10_mask_rcp_bf16_256:
6252+ case Intrinsic::x86_avx10_mask_rcp_bf16_128:
6253+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
6254+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
6255+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
6256+ handleAVX512VectorGenericMaskedFP (I);
6257+ break ;
6258+
60946259 // AVX512 FP16 Arithmetic
60956260 case Intrinsic::x86_avx512fp16_mask_add_sh_round:
60966261 case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
0 commit comments