Skip to content

Conversation

@llvmbot
Copy link
Member

@llvmbot llvmbot commented Nov 28, 2024

Backport 12cefcc

Requested by: @fhahn

lowerDotProduct called above may already lower a matrix multiply and
mark it as procssed by adding it to FusedInsts. Don't try to process it
again in LowerMatrixMultiplyFused by checking if FusedInsts.

Without this change, we trigger an assertion when trying to erase the
same original matrix multiply twice.

(cherry picked from commit 12cefcc)
@llvmbot
Copy link
Member Author

llvmbot commented Nov 28, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (llvmbot)

Changes

Backport 12cefcc

Requested by: @fhahn


Full diff: https://github.com/llvm/llvm-project/pull/118020.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp (+2-1)
  • (added) llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll (+49)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 6a681fd9339717..a44a123fdf8cda 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1014,7 +1014,8 @@ class LowerMatrixIntrinsics {
 
     // Third, try to fuse candidates.
     for (CallInst *CI : MaybeFusableInsts)
-      LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
+      if (!FusedInsts.contains(CI))
+        LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
 
     Changed = !FusedInsts.empty();
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll
new file mode 100644
index 00000000000000..b78d56646d9e4d
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-int-also-fusable-multiply.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -p lower-matrix-intrinsics -S %s | FileCheck %s
+
+define void @test(ptr %p, <8 x i32> %x) {
+; CHECK-LABEL: define void @test(
+; CHECK-SAME: ptr [[P:%.*]], <8 x i32> [[X:%.*]]) {
+; CHECK-NEXT:    [[L:%.*]] = load <8 x i32>, ptr [[P]], align 4
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> zeroinitializer
+; CHECK-NEXT:    [[SPLIT1:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 1>
+; CHECK-NEXT:    [[SPLIT2:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 2>
+; CHECK-NEXT:    [[SPLIT3:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 3>
+; CHECK-NEXT:    [[SPLIT4:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 4>
+; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 5>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 6>
+; CHECK-NEXT:    [[SPLIT7:%.*]] = shufflevector <8 x i32> [[X]], <8 x i32> poison, <1 x i32> <i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <1 x i32> [[SPLIT]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <8 x i32> poison, i32 [[TMP1]], i64 0
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <1 x i32> [[SPLIT1]], i64 0
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <8 x i32> [[TMP2]], i32 [[TMP3]], i64 1
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <1 x i32> [[SPLIT2]], i64 0
+; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <8 x i32> [[TMP4]], i32 [[TMP5]], i64 2
+; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <1 x i32> [[SPLIT3]], i64 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <8 x i32> [[TMP6]], i32 [[TMP7]], i64 3
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <1 x i32> [[SPLIT4]], i64 0
+; CHECK-NEXT:    [[TMP10:%.*]] = insertelement <8 x i32> [[TMP8]], i32 [[TMP9]], i64 4
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <1 x i32> [[SPLIT5]], i64 0
+; CHECK-NEXT:    [[TMP12:%.*]] = insertelement <8 x i32> [[TMP10]], i32 [[TMP11]], i64 5
+; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <1 x i32> [[SPLIT6]], i64 0
+; CHECK-NEXT:    [[TMP14:%.*]] = insertelement <8 x i32> [[TMP12]], i32 [[TMP13]], i64 6
+; CHECK-NEXT:    [[TMP15:%.*]] = extractelement <1 x i32> [[SPLIT7]], i64 0
+; CHECK-NEXT:    [[TMP16:%.*]] = insertelement <8 x i32> [[TMP14]], i32 [[TMP15]], i64 7
+; CHECK-NEXT:    [[TMP17:%.*]] = mul <8 x i32> [[L]], [[TMP16]]
+; CHECK-NEXT:    [[TMP18:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP17]])
+; CHECK-NEXT:    [[TMP19:%.*]] = insertelement <1 x i32> poison, i32 [[TMP18]], i64 0
+; CHECK-NEXT:    [[E:%.*]] = extractelement <1 x i32> [[TMP19]], i64 0
+; CHECK-NEXT:    store i32 [[E]], ptr [[P]], align 4
+; CHECK-NEXT:    ret void
+;
+  %l = load <8 x i32>, ptr %p, align 4
+  %t = tail call <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32> %x, i32 1, i32 8)
+  %m = tail call <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32> %l, <8 x i32> %t, i32 1, i32 8, i32 1)
+  %e = extractelement <1 x i32> %m, i64 0
+  store i32 %e, ptr %p, align 4
+  ret void
+}
+
+declare <8 x i32> @llvm.matrix.transpose.v8i32(<8 x i32>, i32 immarg, i32 immarg)
+
+declare <1 x i32> @llvm.matrix.multiply.v1i32.v8i32.v8i32(<8 x i32>, <8 x i32>, i32 immarg, i32 immarg, i32 immarg)

@tru
Copy link
Collaborator

tru commented Dec 17, 2024

Who can review this?

@fhahn
Copy link
Contributor

fhahn commented Dec 17, 2024

@anemet or @francisvm , but at this point it may not be worth trying to pick this

@fhahn fhahn closed this Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Development

Successfully merging this pull request may close these issues.

3 participants