Skip to content

Conversation

thurstond
Copy link
Contributor

This extends the pmadd handler (recently improved in #153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from #153135

This extends the pmadd handler (recently improved in llvm#153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from llvm#153135
@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Thurston Dang (thurstond)

Changes

This extends the pmadd handler (recently improved in #153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from #153135


Patch is 266.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153927.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+171-12)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll (+64-22)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2ni-intrinsics.ll (+85-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics-upgrade.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics-upgrade.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx_vnni-intrinsics.ll (+161-33)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avxvnniint8-intrinsics.ll (+165-33)
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 6b394f5338687..f3a7fc1e692b8 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -3846,7 +3846,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Instrument multiply-add intrinsics.
+  // Instrument multiply-add(-accumulate)? intrinsics.
   //
   // e.g., Two operands:
   //         <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
@@ -3854,7 +3854,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
   //       Two operands which require an EltSizeInBits override:
   //         <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
   //
-  //       Three operands are not implemented yet:
+  //       Three operands:
   //         <4 x i32> @llvm.x86.avx512.vpdpbusd.128
   //                       (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
   //         (the result of multiply-add'ing %a and %b is accumulated with %s)
@@ -3866,22 +3866,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
         cast<FixedVectorType>(I.getType());
     assert(isa<FixedVectorType>(ReturnType));
 
-    assert(I.arg_size() == 2);
-
     // Vectors A and B, and shadows
-    Value *Va = I.getOperand(0);
-    Value *Vb = I.getOperand(1);
+    Value *Va = nullptr;
+    Value *Vb = nullptr;
+    Value *Sa = nullptr;
+    Value *Sb = nullptr;
 
-    Value *Sa = getShadow(&I, 0);
-    Value *Sb = getShadow(&I, 1);
+    if (I.arg_size() == 2) {
+      Va = I.getOperand(0);
+      Vb = I.getOperand(1);
+
+      Sa = getShadow(&I, 0);
+      Sb = getShadow(&I, 1);
+    } else if (I.arg_size() == 3) {
+      // Operand 0 is the accumulator. We will deal with that below.
+      Va = I.getOperand(1);
+      Vb = I.getOperand(2);
+
+      Sa = getShadow(&I, 1);
+      Sb = getShadow(&I, 2);
+    } else {
+      assert(I.arg_size() == 2 || I.arg_size() == 3);
+    }
 
-    FixedVectorType *ParamType =
-        cast<FixedVectorType>(I.getArgOperand(0)->getType());
-    assert(ParamType == I.getArgOperand(1)->getType());
+    FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType());
+    assert(ParamType == Vb->getType());
 
     assert(ParamType->getPrimitiveSizeInBits() ==
            ReturnType->getPrimitiveSizeInBits());
 
+    if (I.arg_size() == 3) {
+      assert(ParamType == ReturnType);
+      assert(ParamType == I.getArgOperand(0)->getType());
+    }
+
     FixedVectorType *ImplicitReturnType = ReturnType;
     // Step 1: instrument multiplication of corresponding vector elements
     if (EltSizeInBits) {
@@ -3944,10 +3962,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
                          Constant::getNullValue(Horizontal->getType())),
         ImplicitReturnType);
 
-    // For MMX, cast it back to the required fake return type (<1 x i64>).
+    // Cast it back to the required fake return type (<1 x i64>).
     if (EltSizeInBits)
       OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));
 
+    // Step 3 (if applicable): instrument accumulator
+    if (I.arg_size() == 3)
+      OutShadow = IRB.CreateOr(OutShadow, getShadow(&I, 0));
+
     setShadow(&I, OutShadow);
     setOriginForNaryOp(I);
   }
@@ -5507,6 +5529,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
       break;
 
