@@ -3846,15 +3846,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
38463846 setOriginForNaryOp (I);
38473847 }
38483848
3849- // Instrument multiply-add intrinsics.
3849+ // Instrument multiply-add(-accumulate)? intrinsics.
38503850 //
38513851 // e.g., Two operands:
38523852 // <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
38533853 //
38543854 // Two operands which require an EltSizeInBits override:
38553855 // <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
38563856 //
3857- // Three operands are not implemented yet :
3857+ // Three operands:
38583858 // <4 x i32> @llvm.x86.avx512.vpdpbusd.128
38593859 // (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
38603860 // (the result of multiply-add'ing %a and %b is accumulated with %s)
@@ -3866,22 +3866,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
38663866 cast<FixedVectorType>(I.getType ());
38673867 assert (isa<FixedVectorType>(ReturnType));
38683868
3869- assert (I.arg_size () == 2 );
3870-
38713869 // Vectors A and B, and shadows
3872- Value *Va = I.getOperand (0 );
3873- Value *Vb = I.getOperand (1 );
3870+ Value *Va = nullptr ;
3871+ Value *Vb = nullptr ;
3872+ Value *Sa = nullptr ;
3873+ Value *Sb = nullptr ;
38743874
3875- Value *Sa = getShadow (&I, 0 );
3876- Value *Sb = getShadow (&I, 1 );
3875+ if (I.arg_size () == 2 ) {
3876+ Va = I.getOperand (0 );
3877+ Vb = I.getOperand (1 );
3878+
3879+ Sa = getShadow (&I, 0 );
3880+ Sb = getShadow (&I, 1 );
3881+ } else if (I.arg_size () == 3 ) {
3882+ // Operand 0 is the accumulator. We will deal with that below.
3883+ Va = I.getOperand (1 );
3884+ Vb = I.getOperand (2 );
3885+
3886+ Sa = getShadow (&I, 1 );
3887+ Sb = getShadow (&I, 2 );
3888+ } else {
3889+ assert (I.arg_size () == 2 || I.arg_size () == 3 );
3890+ }
38773891
3878- FixedVectorType *ParamType =
3879- cast<FixedVectorType>(I.getArgOperand (0 )->getType ());
3880- assert (ParamType == I.getArgOperand (1 )->getType ());
3892+ FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType ());
3893+ assert (ParamType == Vb->getType ());
38813894
38823895 assert (ParamType->getPrimitiveSizeInBits () ==
38833896 ReturnType->getPrimitiveSizeInBits ());
38843897
3898+ if (I.arg_size () == 3 ) {
3899+ assert (ParamType == ReturnType);
3900+ assert (ParamType == I.getArgOperand (0 )->getType ());
3901+ }
3902+
38853903 FixedVectorType *ImplicitReturnType = ReturnType;
38863904 // Step 1: instrument multiplication of corresponding vector elements
38873905 if (EltSizeInBits) {
@@ -3944,10 +3962,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
39443962 Constant::getNullValue (Horizontal->getType ())),
39453963 ImplicitReturnType);
39463964
3947- // For MMX, cast it back to the required fake return type (<1 x i64>).
3965+ // Cast it back to the required fake return type (<1 x i64>).
39483966 if (EltSizeInBits)
39493967 OutShadow = CreateShadowCast (IRB, OutShadow, getShadowTy (&I));
39503968
3969+ // Step 3 (if applicable): instrument accumulator
3970+ if (I.arg_size () == 3 )
3971+ OutShadow = IRB.CreateOr (OutShadow, getShadow (&I, 0 ));
3972+
39513973 setShadow (&I, OutShadow);
39523974 setOriginForNaryOp (I);
39533975 }
@@ -5507,6 +5529,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
55075529 handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
55085530 break ;
55095531
5532+ // AVX Vector Neural Network Instructions: bytes
5533+ //
5534+ // Multiply and Add Packed Signed and Unsigned Bytes
5535+ // < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
5536+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5537+ // < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
5538+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5539+ // <16 x i32> @llvm.x86.avx512.vpdpbusd.512
5540+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5541+ //
5542+ // Multiply and Add Unsigned and Signed Bytes With Saturation
5543+ // < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
5544+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5545+ // < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
5546+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5547+ // <16 x i32> @llvm.x86.avx512.vpdpbusds.512
5548+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5549+ //
5550+ // < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
5551+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5552+ // < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
5553+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5554+ //
5555+ // < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
5556+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5557+ // < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
5558+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5559+ //
5560+ // <16 x i32> @llvm.x86.avx10.vpdpbssd.512
5561+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5562+ // <16 x i32> @llvm.x86.avx10.vpdpbssds.512
5563+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5564+ //
5565+ // These intrinsics are auto-upgraded into non-masked forms:
5566+ // <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
5567+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5568+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
5569+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5570+ // <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
5571+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5572+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
5573+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5574+ // <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
5575+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5576+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
5577+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5578+ //
5579+ // <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
5580+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5581+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
5582+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5583+ // <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
5584+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5585+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
5586+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5587+ // <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
5588+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5589+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
5590+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5591+ case Intrinsic::x86_avx512_vpdpbusd_128:
5592+ case Intrinsic::x86_avx512_vpdpbusd_256:
5593+ case Intrinsic::x86_avx512_vpdpbusd_512:
5594+ case Intrinsic::x86_avx512_vpdpbusds_128:
5595+ case Intrinsic::x86_avx512_vpdpbusds_256:
5596+ case Intrinsic::x86_avx512_vpdpbusds_512:
5597+ case Intrinsic::x86_avx2_vpdpbssd_128:
5598+ case Intrinsic::x86_avx2_vpdpbssd_256:
5599+ case Intrinsic::x86_avx2_vpdpbssds_128:
5600+ case Intrinsic::x86_avx2_vpdpbssds_256:
5601+ case Intrinsic::x86_avx10_vpdpbssd_512:
5602+ case Intrinsic::x86_avx10_vpdpbssds_512:
5603+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 4 , /* EltSize=*/ 8 );
5604+ break ;
5605+
5606+ // AVX Vector Neural Network Instructions: words
5607+ //
5608+ // Multiply and Add Signed Word Integers
5609+ // < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
5610+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5611+ // < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
5612+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5613+ // <16 x i32> @llvm.x86.avx512.vpdpwssd.512
5614+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5615+ //
5616+ // Multiply and Add Signed Word Integers With Saturation
5617+ // < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
5618+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5619+ // < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
5620+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5621+ // <16 x i32> @llvm.x86.avx512.vpdpwssds.512
5622+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5623+ //
5624+ // These intrinsics are auto-upgraded into non-masked forms:
5625+ // <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
5626+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5627+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
5628+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5629+ // <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
5630+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5631+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
5632+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5633+ // <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
5634+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5635+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
5636+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5637+ //
5638+ // <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
5639+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5640+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
5641+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5642+ // <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
5643+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5644+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
5645+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5646+ // <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
5647+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5648+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
5649+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5650+ case Intrinsic::x86_avx512_vpdpwssd_128:
5651+ case Intrinsic::x86_avx512_vpdpwssd_256:
5652+ case Intrinsic::x86_avx512_vpdpwssd_512:
5653+ case Intrinsic::x86_avx512_vpdpwssds_128:
5654+ case Intrinsic::x86_avx512_vpdpwssds_256:
5655+ case Intrinsic::x86_avx512_vpdpwssds_512:
5656+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
5657+ break ;
5658+
5659+ // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
5660+ // Precision
5661+ // <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
5662+ // (<4 x float>, <8 x bfloat>, <8 x bfloat>)
5663+ // <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
5664+ // (<8 x float>, <16 x bfloat>, <16 x bfloat>)
5665+ // <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
5666+ // (<16 x float>, <32 x bfloat>, <32 x bfloat>)
5667+ // handleVectorPmaddIntrinsic() currently only handles integer types.
5668+
55105669 case Intrinsic::x86_sse_cmp_ss:
55115670 case Intrinsic::x86_sse2_cmp_sd:
55125671 case Intrinsic::x86_sse_comieq_ss:
0 commit comments