Skip to content

Conversation

@juliannagele
Copy link
Member

@juliannagele juliannagele commented Oct 22, 2025

This change aims to convert vector loads to scalar loads, if they are only converted to scalars after anyway.

alive2 proof: https://alive2.llvm.org/ce/z/U_rvht

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Julian Nagele (juliannagele)

Changes

This change aims to convert vector loads to scalar loads, if they are only converted to scalars after anyway.


Full diff: https://github.com/llvm/llvm-project/pull/164682.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+116-28)
  • (added) llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll (+136)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d6eb00da11dc8..e045282c387fe 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -129,7 +129,9 @@ class VectorCombine {
   bool foldExtractedCmps(Instruction &I);
   bool foldBinopOfReductions(Instruction &I);
   bool foldSingleElementStore(Instruction &I);
-  bool scalarizeLoadExtract(Instruction &I);
+  bool scalarizeLoad(Instruction &I);
+  bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
+  bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
   bool scalarizeExtExtract(Instruction &I);
   bool foldConcatOfBoolMasks(Instruction &I);
   bool foldPermuteOfBinops(Instruction &I);
@@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
   return false;
 }
 
-/// Try to scalarize vector loads feeding extractelement instructions.
-bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
-  if (!TTI.allowVectorElementIndexingUsingGEP())
-    return false;
-
+/// Try to scalarize vector loads feeding extractelement or bitcast
+/// instructions.
+bool VectorCombine::scalarizeLoad(Instruction &I) {
   Value *Ptr;
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
   auto *LI = cast<LoadInst>(&I);
   auto *VecTy = cast<VectorType>(LI->getType());
-  if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
+  if (!VecTy || LI->isVolatile() ||
+      !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
     return false;
 
-  InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
-                          LI->getPointerAddressSpace(), CostKind);
-  InstructionCost ScalarizedCost = 0;
-
+  // Check what type of users we have and ensure no memory modifications betwwen
+  // the load and its users.
+  bool AllExtracts = true;
+  bool AllBitcasts = true;
   Instruction *LastCheckedInst = LI;
   unsigned NumInstChecked = 0;
-  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
-  auto FailureGuard = make_scope_exit([&]() {
-    // If the transform is aborted, discard the ScalarizationResults.
-    for (auto &Pair : NeedFreeze)
-      Pair.second.discard();
-  });
 
-  // Check if all users of the load are extracts with no memory modifications
-  // between the load and the extract. Compute the cost of both the original
-  // code and the scalarized version.
   for (User *U : LI->users()) {
-    auto *UI = dyn_cast<ExtractElementInst>(U);
-    if (!UI || UI->getParent() != LI->getParent())
+    auto *UI = dyn_cast<Instruction>(U);
+    if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
       return false;
 
-    // If any extract is waiting to be erased, then bail out as this will
+    // If any user is waiting to be erased, then bail out as this will
     // distort the cost calculation and possibly lead to infinite loops.
     if (UI->use_empty())
       return false;
 
-    // Check if any instruction between the load and the extract may modify
-    // memory.
+    if (!isa<ExtractElementInst>(UI))
+      AllExtracts = false;
+    if (!isa<BitCastInst>(UI))
+      AllBitcasts = false;
+
+    // Check if any instruction between the load and the user may modify memory.
     if (LastCheckedInst->comesBefore(UI)) {
       for (Instruction &I :
            make_range(std::next(LI->getIterator()), UI->getIterator())) {
@@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       }
       LastCheckedInst = UI;
     }
+  }
+
+  if (AllExtracts)
+    return scalarizeLoadExtract(LI, VecTy, Ptr);
+  if (AllBitcasts)
+    return scalarizeLoadBitcast(LI, VecTy, Ptr);
+  return false;
+}
+
+/// Try to scalarize vector loads feeding extractelement instructions.
+bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
+                                         Value *Ptr) {
+  if (!TTI.allowVectorElementIndexingUsingGEP())
+    return false;
+
+  InstructionCost OriginalCost =
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+  InstructionCost ScalarizedCost = 0;
+
+  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
+  auto FailureGuard = make_scope_exit([&]() {
+    // If the transform is aborted, discard the ScalarizationResults.
+    for (auto &Pair : NeedFreeze)
+      Pair.second.discard();
+  });
+
+  for (User *U : LI->users()) {
+    auto *UI = cast<ExtractElementInst>(U);
 
     auto ScalarIdx =
         canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
@@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
                                                     nullptr, nullptr, CostKind);
   }
 
