@@ -3903,7 +3903,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
39033903 // adding/"accumulating" %s. "Accumulation" stores the result in one
39043904 // of the source registers, but this accumulate vs. add distinction
39053905 // is lost when dealing with LLVM intrinsics.)
3906+ //
3907+ // ZeroPurifies means that multiplying a known-zero with an uninitialized
3908+ // value results in an initialized value. This is applicable for integer
3909+ // multiplication, but not floating-point (counter-example: NaN).
39063910 void handleVectorPmaddIntrinsic (IntrinsicInst &I, unsigned ReductionFactor,
3911+ bool ZeroPurifies,
39073912 unsigned EltSizeInBits = 0 ) {
39083913 IRBuilder<> IRB (&I);
39093914
@@ -3945,7 +3950,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
39453950 assert (AccumulatorType == ReturnType);
39463951 }
39473952
3948- FixedVectorType *ImplicitReturnType = ReturnType;
3953+ FixedVectorType *ImplicitReturnType =
3954+ cast<FixedVectorType>(getShadowTy (ReturnType));
39493955 // Step 1: instrument multiplication of corresponding vector elements
39503956 if (EltSizeInBits) {
39513957 ImplicitReturnType = cast<FixedVectorType>(
@@ -3964,30 +3970,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
39643970 ReturnType->getNumElements () * ReductionFactor);
39653971 }
39663972
3967- // Multiplying an *initialized* zero by an uninitialized element results in
3968- // an initialized zero element.
3969- //
3970- // This is analogous to bitwise AND, where "AND" of 0 and a poisoned value
3971- // results in an unpoisoned value. We can therefore adapt the visitAnd()
3972- // instrumentation:
3973- // OutShadow = (SaNonZero & SbNonZero)
3974- // | (VaNonZero & SbNonZero)
3975- // | (SaNonZero & VbNonZero)
3976- // where non-zero is checked on a per-element basis (not per bit).
3977- Value *SZero = Constant::getNullValue (Va->getType ());
3978- Value *VZero = Constant::getNullValue (Sa->getType ());
3979- Value *SaNonZero = IRB.CreateICmpNE (Sa, SZero);
3980- Value *SbNonZero = IRB.CreateICmpNE (Sb, SZero);
3981- Value *VaNonZero = IRB.CreateICmpNE (Va, VZero);
3982- Value *VbNonZero = IRB.CreateICmpNE (Vb, VZero);
3983-
3984- Value *SaAndSbNonZero = IRB.CreateAnd (SaNonZero, SbNonZero);
3985- Value *VaAndSbNonZero = IRB.CreateAnd (VaNonZero, SbNonZero);
3986- Value *SaAndVbNonZero = IRB.CreateAnd (SaNonZero, VbNonZero);
3987-
39883973 // Each element of the vector is represented by a single bit (poisoned or
39893974 // not) e.g., <8 x i1>.
3990- Value *And = IRB.CreateOr ({SaAndSbNonZero, VaAndSbNonZero, SaAndVbNonZero});
3975+ Value *SaNonZero = IRB.CreateIsNotNull (Sa);
3976+ Value *SbNonZero = IRB.CreateIsNotNull (Sb);
3977+ Value *And;
3978+ if (ZeroPurifies) {
3979+ // Multiplying an *initialized* zero by an uninitialized element results
3980+ // in an initialized zero element.
3981+ //
3982+ // This is analogous to bitwise AND, where "AND" of 0 and a poisoned value
3983+ // results in an unpoisoned value. We can therefore adapt the visitAnd()
3984+ // instrumentation:
3985+ // OutShadow = (SaNonZero & SbNonZero)
3986+ // | (VaNonZero & SbNonZero)
3987+ // | (SaNonZero & VbNonZero)
3988+ // where non-zero is checked on a per-element basis (not per bit).
3989+ Value *VaInt = Va;
3990+ Value *VbInt = Vb;
3991+ if (!Va->getType ()->isIntegerTy ()) {
3992+ VaInt = CreateAppToShadowCast (IRB, Va);
3993+ VbInt = CreateAppToShadowCast (IRB, Vb);
3994+ }
3995+
3996+ Value *VaNonZero = IRB.CreateIsNotNull (VaInt);
3997+ Value *VbNonZero = IRB.CreateIsNotNull (VbInt);
3998+
3999+ Value *SaAndSbNonZero = IRB.CreateAnd (SaNonZero, SbNonZero);
4000+ Value *VaAndSbNonZero = IRB.CreateAnd (VaNonZero, SbNonZero);
4001+ Value *SaAndVbNonZero = IRB.CreateAnd (SaNonZero, VbNonZero);
4002+
4003+ And = IRB.CreateOr ({SaAndSbNonZero, VaAndSbNonZero, SaAndVbNonZero});
4004+ } else {
4005+ And = IRB.CreateOr ({SaNonZero, SbNonZero});
4006+ }
39914007
39924008 // Extend <8 x i1> to <8 x i16>.
39934009 // (The real pmadd intrinsic would have computed intermediate values of
@@ -5752,17 +5768,20 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
57525768 case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
57535769 case Intrinsic::x86_avx2_pmadd_ub_sw:
57545770 case Intrinsic::x86_avx512_pmaddubs_w_512:
5755- handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 );
5771+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 ,
5772+ /* ZeroPurifies=*/ true );
57565773 break ;
57575774
57585775 // <1 x i64> @llvm.x86.ssse3.pmadd.ub.sw(<1 x i64>, <1 x i64>)
57595776 case Intrinsic::x86_ssse3_pmadd_ub_sw:
5760- handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 8 );
5777+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 ,
5778+ /* ZeroPurifies=*/ true , /* EltSizeInBits=*/ 8 );
57615779 break ;
57625780
57635781 // <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64>, <1 x i64>)
57645782 case Intrinsic::x86_mmx_pmadd_wd:
5765- handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
5783+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 ,
5784+ /* ZeroPurifies=*/ true , /* EltSizeInBits=*/ 16 );
57665785 break ;
57675786
57685787 // AVX Vector Neural Network Instructions: bytes
@@ -5848,7 +5867,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
58485867 case Intrinsic::x86_avx2_vpdpbuuds_128:
58495868 case Intrinsic::x86_avx2_vpdpbuuds_256:
58505869 case Intrinsic::x86_avx10_vpdpbuuds_512:
5851- handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 4 , /* EltSize=*/ 8 );
5870+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 4 ,
5871+ /* ZeroPurifies=*/ true , /* EltSizeInBits=*/ 8 );
58525872 break ;
58535873
58545874 // AVX Vector Neural Network Instructions: words
@@ -5901,7 +5921,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
59015921 case Intrinsic::x86_avx512_vpdpwssds_128:
59025922 case Intrinsic::x86_avx512_vpdpwssds_256:
59035923 case Intrinsic::x86_avx512_vpdpwssds_512:
5904- handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
5924+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 ,
5925+ /* ZeroPurifies=*/ true , /* EltSizeInBits=*/ 16 );
59055926 break ;
59065927
59075928 // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
0 commit comments