Skip to content

Commit 4220538

Browse files
authored
[msan] Handle multiply-add-accumulate; apply to AVX Vector Neural Network Instructions (VNNI) (#153927)
This extends the pmadd handler (recently improved in #153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions. Updates the tests from #153135
1 parent 4629291 commit 4220538

File tree

9 files changed

+2069
-358
lines changed

9 files changed

+2069
-358
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 174 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3846,18 +3846,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
38463846
setOriginForNaryOp(I);
38473847
}
38483848

3849-
// Instrument multiply-add intrinsics.
3849+
// Instrument multiply-add(-accumulate)? intrinsics.
38503850
//
38513851
// e.g., Two operands:
38523852
// <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
38533853
//
38543854
// Two operands which require an EltSizeInBits override:
38553855
// <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
38563856
//
3857-
// Three operands are not implemented yet:
3857+
// Three operands:
38583858
// <4 x i32> @llvm.x86.avx512.vpdpbusd.128
38593859
// (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
3860-
// (the result of multiply-add'ing %a and %b is accumulated with %s)
3860+
// (this is equivalent to multiply-add on %a and %b, followed by
3861+
// adding/"accumulating" %s. "Accumulation" stores the result in one
3862+
// of the source registers, but this accumulate vs. add distinction
3863+
// is lost when dealing with LLVM intrinsics.)
38613864
void handleVectorPmaddIntrinsic(IntrinsicInst &I, unsigned ReductionFactor,
38623865
unsigned EltSizeInBits = 0) {
38633866
IRBuilder<> IRB(&I);
@@ -3866,22 +3869,39 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
38663869
cast<FixedVectorType>(I.getType());
38673870
assert(isa<FixedVectorType>(ReturnType));
38683871

3869-
assert(I.arg_size() == 2);
3870-
38713872
// Vectors A and B, and shadows
3872-
Value *Va = I.getOperand(0);
3873-
Value *Vb = I.getOperand(1);
3873+
Value *Va = nullptr;
3874+
Value *Vb = nullptr;
3875+
Value *Sa = nullptr;
3876+
Value *Sb = nullptr;
38743877

3875-
Value *Sa = getShadow(&I, 0);
3876-
Value *Sb = getShadow(&I, 1);
3878+
assert(I.arg_size() == 2 || I.arg_size() == 3);
3879+
if (I.arg_size() == 2) {
3880+
Va = I.getOperand(0);
3881+
Vb = I.getOperand(1);
38773882

3878-
FixedVectorType *ParamType =
3879-
cast<FixedVectorType>(I.getArgOperand(0)->getType());
3880-
assert(ParamType == I.getArgOperand(1)->getType());
3883+
Sa = getShadow(&I, 0);
3884+
Sb = getShadow(&I, 1);
3885+
} else if (I.arg_size() == 3) {
3886+
// Operand 0 is the accumulator. We will deal with that below.
3887+
Va = I.getOperand(1);
3888+
Vb = I.getOperand(2);
3889+
3890+
Sa = getShadow(&I, 1);
3891+
Sb = getShadow(&I, 2);
3892+
}
3893+
3894+
FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType());
3895+
assert(ParamType == Vb->getType());
38813896

38823897
assert(ParamType->getPrimitiveSizeInBits() ==
38833898
ReturnType->getPrimitiveSizeInBits());
38843899

3900+
if (I.arg_size() == 3) {
3901+
assert(ParamType == ReturnType);
3902+
assert(ParamType == I.getArgOperand(0)->getType());
3903+
}
3904+
38853905
FixedVectorType *ImplicitReturnType = ReturnType;
38863906
// Step 1: instrument multiplication of corresponding vector elements
38873907
if (EltSizeInBits) {
@@ -3944,10 +3964,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
39443964
Constant::getNullValue(Horizontal->getType())),
39453965
ImplicitReturnType);
39463966

3947-
// For MMX, cast it back to the required fake return type (<1 x i64>).
3967+
// Cast it back to the required fake return type (<1 x i64>).
39483968
if (EltSizeInBits)
39493969
OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));
39503970

3971+
// Step 3 (if applicable): instrument accumulator
3972+
if (I.arg_size() == 3)
3973+
OutShadow = IRB.CreateOr(OutShadow, getShadow(&I, 0));
3974+
39513975
setShadow(&I, OutShadow);
39523976
setOriginForNaryOp(I);
39533977
}
@@ -5525,6 +5549,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
55255549
handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
55265550
break;
55275551