+    // AVX Vector Neural Network Instructions: bytes
+    //
+    // Multiply and Add Packed Signed and Unsigned Bytes
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Unsigned and Signed Bytes With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpbusd_128:
+    case Intrinsic::x86_avx512_vpdpbusd_256:
+    case Intrinsic::x86_avx512_vpdpbusd_512:
+    case Intrinsic::x86_avx512_vpdpbusds_128:
+    case Intrinsic::x86_avx512_vpdpbusds_256:
+    case Intrinsic::x86_avx512_vpdpbusds_512:
+    case Intrinsic::x86_avx2_vpdpbssd_128:
+    case Intrinsic::x86_avx2_vpdpbssd_256:
+    case Intrinsic::x86_avx2_vpdpbssds_128:
+    case Intrinsic::x86_avx2_vpdpbssds_256:
+    case Intrinsic::x86_avx10_vpdpbssd_512:
+    case Intrinsic::x86_avx10_vpdpbssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/4, /*EltSize=*/8);
+      break;
+
+    // AVX Vector Neural Network Instructions: words
+    //
+    // Multiply and Add Signed Word Integers
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Signed Word Integers With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpwssd_128:
+    case Intrinsic::x86_avx512_vpdpwssd_256:
+    case Intrinsic::x86_avx512_vpdpwssd_512:
+    case Intrinsic::x86_avx512_vpdpwssds_128:
+    case Intrinsic::x86_avx512_vpdpwssds_256:
+    case Intrinsic::x86_avx512_vpdpwssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
+      break;
+
+      // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
+      // Precision
+      //   <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
+      //                   (<4 x float>, <8 x bfloat>, <8 x bfloat>)
+      //   <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
+      //                   (<8 x float>, <16 x bfloat>, <16 x bfloat>)
+      //   <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
+      //                   (<16 x float>, <32 x bfloat>, <32 x bfloat>)
+      // handleVectorPmaddIntrinsic() currently only handles integer types.
+
     case Intrinsic::x86_sse_cmp_ss:
     case Intrinsic::x86_sse2_cmp_sd:
     case Intrinsic::x86_sse_comieq_ss:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
index 7af8f34d403a0..298dc4b2c853a 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
@@ -7,19 +7,7 @@
 ; - llvm.x86.avx10.vdpphps.512
 ; - llvm.x86.avx10.vmpsadbw.512
 ;
