-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[AArch64][SVE] Optimize logical ops with convert.to.svbool. #160408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[AArch64][SVE] Optimize logical ops with convert.to.svbool. #160408
Conversation
|
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-transforms Author: Vladimir Miloserdov (miloserdow) ChangesFix redundant AND/OR operations with all-true SVE predicates that were not being optimized. Modified isAllActivePredicate to detect splat(i1 true) patterns and added InstCombine optimizations for sve.and_z and sve.orr_z intrinsics that apply logical identities:
Full diff: https://github.com/llvm/llvm-project/pull/160408.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 8c4b4f6e4d6de..57d69c356be28 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1587,8 +1587,21 @@ static bool isAllActivePredicate(Value *Pred) {
if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
Pred = UncastedPred;
- auto *C = dyn_cast<Constant>(Pred);
- return (C && C->isAllOnesValue());
+
+ // Also look through just convert.to.svbool if the input is an all-true splat
+ Value *ConvertArg;
+ if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
+ m_Value(ConvertArg))))
+ Pred = ConvertArg;
+ // Check for splat(i1 true) pattern used by svptrue intrinsics
+ if (auto *C = dyn_cast<Constant>(Pred)) {
+ if (C->isAllOnesValue())
+ return true;
+ if (auto *SplatVal = C->getSplatValue())
+ if (auto *CI = dyn_cast<ConstantInt>(SplatVal))
+ return CI->isOne();
+ }
+ return false;
}
// Simplify `V` by only considering the operations that affect active lanes.
@@ -2904,6 +2917,30 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
return instCombineSVEUxt(IC, II, 32);
case Intrinsic::aarch64_sme_in_streaming_mode:
return instCombineInStreamingMode(IC, II);
+ case Intrinsic::aarch64_sve_and_z:
+ case Intrinsic::aarch64_sve_orr_z: {
+ Value *Pred = II.getArgOperand(0);
+ Value *Op1 = II.getArgOperand(1);
+ Value *Op2 = II.getArgOperand(2);
+ if (isAllActivePredicate(Pred)) {
+ if (IID == Intrinsic::aarch64_sve_and_z) {
+ // svand_z(all-true, a, all-true) -> a
+ if (isAllActivePredicate(Op2))
+ return IC.replaceInstUsesWith(II, Op1);
+ // svand_z(all-true, all-true, a) -> a
+ if (isAllActivePredicate(Op1))
+ return IC.replaceInstUsesWith(II, Op2);
+ } else { // orr_z
+ // svorr_z(all-true, a, all-true) -> all-true
+ if (isAllActivePredicate(Op2))
+ return IC.replaceInstUsesWith(II, Op2);
+ // svorr_z(all-true, all-true, a) -> all-true
+ if (isAllActivePredicate(Op1))
+ return IC.replaceInstUsesWith(II, Op1);
+ }
+ }
+ break;
+ }
}
return std::nullopt;
diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-and-or-with-all-true.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-and-or-with-all-true.ll
new file mode 100644
index 0000000000000..9eff3acc12c99
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-and-or-with-all-true.ll
@@ -0,0 +1,80 @@
+; RUN: opt -passes=instcombine -mtriple aarch64 -mattr=+sve -S -o - < %s | FileCheck %s
+;
+; Test AArch64-specific InstCombine optimizations for SVE logical operations
+; with all-true predicates.
+; - a AND true = a
+; - a OR true = true
+
+declare <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.orr.z.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>)
+declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>)
+declare <vscale x 8 x i1> @llvm.aarch64.sve.pnext.nxv8i1(<vscale x 8 x i1>, <vscale x 8 x i1>)
+
+define <vscale x 16 x i1> @test_sve_and_z_all_true_right(<vscale x 16 x i1> %a) {
+; CHECK-LABEL: @test_sve_and_z_all_true_right(
+; CHECK-NEXT: ret <vscale x 16 x i1> [[A:%.*]]
+ %all_true = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+ %result = tail call <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i1> %a, <vscale x 16 x i1> %all_true)
+ ret <vscale x 16 x i1> %result
+}
+
+define <vscale x 16 x i1> @test_sve_and_z_all_true_left(<vscale x 16 x i1> %a) {
+; CHECK-LABEL: @test_sve_and_z_all_true_left(
+; CHECK-NEXT: ret <vscale x 16 x i1> [[A:%.*]]
+ %all_true = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+ %result = tail call <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i1> %all_true, <vscale x 16 x i1> %a)
+ ret <vscale x 16 x i1> %result
+}
+
+define <vscale x 16 x i1> @test_sve_orr_z_all_true_right(<vscale x 16 x i1> %a) {
+; CHECK-LABEL: @test_sve_orr_z_all_true_right(
+; CHECK-NEXT: [[ALL_TRUE:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+; CHECK-NEXT: ret <vscale x 16 x i1> [[ALL_TRUE]]
+ %all_true = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+ %result = tail call <vscale x 16 x i1> @llvm.aarch64.sve.orr.z.nxv16i1(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i1> %a, <vscale x 16 x i1> %all_true)
+ ret <vscale x 16 x i1> %result
+}
+
+define <vscale x 16 x i1> @test_sve_orr_z_all_true_left(<vscale x 16 x i1> %a) {
+; CHECK-LABEL: @test_sve_orr_z_all_true_left(
+; CHECK-NEXT: [[ALL_TRUE:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+; CHECK-NEXT: ret <vscale x 16 x i1> [[ALL_TRUE]]
+ %all_true = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+ %result = tail call <vscale x 16 x i1> @llvm.aarch64.sve.orr.z.nxv16i1(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i1> %all_true, <vscale x 16 x i1> %a)
+ ret <vscale x 16 x i1> %result
+}
+
+define <vscale x 16 x i1> @test_original_bug_case(<vscale x 16 x i1> %pg, <vscale x 16 x i1> %prev) {
+; CHECK-LABEL: @test_original_bug_case(
+; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PG:%.*]])
+; CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> [[PREV:%.*]])
+; CHECK-NEXT: [[TMP3:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.pnext.nxv8i1(<vscale x 8 x i1> [[TMP1]], <vscale x 8 x i1> [[TMP2]])
+; CHECK-NEXT: [[TMP4:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> [[TMP3]])
+; CHECK-NEXT: ret <vscale x 16 x i1> [[TMP4]]
+ %1 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %pg)
+ %2 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %prev)
+ %3 = tail call <vscale x 8 x i1> @llvm.aarch64.sve.pnext.nxv8i1(<vscale x 8 x i1> %1, <vscale x 8 x i1> %2)
+ %4 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %3)
+ %5 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+ %6 = tail call <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i1> %4, <vscale x 16 x i1> %5)
+ ret <vscale x 16 x i1> %6
+}
+
+define <vscale x 16 x i1> @test_sve_and_z_not_all_true_predicate(<vscale x 16 x i1> %pred, <vscale x 16 x i1> %a) {
+; CHECK-LABEL: @test_sve_and_z_not_all_true_predicate(
+; CHECK-NEXT: [[ALL_TRUE:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+; CHECK-NEXT: [[RESULT:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1> [[PRED:%.*]], <vscale x 16 x i1> [[A:%.*]], <vscale x 16 x i1> [[ALL_TRUE]])
+; CHECK-NEXT: ret <vscale x 16 x i1> [[RESULT]]
+ %all_true = tail call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> splat (i1 true))
+ %result = tail call <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1> %pred, <vscale x 16 x i1> %a, <vscale x 16 x i1> %all_true)
+ ret <vscale x 16 x i1> %result
+}
+
+define <vscale x 16 x i1> @test_sve_and_z_no_all_true_operands(<vscale x 16 x i1> %a, <vscale x 16 x i1> %b) {
+; CHECK-LABEL: @test_sve_and_z_no_all_true_operands(
+; CHECK-NEXT: [[RESULT:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i1> [[A:%.*]], <vscale x 16 x i1> [[B:%.*]])
+; CHECK-NEXT: ret <vscale x 16 x i1> [[RESULT]]
+ %result = tail call <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i1> %a, <vscale x 16 x i1> %b)
+ ret <vscale x 16 x i1> %result
+}
|
|
Rather than adding these explicit combines do you mind trying to extend |
26e9971 to
c61ae4b
Compare
|
@paulwalker-arm Thanks for the tip! I reworked the simplification to happen through |
llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-and-or-with-all-true.ll
Outdated
Show resolved
Hide resolved
6a5bb24 to
b201a2f
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
When both operands of a logical operation (and/or/xor) are convert.to.svbool from the same narrower type, unwrap to that type, simplify using simplifyBinOp, and rewrap the result. This eliminates redundant instructions in cases like: svand_z(svptrue_b8(), svpnext_b16(prev, pg), svptrue_b16()); Fixes llvm#160279.
b201a2f to
da65767
Compare
| @@ -0,0 +1,123 @@ | |||
| ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --tool ../../llvm-build/bin/opt | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will prevent others from regenerating the CHECK lines unless the path to their opt binary matches yours. I think you need to delete this NOTE and regenerate using --opt-binary instead of --tool.
| declare <vscale x 16 x i1> @llvm.aarch64.sve.and.z.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>) | ||
| declare <vscale x 16 x i1> @llvm.aarch64.sve.orr.z.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>) | ||
| declare <vscale x 16 x i1> @llvm.aarch64.sve.eor.z.nxv16i1(<vscale x 16 x i1>, <vscale x 16 x i1>, <vscale x 16 x i1>) | ||
| declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1>) | ||
| declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1>) | ||
| declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1>) | ||
| declare <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1>) | ||
| declare <vscale x 8 x i1> @llvm.aarch64.sve.pnext.nxv8i1(<vscale x 8 x i1>, <vscale x 8 x i1>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to add these declarations anymore.
| auto *ConvIntr1 = dyn_cast<IntrinsicInst>(Op1); | ||
| auto *ConvIntr2 = dyn_cast<IntrinsicInst>(Op2); | ||
| if (ConvIntr1 && ConvIntr2 && | ||
| ConvIntr1->getIntrinsicID() == | ||
| Intrinsic::aarch64_sve_convert_to_svbool && | ||
| ConvIntr2->getIntrinsicID() == | ||
| Intrinsic::aarch64_sve_convert_to_svbool) { | ||
| Value *NarrowOp1 = ConvIntr1->getArgOperand(0); | ||
| Value *NarrowOp2 = ConvIntr2->getArgOperand(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps use PatternMatch, something like:
Value *NarrowOp1, *NarrowOp2;
if (match(Op1, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(m_Value(NarrowOp1))) &&
match(Op2, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(m_Value(NarrowOp2))))
When both operands of a logical operation (and/or) are convert.to.svbool
from the same narrower type, unwrap to that type, simplify using simplifyBinOp,
and rewrap the result. This eliminates redundant instructions in cases like:
svand_z(svptrue_b8(), svpnext_b16(prev, pg), svptrue_b16());
Fixes #160279.