5552+
// AVX Vector Neural Network Instructions: bytes
5553+
//
5554+
// Multiply and Add Packed Signed and Unsigned Bytes
5555+
// < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
5556+
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5557+
// < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
5558+
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5559+
// <16 x i32> @llvm.x86.avx512.vpdpbusd.512
5560+
// (<16 x i32>, <16 x i32>, <16 x i32>)
5561+
//
5562+
// Multiply and Add Unsigned and Signed Bytes With Saturation
5563+
// < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
5564+
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5565+
// < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
5566+
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5567+
// <16 x i32> @llvm.x86.avx512.vpdpbusds.512
5568+
// (<16 x i32>, <16 x i32>, <16 x i32>)
5569+
//
5570+
// < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
5571+
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5572+
// < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
5573+
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5574+
//
5575+
// < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
5576+
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5577+
// < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
5578+
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5579+
//
5580+
// <16 x i32> @llvm.x86.avx10.vpdpbssd.512
5581+
// (<16 x i32>, <16 x i32>, <16 x i32>)
5582+
// <16 x i32> @llvm.x86.avx10.vpdpbssds.512
5583+
// (<16 x i32>, <16 x i32>, <16 x i32>)
5584+
//
5585+
// These intrinsics are auto-upgraded into non-masked forms:
5586+
// <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
5587+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5588+
// <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
5589+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5590+
// <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
5591+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5592+
// <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
5593+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5594+
// <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
5595+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5596+
// <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
5597+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5598+
//
5599+
// <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
5600+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5601+
// <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
5602+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5603+
// <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
5604+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5605+
// <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
5606+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5607+
// <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
5608+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5609+
// <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
5610+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5611+
case Intrinsic::x86_avx512_vpdpbusd_128:
5612+
case Intrinsic::x86_avx512_vpdpbusd_256:
5613+
case Intrinsic::x86_avx512_vpdpbusd_512:
5614+
case Intrinsic::x86_avx512_vpdpbusds_128:
5615+
case Intrinsic::x86_avx512_vpdpbusds_256:
5616+
case Intrinsic::x86_avx512_vpdpbusds_512:
5617+
case Intrinsic::x86_avx2_vpdpbssd_128:
5618+
case Intrinsic::x86_avx2_vpdpbssd_256:
5619+
case Intrinsic::x86_avx2_vpdpbssds_128:
5620+
case Intrinsic::x86_avx2_vpdpbssds_256:
5621+
case Intrinsic::x86_avx10_vpdpbssd_512:
5622+
case Intrinsic::x86_avx10_vpdpbssds_512:
5623+
handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/4, /*EltSize=*/8);
5624+
break;
5625+
5626+
// AVX Vector Neural Network Instructions: words
5627+
//
5628+
// Multiply and Add Signed Word Integers
5629+
// < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
5630+
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5631+
// < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
5632+
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5633+
// <16 x i32> @llvm.x86.avx512.vpdpwssd.512
5634+
// (<16 x i32>, <16 x i32>, <16 x i32>)
5635+
//
5636+
// Multiply and Add Signed Word Integers With Saturation
5637+
// < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
5638+
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5639+
// < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
5640+
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5641+
// <16 x i32> @llvm.x86.avx512.vpdpwssds.512
5642+
// (<16 x i32>, <16 x i32>, <16 x i32>)
5643+
//
5644+
// These intrinsics are auto-upgraded into non-masked forms:
5645+
// <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
5646+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5647+
// <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
5648+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5649+
// <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
5650+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5651+
// <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
5652+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5653+
// <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
5654+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5655+
// <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
5656+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5657+
//
5658+
// <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
5659+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5660+
// <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
5661+
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5662+
// <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
5663+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5664+
// <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
5665+
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5666+
// <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
5667+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5668+
// <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
5669+
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5670+
case Intrinsic::x86_avx512_vpdpwssd_128:
5671+
case Intrinsic::x86_avx512_vpdpwssd_256:
5672+
case Intrinsic::x86_avx512_vpdpwssd_512:
5673+
case Intrinsic::x86_avx512_vpdpwssds_128:
5674+
case Intrinsic::x86_avx512_vpdpwssds_256:
5675+
case Intrinsic::x86_avx512_vpdpwssds_512:
5676+
handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
5677+
break;
5678+
5679+
// TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
5680+
// Precision
5681+
// <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
5682+
// (<4 x float>, <8 x bfloat>, <8 x bfloat>)
5683+
// <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
5684+
// (<8 x float>, <16 x bfloat>, <16 x bfloat>)
5685+
// <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
5686+
// (<16 x float>, <32 x bfloat>, <32 x bfloat>)
5687+
// handleVectorPmaddIntrinsic() currently only handles integer types.
5688+
55285689
case Intrinsic::x86_sse_cmp_ss:
55295690
case Intrinsic::x86_sse2_cmp_sd:
55305691
case Intrinsic::x86_sse_comieq_ss:

0 commit comments

Comments
 (0)