Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 174 additions & 13 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3846,18 +3846,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}

// Instrument multiply-add intrinsics.
// Instrument multiply-add(-accumulate)? intrinsics.
//
// e.g., Two operands:
// <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
//
// Two operands which require an EltSizeInBits override:
// <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
//
// Three operands are not implemented yet:
// Three operands:
// <4 x i32> @llvm.x86.avx512.vpdpbusd.128
// (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
// (the result of multiply-add'ing %a and %b is accumulated with %s)
// (this is equivalent to multiply-add on %a and %b, followed by
// adding/"accumulating" %s. "Accumulation" stores the result in one
// of the source registers, but this accumulate vs. add distinction
// is lost when dealing with LLVM intrinsics.)
void handleVectorPmaddIntrinsic(IntrinsicInst &I, unsigned ReductionFactor,
unsigned EltSizeInBits = 0) {
IRBuilder<> IRB(&I);
Expand All @@ -3866,22 +3869,39 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
cast<FixedVectorType>(I.getType());
assert(isa<FixedVectorType>(ReturnType));

assert(I.arg_size() == 2);

// Vectors A and B, and shadows
Value *Va = I.getOperand(0);
Value *Vb = I.getOperand(1);
Value *Va = nullptr;
Value *Vb = nullptr;
Value *Sa = nullptr;
Value *Sb = nullptr;

Value *Sa = getShadow(&I, 0);
Value *Sb = getShadow(&I, 1);
assert(I.arg_size() == 2 || I.arg_size() == 3);
if (I.arg_size() == 2) {
Va = I.getOperand(0);
Vb = I.getOperand(1);

FixedVectorType *ParamType =
cast<FixedVectorType>(I.getArgOperand(0)->getType());
assert(ParamType == I.getArgOperand(1)->getType());
Sa = getShadow(&I, 0);
Sb = getShadow(&I, 1);
} else if (I.arg_size() == 3) {
// Operand 0 is the accumulator. We will deal with that below.
Va = I.getOperand(1);
Vb = I.getOperand(2);

Sa = getShadow(&I, 1);
Sb = getShadow(&I, 2);
}

FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType());
assert(ParamType == Vb->getType());

assert(ParamType->getPrimitiveSizeInBits() ==
ReturnType->getPrimitiveSizeInBits());

if (I.arg_size() == 3) {
assert(ParamType == ReturnType);
assert(ParamType == I.getArgOperand(0)->getType());
}

FixedVectorType *ImplicitReturnType = ReturnType;
// Step 1: instrument multiplication of corresponding vector elements
if (EltSizeInBits) {
Expand Down Expand Up @@ -3944,10 +3964,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
Constant::getNullValue(Horizontal->getType())),
ImplicitReturnType);

// For MMX, cast it back to the required fake return type (<1 x i64>).
// Cast it back to the required fake return type (<1 x i64>).
if (EltSizeInBits)
OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));

// Step 3 (if applicable): instrument accumulator
if (I.arg_size() == 3)
OutShadow = IRB.CreateOr(OutShadow, getShadow(&I, 0));

setShadow(&I, OutShadow);
setOriginForNaryOp(I);
}
Expand Down Expand Up @@ -5507,6 +5531,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
break;

// AVX Vector Neural Network Instructions: bytes
//
// Multiply and Add Packed Signed and Unsigned Bytes
// < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
// < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
// <16 x i32> @llvm.x86.avx512.vpdpbusd.512
// (<16 x i32>, <16 x i32>, <16 x i32>)
//
// Multiply and Add Unsigned and Signed Bytes With Saturation
// < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
// < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
// <16 x i32> @llvm.x86.avx512.vpdpbusds.512
// (<16 x i32>, <16 x i32>, <16 x i32>)
//
// < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
// < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
//
// < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
// < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
//
// <16 x i32> @llvm.x86.avx10.vpdpbssd.512
// (<16 x i32>, <16 x i32>, <16 x i32>)
// <16 x i32> @llvm.x86.avx10.vpdpbssds.512
// (<16 x i32>, <16 x i32>, <16 x i32>)
//
// These intrinsics are auto-upgraded into non-masked forms:
// <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
// <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
//
// <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
// <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
case Intrinsic::x86_avx512_vpdpbusd_128:
case Intrinsic::x86_avx512_vpdpbusd_256:
case Intrinsic::x86_avx512_vpdpbusd_512:
case Intrinsic::x86_avx512_vpdpbusds_128:
case Intrinsic::x86_avx512_vpdpbusds_256:
case Intrinsic::x86_avx512_vpdpbusds_512:
case Intrinsic::x86_avx2_vpdpbssd_128:
case Intrinsic::x86_avx2_vpdpbssd_256:
case Intrinsic::x86_avx2_vpdpbssds_128:
case Intrinsic::x86_avx2_vpdpbssds_256:
case Intrinsic::x86_avx10_vpdpbssd_512:
case Intrinsic::x86_avx10_vpdpbssds_512:
handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/4, /*EltSize=*/8);
break;

// AVX Vector Neural Network Instructions: words
//
// Multiply and Add Signed Word Integers
// < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
// < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
// <16 x i32> @llvm.x86.avx512.vpdpwssd.512
// (<16 x i32>, <16 x i32>, <16 x i32>)
//
// Multiply and Add Signed Word Integers With Saturation
// < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
// (< 4 x i32>, < 4 x i32>, < 4 x i32>)
// < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
// (< 8 x i32>, < 8 x i32>, < 8 x i32>)
// <16 x i32> @llvm.x86.avx512.vpdpwssds.512
// (<16 x i32>, <16 x i32>, <16 x i32>)
//
// These intrinsics are auto-upgraded into non-masked forms:
// <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
// <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
//
// <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
// (<4 x i32>, <4 x i32>, <4 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
// (<8 x i32>, <8 x i32>, <8 x i32>, i8)
// <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
// <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
// (<16 x i32>, <16 x i32>, <16 x i32>, i16)
case Intrinsic::x86_avx512_vpdpwssd_128:
case Intrinsic::x86_avx512_vpdpwssd_256:
case Intrinsic::x86_avx512_vpdpwssd_512:
case Intrinsic::x86_avx512_vpdpwssds_128:
case Intrinsic::x86_avx512_vpdpwssds_256:
case Intrinsic::x86_avx512_vpdpwssds_512:
handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
break;

// TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
// Precision
// <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
// (<4 x float>, <8 x bfloat>, <8 x bfloat>)
// <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
// (<8 x float>, <16 x bfloat>, <16 x bfloat>)
// <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
// (<16 x float>, <32 x bfloat>, <32 x bfloat>)
// handleVectorPmaddIntrinsic() currently only handles integer types.

case Intrinsic::x86_sse_cmp_ss:
case Intrinsic::x86_sse2_cmp_sd:
case Intrinsic::x86_sse_comieq_ss:
Expand Down
Loading