-; Handled heuristically:
-; - llvm.x86.avx10.vpdpbssd.512
-; - llvm.x86.avx10.vpdpbssds.512
-; - llvm.x86.avx10.vpdpbsud.512
-; - llvm.x86.avx10.vpdpbsuds.512
-; - llvm.x86.avx10.vpdpbuud.512
-; - llvm.x86.avx10.vpdpbuuds.512
-; - llvm.x86.avx10.vpdpwsud.512
-; - llvm.x86.avx10.vpdpwsuds.512
-; - llvm.x86.avx10.vpdpwusd.512
-; - llvm.x86.avx10.vpdpwusds.512
-; - llvm.x86.avx10.vpdpwuud.512
-; - llvm.x86.avx10.vpdpwuuds.512
+; Handled heuristically: (none)
 
 target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -140,8 +128,8 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-LABEL: define <16 x i32> @test_mm512_dpbssd_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], ptr [[PB:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load i64, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
-; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
+; CHECK-NEXT:    [[TMP4:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
 ; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i64 [[TMP1]], 0
 ; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB4:.*]], label %[[BB5:.*]], !prof [[PROF1]]
@@ -154,8 +142,26 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-NEXT:    [[TMP7:%.*]] = xor i64 [[TMP6]], 87960930222080
 ; CHECK-NEXT:    [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr
 ; CHECK-NEXT:    [[_MSLD:%.*]] = load <16 x i32>, ptr [[TMP8]], align 64
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP2]], [[TMP3]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[_MSLD]]
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <16 x i32> [[_MSLD]] to <64 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <64 x i8> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    [[TMP14:%.*]] = icmp ne <64 x i8> [[TMP12]], zeroinitializer
+; CHECK-NEXT:    [[TMP15:%.*]] = icmp ne <64 x i8> [[TMP9]], zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = icmp ne <64 x i8> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = and <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP18:%.*]] = and <64 x i1> [[TMP15]], [[TMP14]]
+; CHECK-NEXT:    [[TMP19:%.*]] = and <64 x i1> [[TMP13]], [[TMP16]]
+; CHECK-NEXT:    [[TMP20:%.*]] = or <64 x i1> [[TMP17]], [[TMP18]]
+; CHECK-NEXT:    [[TMP21:%.*]] = or <64 x i1> [[TMP20]], [[TMP19]]
+; CHECK-NEXT:    [[TMP22:%.*]] = sext <64 x i1> [[TMP21]] to <64 x i8>
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast <64 x i8> [[TMP22]] to <32 x i16>
+; CHECK-NEXT:    [[TMP24:%.*]] = icmp ne <32 x i16> [[TMP23]], zeroinitializer
+; CHECK-NEXT:    [[TMP25:%.*]] = sext <32 x i1> [[TMP24]] to <32 x i16>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <32 x i16> [[TMP25]] to i512
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast i512 [[TMP26]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP27]], [[TMP4]]
 ; CHECK-NEXT:    [[RES:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    store <16 x i32> [[_MSPROP1]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i32> [[RES]]
@@ -168,13 +174,31 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %__U, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_mask_dpbssds_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], i16 zeroext [[__U:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP24:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP24]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP28]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP28]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast i512 [[TMP22]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP23]], [[TMP1]]
 ; CHECK-NEXT:    [[DPI:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssds.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    [[TMP5:%.*]] = bitcast i16 [[TMP4]] to <16 x i1>
 ; CHECK-NEXT:    [[BST:%.*]] = bitcast i16 [[__U]] to <16 x i1>
@@ -196,13 +220,31 @@ define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %_
 define <16 x i32> @test_mm512_maskz_dpbssd_epi32(i16 zeroext %__U, <16 x i32> %__W, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_maskz_dpbssd_epi32(
 ; CHECK-SAME: i16 zeroext [[__U:%.*]], <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP24:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP29:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP28]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP29]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP29]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2025

@llvm/pr-subscribers-compiler-rt-sanitizer

Author: Thurston Dang (thurstond)

Changes

This extends the pmadd handler (recently improved in #153353) to three-operand intrinsics (multiply-add-accumulate), and applies it to the AVX Vector Neural Network Instructions.

Updates the tests from #153135


Patch is 266.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153927.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp (+171-12)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll (+64-22)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2ni-intrinsics.ll (+85-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics-upgrade.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vl_vnni-intrinsics.ll (+473-73)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics-upgrade.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx512vnni-intrinsics.ll (+237-37)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avx_vnni-intrinsics.ll (+161-33)
  • (modified) llvm/test/Instrumentation/MemorySanitizer/X86/avxvnniint8-intrinsics.ll (+165-33)
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index 6b394f5338687..f3a7fc1e692b8 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -3846,7 +3846,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
     setOriginForNaryOp(I);
   }
 
-  // Instrument multiply-add intrinsics.
+  // Instrument multiply-add(-accumulate)? intrinsics.
   //
   // e.g., Two operands:
   //         <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
@@ -3854,7 +3854,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
   //       Two operands which require an EltSizeInBits override:
   //         <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
   //
-  //       Three operands are not implemented yet:
+  //       Three operands:
   //         <4 x i32> @llvm.x86.avx512.vpdpbusd.128
   //                       (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
   //         (the result of multiply-add'ing %a and %b is accumulated with %s)
@@ -3866,22 +3866,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
         cast<FixedVectorType>(I.getType());
     assert(isa<FixedVectorType>(ReturnType));
 
-    assert(I.arg_size() == 2);
-
     // Vectors A and B, and shadows
-    Value *Va = I.getOperand(0);
-    Value *Vb = I.getOperand(1);
+    Value *Va = nullptr;
+    Value *Vb = nullptr;
+    Value *Sa = nullptr;
+    Value *Sb = nullptr;
 
-    Value *Sa = getShadow(&I, 0);
-    Value *Sb = getShadow(&I, 1);
+    if (I.arg_size() == 2) {
+      Va = I.getOperand(0);
+      Vb = I.getOperand(1);
+
+      Sa = getShadow(&I, 0);
+      Sb = getShadow(&I, 1);
+    } else if (I.arg_size() == 3) {
+      // Operand 0 is the accumulator. We will deal with that below.
+      Va = I.getOperand(1);
+      Vb = I.getOperand(2);
+
+      Sa = getShadow(&I, 1);
+      Sb = getShadow(&I, 2);
+    } else {
+      assert(I.arg_size() == 2 || I.arg_size() == 3);
+    }
 
-    FixedVectorType *ParamType =
-        cast<FixedVectorType>(I.getArgOperand(0)->getType());
-    assert(ParamType == I.getArgOperand(1)->getType());
+    FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType());
+    assert(ParamType == Vb->getType());
 
     assert(ParamType->getPrimitiveSizeInBits() ==
            ReturnType->getPrimitiveSizeInBits());
 
+    if (I.arg_size() == 3) {
+      assert(ParamType == ReturnType);
+      assert(ParamType == I.getArgOperand(0)->getType());
+    }
+
     FixedVectorType *ImplicitReturnType = ReturnType;
     // Step 1: instrument multiplication of corresponding vector elements
     if (EltSizeInBits) {
@@ -3944,10 +3962,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
                          Constant::getNullValue(Horizontal->getType())),
         ImplicitReturnType);
 
-    // For MMX, cast it back to the required fake return type (<1 x i64>).
+    // Cast it back to the required fake return type (<1 x i64>).
     if (EltSizeInBits)
       OutShadow = CreateShadowCast(IRB, OutShadow, getShadowTy(&I));
 
+    // Step 3 (if applicable): instrument accumulator
+    if (I.arg_size() == 3)
+      OutShadow = IRB.CreateOr(OutShadow, getShadow(&I, 0));
+
     setShadow(&I, OutShadow);
     setOriginForNaryOp(I);
   }
@@ -5507,6 +5529,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
       handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
       break;
 
+    // AVX Vector Neural Network Instructions: bytes
+    //
+    // Multiply and Add Packed Signed and Unsigned Bytes
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Unsigned and Signed Bytes With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //   <16 x i32> @llvm.x86.avx10.vpdpbssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
+    //                  (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
+    //                  (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpbusd_128:
+    case Intrinsic::x86_avx512_vpdpbusd_256:
+    case Intrinsic::x86_avx512_vpdpbusd_512:
+    case Intrinsic::x86_avx512_vpdpbusds_128:
+    case Intrinsic::x86_avx512_vpdpbusds_256:
+    case Intrinsic::x86_avx512_vpdpbusds_512:
+    case Intrinsic::x86_avx2_vpdpbssd_128:
+    case Intrinsic::x86_avx2_vpdpbssd_256:
+    case Intrinsic::x86_avx2_vpdpbssds_128:
+    case Intrinsic::x86_avx2_vpdpbssds_256:
+    case Intrinsic::x86_avx10_vpdpbssd_512:
+    case Intrinsic::x86_avx10_vpdpbssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/4, /*EltSize=*/8);
+      break;
+
+    // AVX Vector Neural Network Instructions: words
+    //
+    // Multiply and Add Signed Word Integers
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssd.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // Multiply and Add Signed Word Integers With Saturation
+    //   < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
+    //                  (< 4 x i32>, < 4 x i32>, < 4 x i32>)
+    //   < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
+    //                  (< 8 x i32>, < 8 x i32>, < 8 x i32>)
+    //   <16 x i32> @llvm.x86.avx512.vpdpwssds.512
+    //                  (<16 x i32>, <16 x i32>, <16 x i32>)
+    //
+    // These intrinsics are auto-upgraded into non-masked forms:
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //
+    //   <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
+    //                 (<4 x i32>, <4 x i32>, <4 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
+    //                 (<8 x i32>, <8 x i32>, <8 x i32>, i8)
+    //   <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    //   <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
+    //                 (<16 x i32>, <16 x i32>, <16 x i32>, i16)
+    case Intrinsic::x86_avx512_vpdpwssd_128:
+    case Intrinsic::x86_avx512_vpdpwssd_256:
+    case Intrinsic::x86_avx512_vpdpwssd_512:
+    case Intrinsic::x86_avx512_vpdpwssds_128:
+    case Intrinsic::x86_avx512_vpdpwssds_256:
+    case Intrinsic::x86_avx512_vpdpwssds_512:
+      handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*EltSize=*/16);
+      break;
+
+      // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
+      // Precision
+      //   <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
+      //                   (<4 x float>, <8 x bfloat>, <8 x bfloat>)
+      //   <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
+      //                   (<8 x float>, <16 x bfloat>, <16 x bfloat>)
+      //   <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
+      //                   (<16 x float>, <32 x bfloat>, <32 x bfloat>)
+      // handleVectorPmaddIntrinsic() currently only handles integer types.
+
     case Intrinsic::x86_sse_cmp_ss:
     case Intrinsic::x86_sse2_cmp_sd:
     case Intrinsic::x86_sse_comieq_ss:
diff --git a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
index 7af8f34d403a0..298dc4b2c853a 100644
--- a/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
+++ b/llvm/test/Instrumentation/MemorySanitizer/X86/avx10_2_512ni-intrinsics.ll
@@ -7,19 +7,7 @@
 ; - llvm.x86.avx10.vdpphps.512
 ; - llvm.x86.avx10.vmpsadbw.512
 ;
-; Handled heuristically:
-; - llvm.x86.avx10.vpdpbssd.512
-; - llvm.x86.avx10.vpdpbssds.512
-; - llvm.x86.avx10.vpdpbsud.512
-; - llvm.x86.avx10.vpdpbsuds.512
-; - llvm.x86.avx10.vpdpbuud.512
-; - llvm.x86.avx10.vpdpbuuds.512
-; - llvm.x86.avx10.vpdpwsud.512
-; - llvm.x86.avx10.vpdpwsuds.512
-; - llvm.x86.avx10.vpdpwusd.512
-; - llvm.x86.avx10.vpdpwusds.512
-; - llvm.x86.avx10.vpdpwuud.512
-; - llvm.x86.avx10.vpdpwuuds.512
+; Handled heuristically: (none)
 
 target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -140,8 +128,8 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-LABEL: define <16 x i32> @test_mm512_dpbssd_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], ptr [[PB:%.*]]) #[[ATTR0]] {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load i64, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
-; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
+; CHECK-NEXT:    [[TMP4:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
 ; CHECK-NEXT:    [[_MSCMP:%.*]] = icmp ne i64 [[TMP1]], 0
 ; CHECK-NEXT:    br i1 [[_MSCMP]], label %[[BB4:.*]], label %[[BB5:.*]], !prof [[PROF1]]
@@ -154,8 +142,26 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 ; CHECK-NEXT:    [[TMP7:%.*]] = xor i64 [[TMP6]], 87960930222080
 ; CHECK-NEXT:    [[TMP8:%.*]] = inttoptr i64 [[TMP7]] to ptr
 ; CHECK-NEXT:    [[_MSLD:%.*]] = load <16 x i32>, ptr [[TMP8]], align 64
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP2]], [[TMP3]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[_MSLD]]
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <16 x i32> [[_MSLD]] to <64 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = icmp ne <64 x i8> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    [[TMP14:%.*]] = icmp ne <64 x i8> [[TMP12]], zeroinitializer
+; CHECK-NEXT:    [[TMP15:%.*]] = icmp ne <64 x i8> [[TMP9]], zeroinitializer
+; CHECK-NEXT:    [[TMP16:%.*]] = icmp ne <64 x i8> [[TMP10]], zeroinitializer
+; CHECK-NEXT:    [[TMP17:%.*]] = and <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP18:%.*]] = and <64 x i1> [[TMP15]], [[TMP14]]
+; CHECK-NEXT:    [[TMP19:%.*]] = and <64 x i1> [[TMP13]], [[TMP16]]
+; CHECK-NEXT:    [[TMP20:%.*]] = or <64 x i1> [[TMP17]], [[TMP18]]
+; CHECK-NEXT:    [[TMP21:%.*]] = or <64 x i1> [[TMP20]], [[TMP19]]
+; CHECK-NEXT:    [[TMP22:%.*]] = sext <64 x i1> [[TMP21]] to <64 x i8>
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast <64 x i8> [[TMP22]] to <32 x i16>
+; CHECK-NEXT:    [[TMP24:%.*]] = icmp ne <32 x i16> [[TMP23]], zeroinitializer
+; CHECK-NEXT:    [[TMP25:%.*]] = sext <32 x i1> [[TMP24]] to <32 x i16>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <32 x i16> [[TMP25]] to i512
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast i512 [[TMP26]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP27]], [[TMP4]]
 ; CHECK-NEXT:    [[RES:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssd.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    store <16 x i32> [[_MSPROP1]], ptr @__msan_retval_tls, align 8
 ; CHECK-NEXT:    ret <16 x i32> [[RES]]
@@ -168,13 +174,31 @@ define <16 x i32> @test_mm512_dpbssd_epi32(<16 x i32> %__W, <16 x i32> %__A, ptr
 define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %__U, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_mask_dpbssds_epi32(
 ; CHECK-SAME: <16 x i32> [[__W:%.*]], i16 zeroext [[__U:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP24:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP24]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP28]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP28]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast i512 [[TMP22]] to <16 x i32>
+; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[TMP23]], [[TMP1]]
 ; CHECK-NEXT:    [[DPI:%.*]] = tail call <16 x i32> @llvm.x86.avx10.vpdpbssds.512(<16 x i32> [[__W]], <16 x i32> [[__A]], <16 x i32> [[__B]])
 ; CHECK-NEXT:    [[TMP5:%.*]] = bitcast i16 [[TMP4]] to <16 x i1>
 ; CHECK-NEXT:    [[BST:%.*]] = bitcast i16 [[__U]] to <16 x i1>
@@ -196,13 +220,31 @@ define <16 x i32> @test_mm512_mask_dpbssds_epi32(<16 x i32> %__W, i16 zeroext %_
 define <16 x i32> @test_mm512_maskz_dpbssd_epi32(i16 zeroext %__U, <16 x i32> %__W, <16 x i32> %__A, <16 x i32> %__B) sanitize_memory {
 ; CHECK-LABEL: define <16 x i32> @test_mm512_maskz_dpbssd_epi32(
 ; CHECK-SAME: i16 zeroext [[__U:%.*]], <16 x i32> [[__W:%.*]], <16 x i32> [[__A:%.*]], <16 x i32> [[__B:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 72) to ptr), align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 136) to ptr), align 8
+; CHECK-NEXT:    [[TMP24:%.*]] = load <16 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 8) to ptr), align 8
 ; CHECK-NEXT:    [[TMP4:%.*]] = load i16, ptr @__msan_param_tls, align 8
 ; CHECK-NEXT:    call void @llvm.donothing()
-; CHECK-NEXT:    [[_MSPROP:%.*]] = or <16 x i32> [[TMP1]], [[TMP2]]
-; CHECK-NEXT:    [[_MSPROP1:%.*]] = or <16 x i32> [[_MSPROP]], [[TMP3]]
+; CHECK-NEXT:    [[TMP25:%.*]] = bitcast <16 x i32> [[__A]] to <64 x i8>
+; CHECK-NEXT:    [[TMP26:%.*]] = bitcast <16 x i32> [[__B]] to <64 x i8>
+; CHECK-NEXT:    [[TMP27:%.*]] = bitcast <16 x i32> [[TMP2]] to <64 x i8>
+; CHECK-NEXT:    [[TMP28:%.*]] = bitcast <16 x i32> [[TMP3]] to <64 x i8>
+; CHECK-NEXT:    [[TMP29:%.*]] = icmp ne <64 x i8> [[TMP27]], zeroinitializer
+; CHECK-NEXT:    [[TMP10:%.*]] = icmp ne <64 x i8> [[TMP28]], zeroinitializer
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ne <64 x i8> [[TMP25]], zeroinitializer
+; CHECK-NEXT:    [[TMP12:%.*]] = icmp ne <64 x i8> [[TMP26]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = and <64 x i1> [[TMP29]], [[TMP10]]
+; CHECK-NEXT:    [[TMP14:%.*]] = and <64 x i1> [[TMP11]], [[TMP10]]
+; CHECK-NEXT:    [[TMP15:%.*]] = and <64 x i1> [[TMP29]], [[TMP12]]
+; CHECK-NEXT:    [[TMP16:%.*]] = or <64 x i1> [[TMP13]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = or <64 x i1> [[TMP16]], [[TMP15]]
+; CHECK-NEXT:    [[TMP18:%.*]] = sext <64 x i1> [[TMP17]] to <64 x i8>
+; CHECK-NEXT:    [[TMP19:%.*]] = bitcast <64 x i8> [[TMP18]] to <32 x i16>
+; CHECK-NEXT:    [[TMP20:%.*]] = icmp ne <32 x i16> [[TMP19]], zeroinitializer
+; CHECK-NEXT:    [[TMP21:%.*]] = sext <32 x i1> [[TMP20]] to <32 x i16>
+; CHECK-NEXT:    [[TMP22:%.*]] = bitcast <32 x i16> [[TMP21]] to i512
...
[truncated]

Sa = getShadow(&I, 1);
Sb = getShadow(&I, 2);
} else {
assert(I.arg_size() == 2 || I.arg_size() == 3);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this in front of the if and leave out the else branch? It's cleaner dot png?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// Three operands:
// <4 x i32> @llvm.x86.avx512.vpdpbusd.128
// (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
// (the result of multiply-add'ing %a and %b is accumulated with %s)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What operation is "accumulated" here? Can you clarify in the comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've elaborated

@thurstond thurstond changed the title [msan] Handle AVX Vector Neural Network Instructions (VNNI) [msan] Handle multiply-add-accumulate; apply to AVX Vector Neural Network Instructions (VNNI) Aug 18, 2025
@thurstond thurstond requested a review from fmayer August 18, 2025 18:37
@thurstond thurstond merged commit 4220538 into llvm:main Aug 18, 2025
9 checks passed
thurstond added a commit to thurstond/llvm-project that referenced this pull request Aug 28, 2025
llvm#153927 incorrectly cast using a hardcoded reduction factor of two, rather than using the parameter.

This caused false negatives but not false positives. (The only incorrect case was a reduction factor of four; if four values {A,B,C,D} are being reduced, the result is fully zero iff {A,B} and {C,D} are both zero after pairwise reduction. If only one of those reduced pairs is zero, then the quadwise reduction is non-zero.)
thurstond added a commit that referenced this pull request Sep 2, 2025
…155748)

#153927 incorrectly cast using
a hardcoded reduction factor of two, rather than using the parameter.

This caused false negatives but not false positives. (The only incorrect
case was a reduction factor of four; if four values {A,B,C,D} are being
reduced, the result is fully zero iff {A,B} and {C,D} are both zero
after pairwise reduction. If only one of those reduced pairs is zero,
then the quadwise reduction is non-zero.)
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 2, 2025
…ionFactor (#155748)

llvm/llvm-project#153927 incorrectly cast using
a hardcoded reduction factor of two, rather than using the parameter.

This caused false negatives but not false positives. (The only incorrect
case was a reduction factor of four; if four values {A,B,C,D} are being
reduced, the result is fully zero iff {A,B} and {C,D} are both zero
after pairwise reduction. If only one of those reduced pairs is zero,
then the quadwise reduction is non-zero.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants