Skip to content

Conversation

@UsmanNadeem
Copy link
Contributor

We can narrow trunc(lshr(i32)) to i8 to trunc(lshr(i16)) to i8 even when the bits that we are shifting in are not zero, in the cases where the MSBs of the shifted value don't actually matter and actually end up being truncated away.

This kind of narrowing does not remove the trunc but can help the vectorizer generate better code in a smaller type.
Motivation: libyuv, functions like ARGBToUV444Row_C().

Proof: https://alive2.llvm.org/ce/z/9Ao2aJ

Change-Id: I8ec210209522573a97773201e08dfea8d6b9d78d
We can narrow `trunc(lshr(i32)) to i8` to `trunc(lshr(i16)) to i8` even when the bits that we are shifting in are not zero, in the cases where the MSBs of the shifted value don't actually matter and actually end up being truncated away.

This kind of narrowing does not remove the trunc but can help the vectorizer generate better code in a smaller type.
Motivation: libyuv, functions like ARGBToUV444Row_C().

Change-Id: I681a247eac20a4fcf68e54d4a5009f594030a387
Proof: https://alive2.llvm.org/ce/z/9Ao2aJ
@UsmanNadeem UsmanNadeem requested a review from dtcxzyw May 12, 2025 23:44
@UsmanNadeem UsmanNadeem requested a review from nikic as a code owner May 12, 2025 23:44
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms labels May 12, 2025
@llvmbot
Copy link
Member

llvmbot commented May 12, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Usman Nadeem (UsmanNadeem)

Changes

We can narrow trunc(lshr(i32)) to i8 to trunc(lshr(i16)) to i8 even when the bits that we are shifting in are not zero, in the cases where the MSBs of the shifted value don't actually matter and actually end up being truncated away.

This kind of narrowing does not remove the trunc but can help the vectorizer generate better code in a smaller type.
Motivation: libyuv, functions like ARGBToUV444Row_C().

Proof: https://alive2.llvm.org/ce/z/9Ao2aJ


Full diff: https://github.com/llvm/llvm-project/pull/139645.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+16-6)
  • (modified) llvm/test/Transforms/InstCombine/cast.ll (+96)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index d6c99366e6f00..b47a82a542bfb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -51,6 +51,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
     Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
     Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
     Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
+    if (Opc == Instruction::LShr || Opc == Instruction::AShr)
+      Res->setIsExact(I->isExact());
     break;
   }
   case Instruction::Trunc:
@@ -319,13 +321,21 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
     //       zero - use AmtKnownBits.getMaxValue().
     uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
     uint32_t BitWidth = Ty->getScalarSizeInBits();
-    KnownBits AmtKnownBits =
-        llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+    KnownBits AmtKnownBits = IC.computeKnownBits(I->getOperand(1), 0, CxtI);
+    APInt MaxShiftAmt = AmtKnownBits.getMaxValue();
     APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
-    if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
-        IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) {
-      return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
-             canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+    if (MaxShiftAmt.ult(BitWidth)) {
+      if (IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI))
+        return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+               canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+      // If the only user is a trunc then we can narrow the shift if any new
+      // MSBs are not going to be used.
+      if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) {
+        auto DemandedBits = Trunc->getType()->getScalarSizeInBits();
+        if (MaxShiftAmt.ule(BitWidth - DemandedBits))
+          return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+                 canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+      }
     }
     break;
   }
diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll
index 0f957e22ad17b..8485a01e3180a 100644
--- a/llvm/test/Transforms/InstCombine/cast.ll
+++ b/llvm/test/Transforms/InstCombine/cast.ll
@@ -5,6 +5,7 @@
 ; RUN: opt < %s -passes=instcombine -S -data-layout="E-p:64:64:64-p1:32:32:32-p2:64:64:64-p3:64:64:64-a0:0:8-f32:32:32-f64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-v64:64:64-v128:128:128-n8:16:32:64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ALL,BE
 ; RUN: opt < %s -passes=instcombine -S -data-layout="e-p:64:64:64-p1:32:32:32-p2:64:64:64-p3:64:64:64-a0:0:8-f32:32:32-f64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-v64:64:64-v128:128:128-n8:16:32:64" -use-constant-fp-for-fixed-length-splat -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=ALL,LE
 
+declare void @use_i8(i8)
 declare void @use_i32(i32)
 declare void @use_v2i32(<2 x i32>)
 
@@ -2041,6 +2042,101 @@ define <2 x i8> @trunc_lshr_zext_uses1(<2 x i8> %A) {
   ret <2 x i8> %D
 }
 
+define i8 @trunc_lshr_ext_halfWidth(i16 %a, i16 %b, i16 range(i16 0, 8) %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth(
+; ALL-NEXT:    [[ADD:%.*]] = add i16 [[A:%.*]], [[B:%.*]]
+; ALL-NEXT:    [[SHR:%.*]] = lshr i16 [[ADD]], [[SHIFTAMT:%.*]]
+; ALL-NEXT:    [[TRUNC:%.*]] = trunc i16 [[SHR]] to i8
+; ALL-NEXT:    ret i8 [[TRUNC]]
+;
+  %zext_a = zext i16 %a to i32
+  %zext_b = zext i16 %b to i32
+  %zext_shiftAmt = zext i16 %shiftAmt to i32
+  %add = add nuw nsw i32 %zext_a, %zext_b
+  %shr = lshr i32 %add, %zext_shiftAmt
+  %trunc = trunc i32 %shr to i8
+  ret i8 %trunc
+}
+
+define i8 @trunc_lshr_ext_halfWidth_rhsRange_neg(i16 %a, i16 %b, i16 %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_rhsRange_neg(
+; ALL-NEXT:    [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
+; ALL-NEXT:    [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
+; ALL-NEXT:    [[ZEXT_SHIFTAMT:%.*]] = zext nneg i16 [[SHIFTAMT:%.*]] to i32
+; ALL-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT:    [[SHR:%.*]] = lshr i32 [[ADD]], [[ZEXT_SHIFTAMT]]
+; ALL-NEXT:    [[TRUNC:%.*]] = trunc i32 [[SHR]] to i8
+; ALL-NEXT:    ret i8 [[TRUNC]]
+;
+  %zext_a = zext i16 %a to i32
+  %zext_b = zext i16 %b to i32
+  %zext_shiftAmt = zext i16 %shiftAmt to i32
+  %add = add nuw nsw i32 %zext_a, %zext_b
+  %shr = lshr i32 %add, %zext_shiftAmt
+  %trunc = trunc i32 %shr to i8
+  ret i8 %trunc
+}
+
+define i8 @trunc_lshr_ext_halfWidth_twouse_neg1(i16 %a, i16 %b, i16 range(i16 0, 8) %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_twouse_neg1(
+; ALL-NEXT:    [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
+; ALL-NEXT:    [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
+; ALL-NEXT:    [[ZEXT_SHIFTAMT:%.*]] = zext nneg i16 [[SHIFTAMT:%.*]] to i32
+; ALL-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT:    call void @use_i32(i32 [[ADD]])
+; ALL-NEXT:    [[SHR:%.*]] = lshr i32 [[ADD]], [[ZEXT_SHIFTAMT]]
+; ALL-NEXT:    [[TRUNC:%.*]] = trunc i32 [[SHR]] to i8
+; ALL-NEXT:    ret i8 [[TRUNC]]
+;
+  %zext_a = zext i16 %a to i32
+  %zext_b = zext i16 %b to i32
+  %zext_shiftAmt = zext i16 %shiftAmt to i32
+  %add = add nuw nsw i32 %zext_a, %zext_b
+  call void @use_i32(i32 %add)
+  %shr = lshr i32 %add, %zext_shiftAmt
+  %trunc = trunc i32 %shr to i8
+  ret i8 %trunc
+}
+
+define i8 @trunc_lshr_ext_halfWidth_twouse_neg2(i16 %a, i16 %b, i16 range(i16 0, 8) %shiftAmt) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_twouse_neg2(
+; ALL-NEXT:    [[ZEXT_A:%.*]] = zext i16 [[A:%.*]] to i32
+; ALL-NEXT:    [[ZEXT_B:%.*]] = zext i16 [[B:%.*]] to i32
+; ALL-NEXT:    [[ZEXT_SHIFTAMT:%.*]] = zext nneg i16 [[SHIFTAMT:%.*]] to i32
+; ALL-NEXT:    [[ADD:%.*]] = add nuw nsw i32 [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT:    [[SHR:%.*]] = lshr i32 [[ADD]], [[ZEXT_SHIFTAMT]]
+; ALL-NEXT:    call void @use_i32(i32 [[SHR]])
+; ALL-NEXT:    [[TRUNC:%.*]] = trunc i32 [[SHR]] to i8
+; ALL-NEXT:    ret i8 [[TRUNC]]
+;
+  %zext_a = zext i16 %a to i32
+  %zext_b = zext i16 %b to i32
+  %zext_shiftAmt = zext i16 %shiftAmt to i32
+  %add = add nuw nsw i32 %zext_a, %zext_b
+  %shr = lshr i32 %add, %zext_shiftAmt
+  call void @use_i32(i32 %shr)
+  %trunc = trunc i32 %shr to i8
+  ret i8 %trunc
+}
+
+; The narrowing transform only happens for integer types.
+define <2 x i8> @trunc_lshr_ext_halfWidth_vector_neg(<2 x i16> %a, <2 x i16> %b) {
+; ALL-LABEL: @trunc_lshr_ext_halfWidth_vector_neg(
+; ALL-NEXT:    [[ZEXT_A:%.*]] = zext <2 x i16> [[A:%.*]] to <2 x i32>
+; ALL-NEXT:    [[ZEXT_B:%.*]] = zext <2 x i16> [[B:%.*]] to <2 x i32>
+; ALL-NEXT:    [[ADD:%.*]] = add nuw nsw <2 x i32> [[ZEXT_A]], [[ZEXT_B]]
+; ALL-NEXT:    [[SHR:%.*]] = lshr <2 x i32> [[ADD]], splat (i32 6)
+; ALL-NEXT:    [[TRUNC:%.*]] = trunc <2 x i32> [[SHR]] to <2 x i8>
+; ALL-NEXT:    ret <2 x i8> [[TRUNC]]
+;
+  %zext_a = zext <2 x i16> %a to <2 x i32>
+  %zext_b = zext <2 x i16> %b to <2 x i32>
+  %add = add nuw nsw <2 x i32> %zext_a, %zext_b
+  %shr = lshr <2 x i32> %add, <i32 6, i32 6>
+  %trunc = trunc <2 x i32> %shr to <2 x i8>
+  ret <2 x i8> %trunc
+}
+
 ; The following four tests sext + lshr + trunc patterns.
 ; PR33078
 

Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
if (Opc == Instruction::LShr || Opc == Instruction::AShr)
Res->setIsExact(I->isExact());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks unrelated. Is it necessary to avoid regression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, some other tests regress without it.

// MSBs are not going to be used.
if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) {
auto DemandedBits = Trunc->getType()->getScalarSizeInBits();
if (MaxShiftAmt.ule(BitWidth - DemandedBits))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BitWidth - DemandedBits may wrap. Both DemandedBits and BitWidth are only guaranteed to be less than OrigBitWidth. But we don't know which one is larger.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote it to if ((MaxShiftAmt + DemandedBits).ule(BitWidth))

Change-Id: Ifee1fa5fe63a88ab40621c252591a0c620225d37
Change-Id: I056c38da9089b57146b33d94d88383d81a3f15de
Change-Id: I33495fe6248a9f8e7220c8a26c3beb81bb30c645
Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!
If we want to further generalize this optimization, we may need to add a parameter like DemandedLowBits.

@UsmanNadeem UsmanNadeem merged commit b931731 into llvm:main May 13, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants