From dcb49a088ad6b34071a181923acdb142e281e7a4 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Fri, 14 Feb 2025 21:17:15 +0530 Subject: [PATCH 1/5] [ValueTracking] Pre-Commit Tests --- .../InstCombine/compute-sign-bits-bitcast.ll | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll diff --git a/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll new file mode 100644 index 0000000000000..887720000477a --- /dev/null +++ b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll @@ -0,0 +1,35 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=instcombine -S < %s | FileCheck %s + +define i32 @test_compute_sign_bits() { +; CHECK-LABEL: define i32 @test_compute_sign_bits() { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: ret i32 -1 +; +entry: + %a = add i8 -1, 0 + %b = bitcast i8 %a to <4 x i2> + %c = ashr <4 x i2> %b, + %d = bitcast <4 x i2> %c to i8 + %e = sext i8 %d to i32 + ret i32 %e +} + +; Test with sign extension to ensure proper sign bit tracking +define <4 x i2> @test_sext_bitcast(<1 x i8> %a0, <1 x i8> %a1) { +; CHECK-LABEL: define <4 x i2> @test_sext_bitcast( +; CHECK-SAME: <1 x i8> [[A0:%.*]], <1 x i8> [[A1:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <1 x i8> [[A0]], [[A1]] +; CHECK-NEXT: [[EXT:%.*]] = sext <1 x i1> [[CMP]] to <1 x i8> +; CHECK-NEXT: [[SUB:%.*]] = bitcast <1 x i8> [[EXT]] to <4 x i2> +; CHECK-NEXT: [[RESULT:%.*]] = ashr <4 x i2> [[SUB]], splat (i2 1) +; CHECK-NEXT: ret <4 x i2> [[RESULT]] +; +entry: + %cmp = icmp sgt <1 x i8> %a0, %a1 + %ext = sext <1 x i1> %cmp to <1 x i8> + %sub = bitcast <1 x i8> %ext to <4 x i2> + %result = ashr <4 x i2> %sub, + ret <4 x i2> %result +} From 9db0ba72aa532d9bf4e9949f783589f738678045 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Sat, 15 Feb 2025 01:54:15 +0530 Subject: [PATCH 2/5] [ValueTracking] add basic handling of BITCAST nodes --- llvm/lib/Analysis/ValueTracking.cpp | 29 +++++++++++++++++++ .../InstCombine/compute-sign-bits-bitcast.ll | 3 +- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 8a9ad55366ee7..c28c91b0a511e 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3922,6 +3922,35 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, if (auto *U = dyn_cast(V)) { switch (Operator::getOpcode(V)) { default: break; + case Instruction::BitCast: { + Value *Src = U->getOperand(0); + Type *SrcTy = Src->getType(); + Type *SrcScalarTy = SrcTy->getScalarType(); + + if (!SrcScalarTy->isIntegerTy()) + break; + + unsigned SrcBits = SrcTy->getScalarSizeInBits(); + + if ((SrcBits % TyBits) != 0) + break; + + if (auto *DstVTy = dyn_cast(Ty)) { + unsigned Scale = SrcBits / TyBits; + + APInt SrcDemandedElts = + APInt::getSplat(DstVTy->getNumElements() / Scale, APInt(1, 1)); + + Tmp = ComputeNumSignBits(Src, SrcDemandedElts, Depth + 1, Q); + if (Tmp == SrcBits) + return TyBits; + } else { + Tmp = ComputeNumSignBits(Src, APInt(1, 1), Depth + 1, Q); + if (Tmp == SrcBits) + return TyBits; + } + break; + } case Instruction::SExt: Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); return ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q) + diff --git a/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll index 887720000477a..35ba53687c646 100644 --- a/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll +++ b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll @@ -23,8 +23,7 @@ define <4 x i2> @test_sext_bitcast(<1 x i8> %a0, <1 x i8> %a1) { ; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <1 x i8> [[A0]], [[A1]] ; CHECK-NEXT: [[EXT:%.*]] = sext <1 x i1> [[CMP]] to <1 x i8> ; CHECK-NEXT: [[SUB:%.*]] = bitcast <1 x i8> [[EXT]] to <4 x i2> -; CHECK-NEXT: [[RESULT:%.*]] = ashr <4 x i2> [[SUB]], splat (i2 1) -; CHECK-NEXT: ret <4 x i2> [[RESULT]] +; CHECK-NEXT: ret <4 x i2> [[SUB]] ; entry: %cmp = icmp sgt <1 x i8> %a0, %a1 From a5445e7cd44f63ab4fba704f4f595f06faabcc07 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Mon, 3 Mar 2025 18:14:28 +0530 Subject: [PATCH 3/5] Check for scalar type --- llvm/lib/Analysis/ValueTracking.cpp | 22 +++--- .../InstCombine/compute-sign-bits-bitcast.ll | 69 +++++++++++++------ 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index c28c91b0a511e..a90f12ae4b1e6 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3936,18 +3936,18 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, break; if (auto *DstVTy = dyn_cast(Ty)) { - unsigned Scale = SrcBits / TyBits; + if (auto *SrcVTy = dyn_cast(SrcTy)) { + APInt SrcDemandedElts = + APInt::getSplat(SrcVTy->getNumElements(), APInt(1, 1)); - APInt SrcDemandedElts = - APInt::getSplat(DstVTy->getNumElements() / Scale, APInt(1, 1)); - - Tmp = ComputeNumSignBits(Src, SrcDemandedElts, Depth + 1, Q); - if (Tmp == SrcBits) - return TyBits; - } else { - Tmp = ComputeNumSignBits(Src, APInt(1, 1), Depth + 1, Q); - if (Tmp == SrcBits) - return TyBits; + Tmp = ComputeNumSignBits(Src, SrcDemandedElts, Depth + 1, Q); + if (Tmp == SrcBits) + return TyBits; + } else { + Tmp = ComputeNumSignBits(Src, APInt(1, 1), Depth + 1, Q); + if (Tmp == SrcBits) + return TyBits; + } } break; } diff --git a/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll index 35ba53687c646..f80e4355e834e 100644 --- a/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll +++ b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll @@ -1,34 +1,61 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 ; RUN: opt -passes=instcombine -S < %s | FileCheck %s -define i32 @test_compute_sign_bits() { -; CHECK-LABEL: define i32 @test_compute_sign_bits() { -; CHECK-NEXT: [[ENTRY:.*:]] -; CHECK-NEXT: ret i32 -1 -; -entry: - %a = add i8 -1, 0 - %b = bitcast i8 %a to <4 x i2> - %c = ashr <4 x i2> %b, - %d = bitcast <4 x i2> %c to i8 - %e = sext i8 %d to i32 - ret i32 %e -} - -; Test with sign extension to ensure proper sign bit tracking -define <4 x i2> @test_sext_bitcast(<1 x i8> %a0, <1 x i8> %a1) { -; CHECK-LABEL: define <4 x i2> @test_sext_bitcast( +; Case 1: Vector to Vector bitcast +define <4 x i2> @test_vector_to_vector(<1 x i8> %a0, <1 x i8> %a1) { +; CHECK-LABEL: define <4 x i2> @test_vector_to_vector( ; CHECK-SAME: <1 x i8> [[A0:%.*]], <1 x i8> [[A1:%.*]]) { -; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <1 x i8> [[A0]], [[A1]] ; CHECK-NEXT: [[EXT:%.*]] = sext <1 x i1> [[CMP]] to <1 x i8> ; CHECK-NEXT: [[SUB:%.*]] = bitcast <1 x i8> [[EXT]] to <4 x i2> ; CHECK-NEXT: ret <4 x i2> [[SUB]] ; -entry: %cmp = icmp sgt <1 x i8> %a0, %a1 %ext = sext <1 x i1> %cmp to <1 x i8> %sub = bitcast <1 x i8> %ext to <4 x i2> - %result = ashr <4 x i2> %sub, - ret <4 x i2> %result + %sra = ashr <4 x i2> %sub, + ret <4 x i2> %sra +} + +; Case 2: Scalar to Vector bitcast +define <2 x i16> @test_scalar_to_vector(i1 %cond) { +; CHECK-LABEL: define <2 x i16> @test_scalar_to_vector( +; CHECK-SAME: i1 [[COND:%.*]]) { +; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[COND]] to i32 +; CHECK-NEXT: [[BC:%.*]] = bitcast i32 [[EXT]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[BC]] +; + %ext = sext i1 %cond to i32 + %bc = bitcast i32 %ext to <2 x i16> + %sra = ashr <2 x i16> %bc, + ret <2 x i16> %sra +} + + +; Case 3: Multiple right shifts +define <8 x i8> @test_multiple_shifts(i1 %cond) { +; CHECK-LABEL: define <8 x i8> @test_multiple_shifts( +; CHECK-SAME: i1 [[COND:%.*]]) { +; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[COND]] to i64 +; CHECK-NEXT: [[BC:%.*]] = bitcast i64 [[EXT]] to <8 x i8> +; CHECK-NEXT: ret <8 x i8> [[BC]] +; + %ext = sext i1 %cond to i64 + %bc = bitcast i64 %ext to <8 x i8> + %sra1 = ashr <8 x i8> %bc, + %sra2 = ashr <8 x i8> %sra1, + ret <8 x i8> %sra2 +} + +; Case 4: Test with non-sign-extended source +define <4 x i8> @test_non_sign_extended(i32 %val) { +; CHECK-LABEL: define <4 x i8> @test_non_sign_extended( +; CHECK-SAME: i32 [[VAL:%.*]]) { +; CHECK-NEXT: [[BC:%.*]] = bitcast i32 [[VAL]] to <4 x i8> +; CHECK-NEXT: [[SRA:%.*]] = ashr <4 x i8> [[BC]], splat (i8 1) +; CHECK-NEXT: ret <4 x i8> [[SRA]] +; + %bc = bitcast i32 %val to <4 x i8> + %sra = ashr <4 x i8> %bc, + ret <4 x i8> %sra } From be31853ae2adc963ec9eeca4c0bbee655da3697f Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Tue, 4 Mar 2025 23:00:50 +0530 Subject: [PATCH 4/5] refactor --- llvm/lib/Analysis/ValueTracking.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index a90f12ae4b1e6..ca18effaedf87 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3925,9 +3925,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, case Instruction::BitCast: { Value *Src = U->getOperand(0); Type *SrcTy = Src->getType(); - Type *SrcScalarTy = SrcTy->getScalarType(); - if (!SrcScalarTy->isIntegerTy()) + if (!SrcTy->isIntOrIntVectorTy()) break; unsigned SrcBits = SrcTy->getScalarSizeInBits(); @@ -3935,7 +3934,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, if ((SrcBits % TyBits) != 0) break; - if (auto *DstVTy = dyn_cast(Ty)) { + if (isa(Ty)) { if (auto *SrcVTy = dyn_cast(SrcTy)) { APInt SrcDemandedElts = APInt::getSplat(SrcVTy->getNumElements(), APInt(1, 1)); From 4a4602150700941df1ef1e03567e7e48cc166ae9 Mon Sep 17 00:00:00 2001 From: Narayan Sreekumar Date: Wed, 5 Mar 2025 18:20:51 +0530 Subject: [PATCH 5/5] offload demandedElts to wrapper function --- llvm/lib/Analysis/ValueTracking.cpp | 21 ++++++++----------- .../InstCombine/compute-sign-bits-bitcast.ll | 2 +- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index ca18effaedf87..e1196aa6f22ed 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3926,27 +3926,24 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, Value *Src = U->getOperand(0); Type *SrcTy = Src->getType(); + // Skip if the source type is not an integer or integer vector type + // This ensures we only process integer-like types if (!SrcTy->isIntOrIntVectorTy()) break; unsigned SrcBits = SrcTy->getScalarSizeInBits(); + // Bitcast 'large element' scalar/vector to 'small element' vector. if ((SrcBits % TyBits) != 0) break; + // Only proceed if the destination type is a fixed-size vector if (isa(Ty)) { - if (auto *SrcVTy = dyn_cast(SrcTy)) { - APInt SrcDemandedElts = - APInt::getSplat(SrcVTy->getNumElements(), APInt(1, 1)); - - Tmp = ComputeNumSignBits(Src, SrcDemandedElts, Depth + 1, Q); - if (Tmp == SrcBits) - return TyBits; - } else { - Tmp = ComputeNumSignBits(Src, APInt(1, 1), Depth + 1, Q); - if (Tmp == SrcBits) - return TyBits; - } + // Fast case - sign splat can be simply split across the small elements. + // This works for both vector and scalar sources + Tmp = ComputeNumSignBits(Src, Depth + 1, Q); + if (Tmp == SrcBits) + return TyBits; } break; } diff --git a/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll index f80e4355e834e..1da304f64a1ee 100644 --- a/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll +++ b/llvm/test/Transforms/InstCombine/compute-sign-bits-bitcast.ll @@ -47,7 +47,7 @@ define <8 x i8> @test_multiple_shifts(i1 %cond) { ret <8 x i8> %sra2 } -; Case 4: Test with non-sign-extended source +; (Negative) Case 4: Test with non-sign-extended source define <4 x i8> @test_non_sign_extended(i32 %val) { ; CHECK-LABEL: define <4 x i8> @test_non_sign_extended( ; CHECK-SAME: i32 [[VAL:%.*]]) {