@@ -3846,18 +3846,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
3846
3846
setOriginForNaryOp (I);
3847
3847
}
3848
3848
3849
- // Instrument multiply-add intrinsics.
3849
+ // Instrument multiply-add(-accumulate)? intrinsics.
3850
3850
//
3851
3851
// e.g., Two operands:
3852
3852
// <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
3853
3853
//
3854
3854
// Two operands which require an EltSizeInBits override:
3855
3855
// <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
3856
3856
//
3857
- // Three operands are not implemented yet :
3857
+ // Three operands:
3858
3858
// <4 x i32> @llvm.x86.avx512.vpdpbusd.128
3859
3859
// (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
3860
- // (the result of multiply-add'ing %a and %b is accumulated with %s)
3860
+ // (this is equivalent to multiply-add on %a and %b, followed by
3861
+ // adding/"accumulating" %s. "Accumulation" stores the result in one
3862
+ // of the source registers, but this accumulate vs. add distinction
3863
+ // is lost when dealing with LLVM intrinsics.)
3861
3864
void handleVectorPmaddIntrinsic (IntrinsicInst &I, unsigned ReductionFactor,
3862
3865
unsigned EltSizeInBits = 0 ) {
3863
3866
IRBuilder<> IRB (&I);
@@ -3866,22 +3869,39 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
3866
3869
cast<FixedVectorType>(I.getType ());
3867
3870
assert (isa<FixedVectorType>(ReturnType));
3868
3871
3869
- assert (I.arg_size () == 2 );
3870
-
3871
3872
// Vectors A and B, and shadows
3872
- Value *Va = I.getOperand (0 );
3873
- Value *Vb = I.getOperand (1 );
3873
+ Value *Va = nullptr ;
3874
+ Value *Vb = nullptr ;
3875
+ Value *Sa = nullptr ;
3876
+ Value *Sb = nullptr ;
3874
3877
3875
- Value *Sa = getShadow (&I, 0 );
3876
- Value *Sb = getShadow (&I, 1 );
3878
+ assert (I.arg_size () == 2 || I.arg_size () == 3 );
3879
+ if (I.arg_size () == 2 ) {
3880
+ Va = I.getOperand (0 );
3881
+ Vb = I.getOperand (1 );
3877
3882
3878
- FixedVectorType *ParamType =
3879
- cast<FixedVectorType>(I.getArgOperand (0 )->getType ());
3880
- assert (ParamType == I.getArgOperand (1 )->getType ());
3883
+ Sa = getShadow (&I, 0 );
3884
+ Sb = getShadow (&I, 1 );
3885
+ } else if (I.arg_size () == 3 ) {
3886
+ // Operand 0 is the accumulator. We will deal with that below.
3887
+ Va = I.getOperand (1 );
3888
+ Vb = I.getOperand (2 );
3889
+
3890
+ Sa = getShadow (&I, 1 );
3891
+ Sb = getShadow (&I, 2 );
3892
+ }
3893
+
3894
+ FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType ());
3895
+ assert (ParamType == Vb->getType ());
3881
3896
3882
3897
assert (ParamType->getPrimitiveSizeInBits () ==
3883
3898
ReturnType->getPrimitiveSizeInBits ());
3884
3899
3900
+ if (I.arg_size () == 3 ) {
3901
+ assert (ParamType == ReturnType);
3902
+ assert (ParamType == I.getArgOperand (0 )->getType ());
3903
+ }
3904
+
3885
3905
FixedVectorType *ImplicitReturnType = ReturnType;
3886
3906
// Step 1: instrument multiplication of corresponding vector elements
3887
3907
if (EltSizeInBits) {
@@ -3944,10 +3964,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
3944
3964
Constant::getNullValue (Horizontal->getType ())),
3945
3965
ImplicitReturnType);
3946
3966
3947
- // For MMX, cast it back to the required fake return type (<1 x i64>).
3967
+ // Cast it back to the required fake return type (<1 x i64>).
3948
3968
if (EltSizeInBits)
3949
3969
OutShadow = CreateShadowCast (IRB, OutShadow, getShadowTy (&I));
3950
3970
3971
+ // Step 3 (if applicable): instrument accumulator
3972
+ if (I.arg_size () == 3 )
3973
+ OutShadow = IRB.CreateOr (OutShadow, getShadow (&I, 0 ));
3974
+
3951
3975
setShadow (&I, OutShadow);
3952
3976
setOriginForNaryOp (I);
3953
3977
}
@@ -5525,6 +5549,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
5525
5549
handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
5526
5550
break ;
5527
5551
5552
+ // AVX Vector Neural Network Instructions: bytes
5553
+ //
5554
+ // Multiply and Add Packed Signed and Unsigned Bytes
5555
+ // < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
5556
+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5557
+ // < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
5558
+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5559
+ // <16 x i32> @llvm.x86.avx512.vpdpbusd.512
5560
+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5561
+ //
5562
+ // Multiply and Add Unsigned and Signed Bytes With Saturation
5563
+ // < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
5564
+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5565
+ // < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
5566
+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5567
+ // <16 x i32> @llvm.x86.avx512.vpdpbusds.512
5568
+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5569
+ //
5570
+ // < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
5571
+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5572
+ // < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
5573
+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5574
+ //
5575
+ // < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
5576
+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5577
+ // < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
5578
+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5579
+ //
5580
+ // <16 x i32> @llvm.x86.avx10.vpdpbssd.512
5581
+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5582
+ // <16 x i32> @llvm.x86.avx10.vpdpbssds.512
5583
+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5584
+ //
5585
+ // These intrinsics are auto-upgraded into non-masked forms:
5586
+ // <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
5587
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5588
+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
5589
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5590
+ // <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
5591
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5592
+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
5593
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5594
+ // <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
5595
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5596
+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
5597
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5598
+ //
5599
+ // <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
5600
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5601
+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
5602
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5603
+ // <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
5604
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5605
+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
5606
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5607
+ // <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
5608
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5609
+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
5610
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5611
+ case Intrinsic::x86_avx512_vpdpbusd_128:
5612
+ case Intrinsic::x86_avx512_vpdpbusd_256:
5613
+ case Intrinsic::x86_avx512_vpdpbusd_512:
5614
+ case Intrinsic::x86_avx512_vpdpbusds_128:
5615
+ case Intrinsic::x86_avx512_vpdpbusds_256:
5616
+ case Intrinsic::x86_avx512_vpdpbusds_512:
5617
+ case Intrinsic::x86_avx2_vpdpbssd_128:
5618
+ case Intrinsic::x86_avx2_vpdpbssd_256:
5619
+ case Intrinsic::x86_avx2_vpdpbssds_128:
5620
+ case Intrinsic::x86_avx2_vpdpbssds_256:
5621
+ case Intrinsic::x86_avx10_vpdpbssd_512:
5622
+ case Intrinsic::x86_avx10_vpdpbssds_512:
5623
+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 4 , /* EltSize=*/ 8 );
5624
+ break ;
5625
+
5626
+ // AVX Vector Neural Network Instructions: words
5627
+ //
5628
+ // Multiply and Add Signed Word Integers
5629
+ // < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
5630
+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5631
+ // < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
5632
+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5633
+ // <16 x i32> @llvm.x86.avx512.vpdpwssd.512
5634
+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5635
+ //
5636
+ // Multiply and Add Signed Word Integers With Saturation
5637
+ // < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
5638
+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5639
+ // < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
5640
+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5641
+ // <16 x i32> @llvm.x86.avx512.vpdpwssds.512
5642
+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5643
+ //
5644
+ // These intrinsics are auto-upgraded into non-masked forms:
5645
+ // <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
5646
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5647
+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
5648
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5649
+ // <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
5650
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5651
+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
5652
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5653
+ // <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
5654
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5655
+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
5656
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5657
+ //
5658
+ // <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
5659
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5660
+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
5661
+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5662
+ // <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
5663
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5664
+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
5665
+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5666
+ // <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
5667
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5668
+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
5669
+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5670
+ case Intrinsic::x86_avx512_vpdpwssd_128:
5671
+ case Intrinsic::x86_avx512_vpdpwssd_256:
5672
+ case Intrinsic::x86_avx512_vpdpwssd_512:
5673
+ case Intrinsic::x86_avx512_vpdpwssds_128:
5674
+ case Intrinsic::x86_avx512_vpdpwssds_256:
5675
+ case Intrinsic::x86_avx512_vpdpwssds_512:
5676
+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
5677
+ break ;
5678
+
5679
+ // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
5680
+ // Precision
5681
+ // <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
5682
+ // (<4 x float>, <8 x bfloat>, <8 x bfloat>)
5683
+ // <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
5684
+ // (<8 x float>, <16 x bfloat>, <16 x bfloat>)
5685
+ // <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
5686
+ // (<16 x float>, <32 x bfloat>, <32 x bfloat>)
5687
+ // handleVectorPmaddIntrinsic() currently only handles integer types.
5688
+
5528
5689
case Intrinsic::x86_sse_cmp_ss:
5529
5690
case Intrinsic::x86_sse2_cmp_sd:
5530
5691
case Intrinsic::x86_sse_comieq_ss:
0 commit comments