@@ -4911,6 +4911,69 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
4911
4911
setOriginForNaryOp (I);
4912
4912
}
4913
4913
4914
+ // Handle llvm.x86.avx512.* instructions that take a vector of floating-point
4915
+ // values and perform an operation whose shadow propagation should be handled
4916
+ // as all-or-nothing [*], with masking provided by a vector and a mask
4917
+ // supplied as an integer.
4918
+ //
4919
+ // [*] if all bits of a vector element are initialized, the output is fully
4920
+ // initialized; otherwise, the output is fully uninitialized
4921
+ //
4922
+ // e.g., <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
4923
+ // (<16 x float>, <16 x float>, i16)
4924
+ // A WriteThru Mask
4925
+ //
4926
+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
4927
+ // (<2 x double>, <2 x double>, i8)
4928
+ //
4929
+ // Dst[i] = Mask[i] ? some_op(A[i]) : WriteThru[i]
4930
+ // Dst_shadow[i] = Mask[i] ? all_or_nothing(A_shadow[i]) : WriteThru_shadow[i]
4931
+ void handleAVX512VectorGenericMaskedFP (IntrinsicInst &I) {
4932
+ IRBuilder<> IRB (&I);
4933
+
4934
+ assert (I.arg_size () == 3 );
4935
+ Value *A = I.getOperand (0 );
4936
+ Value *WriteThrough = I.getOperand (1 );
4937
+ Value *Mask = I.getOperand (2 );
4938
+
4939
+ assert (isFixedFPVector (A));
4940
+ assert (isFixedFPVector (WriteThrough));
4941
+
4942
+ [[maybe_unused]] unsigned ANumElements =
4943
+ cast<FixedVectorType>(A->getType ())->getNumElements ();
4944
+ unsigned OutputNumElements =
4945
+ cast<FixedVectorType>(WriteThrough->getType ())->getNumElements ();
4946
+ assert (ANumElements == OutputNumElements);
4947
+
4948
+ assert (Mask->getType ()->isIntegerTy ());
4949
+ // Some bits of the mask might be unused, but check them all anyway
4950
+ // (typically the mask is an integer constant).
4951
+ insertCheckShadowOf (Mask, &I);
4952
+
4953
+ // The mask has 1 bit per element of A, but a minimum of 8 bits.
4954
+ if (Mask->getType ()->getScalarSizeInBits () == 8 && ANumElements < 8 )
4955
+ Mask = IRB.CreateTrunc (Mask, Type::getIntNTy (*MS.C , ANumElements));
4956
+ assert (Mask->getType ()->getScalarSizeInBits () == ANumElements);
4957
+
4958
+ assert (I.getType () == WriteThrough->getType ());
4959
+
4960
+ Mask = IRB.CreateBitCast (
4961
+ Mask, FixedVectorType::get (IRB.getInt1Ty (), OutputNumElements));
4962
+
4963
+ Value *AShadow = getShadow (A);
4964
+
4965
+ // All-or-nothing shadow
4966
+ AShadow = IRB.CreateSExt (IRB.CreateICmpNE (AShadow, getCleanShadow (AShadow)),
4967
+ AShadow->getType ());
4968
+
4969
+ Value *WriteThroughShadow = getShadow (WriteThrough);
4970
+
4971
+ Value *Shadow = IRB.CreateSelect (Mask, AShadow, WriteThroughShadow);
4972
+ setShadow (&I, Shadow);
4973
+
4974
+ setOriginForNaryOp (I);
4975
+ }
4976
+
4914
4977
// For sh.* compiler intrinsics:
4915
4978
// llvm.x86.avx512fp16.mask.{add/sub/mul/div/max/min}.sh.round
4916
4979
// (<8 x half>, <8 x half>, <8 x half>, i8, i32)
@@ -6091,6 +6154,108 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
6091
6154
break ;
6092
6155
}
6093
6156
6157
+ // AVX512/AVX10 Reciprocal
6158
+ // <16 x float> @llvm.x86.avx512.rsqrt14.ps.512
6159
+ // (<16 x float>, <16 x float>, i16)
6160
+ // <8 x float> @llvm.x86.avx512.rsqrt14.ps.256
6161
+ // (<8 x float>, <8 x float>, i8)
6162
+ // <4 x float> @llvm.x86.avx512.rsqrt14.ps.128
6163
+ // (<4 x float>, <4 x float>, i8)
6164
+ //
6165
+ // <8 x double> @llvm.x86.avx512.rsqrt14.pd.512
6166
+ // (<8 x double>, <8 x double>, i8)
6167
+ // <4 x double> @llvm.x86.avx512.rsqrt14.pd.256
6168
+ // (<4 x double>, <4 x double>, i8)
6169
+ // <2 x double> @llvm.x86.avx512.rsqrt14.pd.128
6170
+ // (<2 x double>, <2 x double>, i8)
6171
+ //
6172
+ // <32 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.512
6173
+ // (<32 x bfloat>, <32 x bfloat>, i32)
6174
+ // <16 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.256
6175
+ // (<16 x bfloat>, <16 x bfloat>, i16)
6176
+ // <8 x bfloat> @llvm.x86.avx10.mask.rsqrt.bf16.128
6177
+ // (<8 x bfloat>, <8 x bfloat>, i8)
6178
+ //
6179
+ // <32 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.512
6180
+ // (<32 x half>, <32 x half>, i32)
6181
+ // <16 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.256
6182
+ // (<16 x half>, <16 x half>, i16)
6183
+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.ph.128
6184
+ // (<8 x half>, <8 x half>, i8)
6185
+ //
6186
+ // TODO: 3-operand variants are not handled:
6187
+ // <2 x double> @llvm.x86.avx512.rsqrt14.sd
6188
+ // (<2 x double>, <2 x double>, <2 x double>, i8)
6189
+ // <4 x float> @llvm.x86.avx512.rsqrt14.ss
6190
+ // (<4 x float>, <4 x float>, <4 x float>, i8)
6191
+ // <8 x half> @llvm.x86.avx512fp16.mask.rsqrt.sh
6192
+ // (<8 x half>, <8 x half>, <8 x half>, i8)
6193
+ case Intrinsic::x86_avx512_rsqrt14_ps_512:
6194
+ case Intrinsic::x86_avx512_rsqrt14_ps_256:
6195
+ case Intrinsic::x86_avx512_rsqrt14_ps_128:
6196
+ case Intrinsic::x86_avx512_rsqrt14_pd_512:
6197
+ case Intrinsic::x86_avx512_rsqrt14_pd_256:
6198
+ case Intrinsic::x86_avx512_rsqrt14_pd_128:
6199
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_512:
6200
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_256:
6201
+ case Intrinsic::x86_avx10_mask_rsqrt_bf16_128:
6202
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_512:
6203
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_256:
6204
+ case Intrinsic::x86_avx512fp16_mask_rsqrt_ph_128:
6205
+ handleAVX512VectorGenericMaskedFP (I);
6206
+ break ;
6207
+
6208
+ // AVX512/AVX10 Reciprocal Square Root
6209
+ // <16 x float> @llvm.x86.avx512.rcp14.ps.512
6210
+ // (<16 x float>, <16 x float>, i16)
6211
+ // <8 x float> @llvm.x86.avx512.rcp14.ps.256
6212
+ // (<8 x float>, <8 x float>, i8)
6213
+ // <4 x float> @llvm.x86.avx512.rcp14.ps.128
6214
+ // (<4 x float>, <4 x float>, i8)
6215
+ //
6216
+ // <8 x double> @llvm.x86.avx512.rcp14.pd.512
6217
+ // (<8 x double>, <8 x double>, i8)
6218
+ // <4 x double> @llvm.x86.avx512.rcp14.pd.256
6219
+ // (<4 x double>, <4 x double>, i8)
6220
+ // <2 x double> @llvm.x86.avx512.rcp14.pd.128
6221
+ // (<2 x double>, <2 x double>, i8)
6222
+ //
6223
+ // <32 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.512
6224
+ // (<32 x bfloat>, <32 x bfloat>, i32)
6225
+ // <16 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.256
6226
+ // (<16 x bfloat>, <16 x bfloat>, i16)
6227
+ // <8 x bfloat> @llvm.x86.avx10.mask.rcp.bf16.128
6228
+ // (<8 x bfloat>, <8 x bfloat>, i8)
6229
+ //
6230
+ // <32 x half> @llvm.x86.avx512fp16.mask.rcp.ph.512
6231
+ // (<32 x half>, <32 x half>, i32)
6232
+ // <16 x half> @llvm.x86.avx512fp16.mask.rcp.ph.256
6233
+ // (<16 x half>, <16 x half>, i16)
6234
+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.ph.128
6235
+ // (<8 x half>, <8 x half>, i8)
6236
+ //
6237
+ // TODO: 3-operand variants are not handled:
6238
+ // <2 x double> @llvm.x86.avx512.rcp14.sd
6239
+ // (<2 x double>, <2 x double>, <2 x double>, i8)
6240
+ // <4 x float> @llvm.x86.avx512.rcp14.ss
6241
+ // (<4 x float>, <4 x float>, <4 x float>, i8)
6242
+ // <8 x half> @llvm.x86.avx512fp16.mask.rcp.sh
6243
+ // (<8 x half>, <8 x half>, <8 x half>, i8)
6244
+ case Intrinsic::x86_avx512_rcp14_ps_512:
6245
+ case Intrinsic::x86_avx512_rcp14_ps_256:
6246
+ case Intrinsic::x86_avx512_rcp14_ps_128:
6247
+ case Intrinsic::x86_avx512_rcp14_pd_512:
6248
+ case Intrinsic::x86_avx512_rcp14_pd_256:
6249
+ case Intrinsic::x86_avx512_rcp14_pd_128:
6250
+ case Intrinsic::x86_avx10_mask_rcp_bf16_512:
6251
+ case Intrinsic::x86_avx10_mask_rcp_bf16_256:
6252
+ case Intrinsic::x86_avx10_mask_rcp_bf16_128:
6253
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_512:
6254
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_256:
6255
+ case Intrinsic::x86_avx512fp16_mask_rcp_ph_128:
6256
+ handleAVX512VectorGenericMaskedFP (I);
6257
+ break ;
6258
+
6094
6259
// AVX512 FP16 Arithmetic
6095
6260
case Intrinsic::x86_avx512fp16_mask_add_sh_round:
6096
6261
case Intrinsic::x86_avx512fp16_mask_sub_sh_round:
0 commit comments