-  LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
+  LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
                     << "\n  LoadExtractCost: " << OriginalCost
                     << " vs ScalarizedCost: " << ScalarizedCost << "\n");
 
@@ -1966,6 +1990,70 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+/// Try to scalarize vector loads feeding bitcast instructions.
+bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
+                                         Value *Ptr) {
+  InstructionCost OriginalCost =
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+
+  Type *TargetScalarType = nullptr;
+  unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
+
+  for (User *U : LI->users()) {
+    auto *BC = cast<BitCastInst>(U);
+
+    Type *DestTy = BC->getDestTy();
+    if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
+      return false;
+
+    unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
+    if (DestBitWidth != VecBitWidth)
+      return false;
+
+    // All bitcasts should target the same scalar type.
+    if (!TargetScalarType)
+      TargetScalarType = DestTy;
+    else if (TargetScalarType != DestTy)
+      return false;
+
+    OriginalCost +=
+        TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
+                             TTI.getCastContextHint(BC), CostKind, BC);
+  }
+
+  if (!TargetScalarType || LI->user_empty())
+    return false;
+  InstructionCost ScalarizedCost =
+      TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+
+  LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
+                    << "\n  OriginalCost: " << OriginalCost
+                    << " vs ScalarizedCost: " << ScalarizedCost << "\n");
+
+  if (ScalarizedCost >= OriginalCost)
+    return false;
+
+  // Ensure we add the load back to the worklist BEFORE its users so they can
+  // erased in the correct order.
+  Worklist.push(LI);
+
+  Builder.SetInsertPoint(LI);
+  auto *ScalarLoad =
+      Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
+  ScalarLoad->setAlignment(LI->getAlign());
+  ScalarLoad->copyMetadata(*LI);
+
+  // Replace all bitcast users with the scalar load.
+  for (User *U : LI->users()) {
+    auto *BC = cast<BitCastInst>(U);
+    replaceValue(*BC, *ScalarLoad, false);
+  }
+
+  return true;
+}
+
 bool VectorCombine::scalarizeExtExtract(Instruction &I) {
   if (!TTI.allowVectorElementIndexingUsingGEP())
     return false;
@@ -4555,7 +4643,7 @@ bool VectorCombine::run() {
     if (IsVectorType) {
       if (scalarizeOpOrCmp(I))
         return true;
-      if (scalarizeLoadExtract(I))
+      if (scalarizeLoad(I))
         return true;
       if (scalarizeExtExtract(I))
         return true;
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
new file mode 100644
index 0000000000000..464e5129262bc
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
@@ -0,0 +1,136 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
+
+define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    ret i32 [[R_SCALAR]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to i32
+  ret i32 %r
+}
+
+define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
+; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
+; CHECK-NEXT:    ret i64 [[R_SCALAR]]
+;
+  %lv = load <2 x i32>, ptr %x
+  %r = bitcast <2 x i32> %lv to i64
+  ret i64 %r
+}
+
+define float @load_v4i8_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT:    ret float [[R_SCALAR]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to float
+  ret float %r
+}
+
+define float @load_v2i16_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT:    ret float [[R_SCALAR]]
+;
+  %lv = load <2 x i16>, ptr %x
+  %r = bitcast <2 x i16> %lv to float
+  ret float %r
+}
+
+define double @load_v4i16_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
+; CHECK-NEXT:    ret double [[R_SCALAR]]
+;
+  %lv = load <4 x i16>, ptr %x
+  %r = bitcast <4 x i16> %lv to double
+  ret double %r
+}
+
+define double @load_v2i32_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
+; CHECK-NEXT:    ret double [[R_SCALAR]]
+;
+  %lv = load <2 x i32>, ptr %x
+  %r = bitcast <2 x i32> %lv to double
+  ret double %r
+}
+
+; Multiple users with the same bitcast type should be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = bitcast <4 x i8> %lv to i32
+  %add = add i32 %r1, %r2
+  ret i32 %add
+}
+
+; Different bitcast types should not be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT:    [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
+; CHECK-NEXT:    [[R2_INT:%.*]] = bitcast float [[R2]] to i32
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = bitcast <4 x i8> %lv to float
+  %r2.int = bitcast float %r2 to i32
+  %add = add i32 %r1, %r2.int
+  ret i32 %add
+}
+
+; Bitcast to vector should not be scalarized.
+define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
+; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[R]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to <2 x i16>
+  ret <2 x i16> %r
+}
+
+; Load with both bitcast users and other users should not be scalarized.
+define i32 @load_v4i8_mixed_users(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT:    [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
+; CHECK-NEXT:    [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = extractelement <4 x i8> %lv, i32 0
+  %r2.ext = zext i8 %r2 to i32
+  %add = add i32 %r1, %r2.ext
+  ret i32 %add
+}

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-vectorizers

Author: Julian Nagele (juliannagele)

Changes

This change aims to convert vector loads to scalar loads, if they are only converted to scalars after anyway.


Full diff: https://github.com/llvm/llvm-project/pull/164682.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+116-28)
  • (added) llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll (+136)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d6eb00da11dc8..e045282c387fe 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -129,7 +129,9 @@ class VectorCombine {
   bool foldExtractedCmps(Instruction &I);
   bool foldBinopOfReductions(Instruction &I);
   bool foldSingleElementStore(Instruction &I);
-  bool scalarizeLoadExtract(Instruction &I);
+  bool scalarizeLoad(Instruction &I);
+  bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
+  bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
   bool scalarizeExtExtract(Instruction &I);
   bool foldConcatOfBoolMasks(Instruction &I);
   bool foldPermuteOfBinops(Instruction &I);
@@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
   return false;
 }
 
-/// Try to scalarize vector loads feeding extractelement instructions.
-bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
-  if (!TTI.allowVectorElementIndexingUsingGEP())
-    return false;
-
+/// Try to scalarize vector loads feeding extractelement or bitcast
+/// instructions.
+bool VectorCombine::scalarizeLoad(Instruction &I) {
   Value *Ptr;
   if (!match(&I, m_Load(m_Value(Ptr))))
     return false;
 
   auto *LI = cast<LoadInst>(&I);
   auto *VecTy = cast<VectorType>(LI->getType());
-  if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
+  if (!VecTy || LI->isVolatile() ||
+      !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
     return false;
 
-  InstructionCost OriginalCost =
-      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
-                          LI->getPointerAddressSpace(), CostKind);
-  InstructionCost ScalarizedCost = 0;
-
+  // Check what type of users we have and ensure no memory modifications betwwen
+  // the load and its users.
+  bool AllExtracts = true;
+  bool AllBitcasts = true;
   Instruction *LastCheckedInst = LI;
   unsigned NumInstChecked = 0;
-  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
-  auto FailureGuard = make_scope_exit([&]() {
-    // If the transform is aborted, discard the ScalarizationResults.
-    for (auto &Pair : NeedFreeze)
-      Pair.second.discard();
-  });
 
-  // Check if all users of the load are extracts with no memory modifications
-  // between the load and the extract. Compute the cost of both the original
-  // code and the scalarized version.
   for (User *U : LI->users()) {
-    auto *UI = dyn_cast<ExtractElementInst>(U);
-    if (!UI || UI->getParent() != LI->getParent())
+    auto *UI = dyn_cast<Instruction>(U);
+    if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
       return false;
 
-    // If any extract is waiting to be erased, then bail out as this will
+    // If any user is waiting to be erased, then bail out as this will
     // distort the cost calculation and possibly lead to infinite loops.
     if (UI->use_empty())
       return false;
 
-    // Check if any instruction between the load and the extract may modify
-    // memory.
+    if (!isa<ExtractElementInst>(UI))
+      AllExtracts = false;
+    if (!isa<BitCastInst>(UI))
+      AllBitcasts = false;
+
+    // Check if any instruction between the load and the user may modify memory.
     if (LastCheckedInst->comesBefore(UI)) {
       for (Instruction &I :
            make_range(std::next(LI->getIterator()), UI->getIterator())) {
@@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
       }
       LastCheckedInst = UI;
     }
+  }
+
+  if (AllExtracts)
+    return scalarizeLoadExtract(LI, VecTy, Ptr);
+  if (AllBitcasts)
+    return scalarizeLoadBitcast(LI, VecTy, Ptr);
+  return false;
+}
+
+/// Try to scalarize vector loads feeding extractelement instructions.
+bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
+                                         Value *Ptr) {
+  if (!TTI.allowVectorElementIndexingUsingGEP())
+    return false;
+
+  InstructionCost OriginalCost =
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+  InstructionCost ScalarizedCost = 0;
+
+  DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
+  auto FailureGuard = make_scope_exit([&]() {
+    // If the transform is aborted, discard the ScalarizationResults.
+    for (auto &Pair : NeedFreeze)
+      Pair.second.discard();
+  });
+
+  for (User *U : LI->users()) {
+    auto *UI = cast<ExtractElementInst>(U);
 
     auto ScalarIdx =
         canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
@@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
                                                     nullptr, nullptr, CostKind);
   }
 
-  LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
+  LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
                     << "\n  LoadExtractCost: " << OriginalCost
                     << " vs ScalarizedCost: " << ScalarizedCost << "\n");
 
@@ -1966,6 +1990,70 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+/// Try to scalarize vector loads feeding bitcast instructions.
+bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
+                                         Value *Ptr) {
+  InstructionCost OriginalCost =
+      TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+
+  Type *TargetScalarType = nullptr;
+  unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
+
+  for (User *U : LI->users()) {
+    auto *BC = cast<BitCastInst>(U);
+
+    Type *DestTy = BC->getDestTy();
+    if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
+      return false;
+
+    unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
+    if (DestBitWidth != VecBitWidth)
+      return false;
+
+    // All bitcasts should target the same scalar type.
+    if (!TargetScalarType)
+      TargetScalarType = DestTy;
+    else if (TargetScalarType != DestTy)
+      return false;
+
+    OriginalCost +=
+        TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
+                             TTI.getCastContextHint(BC), CostKind, BC);
+  }
+
+  if (!TargetScalarType || LI->user_empty())
+    return false;
+  InstructionCost ScalarizedCost =
+      TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
+                          LI->getPointerAddressSpace(), CostKind);
+
+  LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
+                    << "\n  OriginalCost: " << OriginalCost
+                    << " vs ScalarizedCost: " << ScalarizedCost << "\n");
+
+  if (ScalarizedCost >= OriginalCost)
+    return false;
+
+  // Ensure we add the load back to the worklist BEFORE its users so they can
+  // erased in the correct order.
+  Worklist.push(LI);
+
+  Builder.SetInsertPoint(LI);
+  auto *ScalarLoad =
+      Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
+  ScalarLoad->setAlignment(LI->getAlign());
+  ScalarLoad->copyMetadata(*LI);
+
+  // Replace all bitcast users with the scalar load.
+  for (User *U : LI->users()) {
+    auto *BC = cast<BitCastInst>(U);
+    replaceValue(*BC, *ScalarLoad, false);
+  }
+
+  return true;
+}
+
 bool VectorCombine::scalarizeExtExtract(Instruction &I) {
   if (!TTI.allowVectorElementIndexingUsingGEP())
     return false;
@@ -4555,7 +4643,7 @@ bool VectorCombine::run() {
     if (IsVectorType) {
       if (scalarizeOpOrCmp(I))
         return true;
-      if (scalarizeLoadExtract(I))
+      if (scalarizeLoad(I))
         return true;
       if (scalarizeExtExtract(I))
         return true;
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
new file mode 100644
index 0000000000000..464e5129262bc
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/AArch64/load-bitcast-scalarization.ll
@@ -0,0 +1,136 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
+
+define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    ret i32 [[R_SCALAR]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to i32
+  ret i32 %r
+}
+
+define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
+; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
+; CHECK-NEXT:    ret i64 [[R_SCALAR]]
+;
+  %lv = load <2 x i32>, ptr %x
+  %r = bitcast <2 x i32> %lv to i64
+  ret i64 %r
+}
+
+define float @load_v4i8_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT:    ret float [[R_SCALAR]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to float
+  ret float %r
+}
+
+define float @load_v2i16_bitcast_to_float(ptr %x) {
+; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
+; CHECK-NEXT:    ret float [[R_SCALAR]]
+;
+  %lv = load <2 x i16>, ptr %x
+  %r = bitcast <2 x i16> %lv to float
+  ret float %r
+}
+
+define double @load_v4i16_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
+; CHECK-NEXT:    ret double [[R_SCALAR]]
+;
+  %lv = load <4 x i16>, ptr %x
+  %r = bitcast <4 x i16> %lv to double
+  ret double %r
+}
+
+define double @load_v2i32_bitcast_to_double(ptr %x) {
+; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
+; CHECK-NEXT:    [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
+; CHECK-NEXT:    ret double [[R_SCALAR]]
+;
+  %lv = load <2 x i32>, ptr %x
+  %r = bitcast <2 x i32> %lv to double
+  ret double %r
+}
+
+; Multiple users with the same bitcast type should be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = bitcast <4 x i8> %lv to i32
+  %add = add i32 %r1, %r2
+  ret i32 %add
+}
+
+; Different bitcast types should not be scalarized.
+define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT:    [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
+; CHECK-NEXT:    [[R2_INT:%.*]] = bitcast float [[R2]] to i32
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = bitcast <4 x i8> %lv to float
+  %r2.int = bitcast float %r2 to i32
+  %add = add i32 %r1, %r2.int
+  ret i32 %add
+}
+
+; Bitcast to vector should not be scalarized.
+define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
+; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[R]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r = bitcast <4 x i8> %lv to <2 x i16>
+  ret <2 x i16> %r
+}
+
+; Load with both bitcast users and other users should not be scalarized.
+define i32 @load_v4i8_mixed_users(ptr %x) {
+; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
+; CHECK-SAME: ptr [[X:%.*]]) {
+; CHECK-NEXT:    [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
+; CHECK-NEXT:    [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
+; CHECK-NEXT:    [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
+; CHECK-NEXT:    [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
+; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
+; CHECK-NEXT:    ret i32 [[ADD]]
+;
+  %lv = load <4 x i8>, ptr %x
+  %r1 = bitcast <4 x i8> %lv to i32
+  %r2 = extractelement <4 x i8> %lv, i32 0
+  %r2.ext = zext i8 %r2 to i32
+  %add = add i32 %r1, %r2.ext
+  ret i32 %add
+}

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add an Alive2 proof to this one?

Comment on lines 2025 to 2026
if (!TargetScalarType || LI->user_empty())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TargetScalarType != nullptr should imply that there are users, right?

Suggested change
if (!TargetScalarType || LI->user_empty())
return false;
if (!TargetScalarType)
return false;
assert(!LI->user_empty() && "...");

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, updated as suggested -- thanks!

@@ -0,0 +1,136 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6

latet version should be 6I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be much easier with the patch if it was merely:

m_BitCast(m_OneUse(m_Load()))

I'm not clear on the need for iterating across users etc.

%r2 = bitcast <4 x i8> %lv to i32
%add = add i32 %r1, %r2
ret i32 %add
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't other passes likely to have already folded these duplication or are you seeing this kind of thing in real world code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for larger the motivating cases this improves results in combination with extend scalarization. @juliannagele could you add a larger case showing the interaction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now that #164683 is in we can potentially scalarize a load-ext-extend sequence. Added a test that show this.

@juliannagele
Copy link
Member Author

I'd be much easier with the patch if it was merely:

m_BitCast(m_OneUse(m_Load()))

I'm not clear on the need for iterating across users etc.

It was just straight-forward to support multiple users. The load-ext scalarization is looking at all users already anyway so just checking whether they're all bitcasts at the same time seemed natural and I didn't feel like restricting it more than the other load scalarization was simpler. That being said, one user should catch the case we're after, so if you strongly prefer that it'd be ok to change it.

@RKSimon RKSimon self-requested a review October 31, 2025 17:45
juliannagele added a commit that referenced this pull request Nov 4, 2025
…t if all extracts would lead to UB on poison. (#164683)

This change aims to avoid inserting a freeze instruction between the
load and bitcast when scalarizing extend-extract. This is particularly
useful in combination with
#164682, which can then
potentially further scalarize, provided there is no freeze.

alive2 proof: https://alive2.llvm.org/ce/z/W-GD88
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Nov 4, 2025
…tend-extract if all extracts would lead to UB on poison. (#164683)

This change aims to avoid inserting a freeze instruction between the
load and bitcast when scalarizing extend-extract. This is particularly
useful in combination with
llvm/llvm-project#164682, which can then
potentially further scalarize, provided there is no freeze.

alive2 proof: https://alive2.llvm.org/ce/z/W-GD88
auto *LI = cast<LoadInst>(&I);
auto *VecTy = cast<VectorType>(LI->getType());
if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
if (!VecTy || LI->isVolatile() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need checking now, can the function be called for scalar loads?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no, left-over from a first try, dropped, thanks!

Comment on lines 1863 to 1864
// Check what type of users we have and ensure no memory modifications betwwen
// the load and its users.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably be better to keep the comment at the loop, as it is not directly related to the variables here

Suggested change
// Check what type of users we have and ensure no memory modifications betwwen
// the load and its users.
// Check what type of users we have (must either all be extracts or bitcasts) and ensure no memory modifications between
// the load and its users.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks!

Comment on lines 1872 to 1873
if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
return false;
if (!UI || UI->getParent() != LI->getParent())
return false;

also checked below, with explanation in comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yes, dropped, thanks!

Comment on lines 1912 to 1915
InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);
InstructionCost ScalarizedCost = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can sink this after the FailGuard, closer to the loop that sets them

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

if (DestBitWidth != VecBitWidth)
return false;

// All bitcasts should target the same scalar type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// All bitcasts should target the same scalar type.
// All bitcasts must target the same scalar type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

Comment on lines +2025 to +2026
if (!TargetScalarType)
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!TargetScalarType)
return false;
if (!TargetScalarType)
return false;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

@@ -0,0 +1,32 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
; RUN: opt -passes=vector-combine,dce,vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to run vector-combine twice to get the transform? If so, it would probably better to add this as PhaseOrdering test, making sure we get the expected fold with -O3

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do yes, the first scalarization leaves dead users that need to be removed before the second one can kick in -- moved to phaseordering with O3.

@juliannagele juliannagele force-pushed the scalarize-load-bitcast branch from 504f3e8 to c2db71e Compare November 7, 2025 13:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants