Skip to content

Commit 6fcd377

Browse files
committed
[VectorCombine] Try to scalarize vector loads feeding bitcast instructions.
1 parent 15d11eb commit 6fcd377

File tree

2 files changed

+252
-28
lines changed

2 files changed

+252
-28
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 116 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ class VectorCombine {
129129
bool foldExtractedCmps(Instruction &I);
130130
bool foldBinopOfReductions(Instruction &I);
131131
bool foldSingleElementStore(Instruction &I);
132-
bool scalarizeLoadExtract(Instruction &I);
132+
bool scalarizeLoad(Instruction &I);
133+
bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
134+
bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
133135
bool scalarizeExtExtract(Instruction &I);
134136
bool foldConcatOfBoolMasks(Instruction &I);
135137
bool foldPermuteOfBinops(Instruction &I);
@@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
18451847
return false;
18461848
}
18471849

1848-
/// Try to scalarize vector loads feeding extractelement instructions.
1849-
bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1850-
if (!TTI.allowVectorElementIndexingUsingGEP())
1851-
return false;
1852-
1850+
/// Try to scalarize vector loads feeding extractelement or bitcast
1851+
/// instructions.
1852+
bool VectorCombine::scalarizeLoad(Instruction &I) {
18531853
Value *Ptr;
18541854
if (!match(&I, m_Load(m_Value(Ptr))))
18551855
return false;
18561856

18571857
auto *LI = cast<LoadInst>(&I);
18581858
auto *VecTy = cast<VectorType>(LI->getType());
1859-
if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
1859+
if (!VecTy || LI->isVolatile() ||
1860+
!DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
18601861
return false;
18611862

1862-
InstructionCost OriginalCost =
1863-
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1864-
LI->getPointerAddressSpace(), CostKind);
1865-
InstructionCost ScalarizedCost = 0;
1866-
1863+
// Check what type of users we have and ensure no memory modifications betwwen
1864+
// the load and its users.
1865+
bool AllExtracts = true;
1866+
bool AllBitcasts = true;
18671867
Instruction *LastCheckedInst = LI;
18681868
unsigned NumInstChecked = 0;
1869-
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1870-
auto FailureGuard = make_scope_exit([&]() {
1871-
// If the transform is aborted, discard the ScalarizationResults.
1872-
for (auto &Pair : NeedFreeze)
1873-
Pair.second.discard();
1874-
});
18751869

1876-
// Check if all users of the load are extracts with no memory modifications
1877-
// between the load and the extract. Compute the cost of both the original
1878-
// code and the scalarized version.
18791870
for (User *U : LI->users()) {
1880-
auto *UI = dyn_cast<ExtractElementInst>(U);
1881-
if (!UI || UI->getParent() != LI->getParent())
1871+
auto *UI = dyn_cast<Instruction>(U);
1872+
if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
18821873
return false;
18831874

1884-
// If any extract is waiting to be erased, then bail out as this will
1875+
// If any user is waiting to be erased, then bail out as this will
18851876
// distort the cost calculation and possibly lead to infinite loops.
18861877
if (UI->use_empty())
18871878
return false;
18881879

1889-
// Check if any instruction between the load and the extract may modify
1890-
// memory.
1880+
if (!isa<ExtractElementInst>(UI))
1881+
AllExtracts = false;
1882+
if (!isa<BitCastInst>(UI))
1883+
AllBitcasts = false;
1884+
1885+
// Check if any instruction between the load and the user may modify memory.
18911886
if (LastCheckedInst->comesBefore(UI)) {
18921887
for (Instruction &I :
18931888
make_range(std::next(LI->getIterator()), UI->getIterator())) {
@@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
18991894
}
19001895
LastCheckedInst = UI;
19011896
}
1897+
}
1898+
1899+
if (AllExtracts)
1900+
return scalarizeLoadExtract(LI, VecTy, Ptr);
1901+
if (AllBitcasts)
1902+
return scalarizeLoadBitcast(LI, VecTy, Ptr);
1903+
return false;
1904+
}
1905+
1906+
/// Try to scalarize vector loads feeding extractelement instructions.
1907+
bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
1908+
Value *Ptr) {
1909+
if (!TTI.allowVectorElementIndexingUsingGEP())
1910+
return false;
1911+
1912+
InstructionCost OriginalCost =
1913+
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1914+
LI->getPointerAddressSpace(), CostKind);
1915+
InstructionCost ScalarizedCost = 0;
1916+
1917+
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1918+
auto FailureGuard = make_scope_exit([&]() {
1919+
// If the transform is aborted, discard the ScalarizationResults.
1920+
for (auto &Pair : NeedFreeze)
1921+
Pair.second.discard();
1922+
});
1923+
1924+
for (User *U : LI->users()) {
1925+
auto *UI = cast<ExtractElementInst>(U);
19021926

19031927
auto ScalarIdx =
19041928
canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
@@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19201944
nullptr, nullptr, CostKind);
19211945
}
19221946

1923-
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
1947+
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
19241948
<< "\n LoadExtractCost: " << OriginalCost
19251949
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");
19261950

@@ -1966,6 +1990,70 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19661990
return true;
19671991
}
19681992

1993+
/// Try to scalarize vector loads feeding bitcast instructions.
1994+
bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
1995+
Value *Ptr) {
1996+
InstructionCost OriginalCost =
1997+
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1998+
LI->getPointerAddressSpace(), CostKind);
1999+
2000+
Type *TargetScalarType = nullptr;
2001+
unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);
2002+
2003+
for (User *U : LI->users()) {
2004+
auto *BC = cast<BitCastInst>(U);
2005+
2006+
Type *DestTy = BC->getDestTy();
2007+
if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2008+
return false;
2009+
2010+
unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
2011+
if (DestBitWidth != VecBitWidth)
2012+
return false;
2013+
2014+
// All bitcasts should target the same scalar type.
2015+
if (!TargetScalarType)
2016+
TargetScalarType = DestTy;
2017+
else if (TargetScalarType != DestTy)
2018+
return false;
2019+
2020+
OriginalCost +=
2021+
TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
2022+
TTI.getCastContextHint(BC), CostKind, BC);
2023+
}
2024+
2025+
if (!TargetScalarType || LI->user_empty())
2026+
return false;
2027+
InstructionCost ScalarizedCost =
2028+
TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
2029+
LI->getPointerAddressSpace(), CostKind);
2030+
2031+
LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2032+
<< "\n OriginalCost: " << OriginalCost
2033+
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");
2034+
2035+
if (ScalarizedCost >= OriginalCost)
2036+
return false;
2037+
2038+
// Ensure we add the load back to the worklist BEFORE its users so they can
2039+
// erased in the correct order.
2040+
Worklist.push(LI);
2041+
2042+
Builder.SetInsertPoint(LI);
2043+
auto *ScalarLoad =
2044+
Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
2045+
ScalarLoad->setAlignment(LI->getAlign());
2046+
ScalarLoad->copyMetadata(*LI);
2047+
2048+
// Replace all bitcast users with the scalar load.
2049+
for (User *U : LI->users()) {
2050+
auto *BC = cast<BitCastInst>(U);
2051+
replaceValue(*BC, *ScalarLoad, false);
2052+
}
2053+
2054+
return true;
2055+
}
2056+
19692057
bool VectorCombine::scalarizeExtExtract(Instruction &I) {
19702058
if (!TTI.allowVectorElementIndexingUsingGEP())
19712059
return false;
@@ -4555,7 +4643,7 @@ bool VectorCombine::run() {
45554643
if (IsVectorType) {
45564644
if (scalarizeOpOrCmp(I))
45574645
return true;
4558-
if (scalarizeLoadExtract(I))
4646+
if (scalarizeLoad(I))
45594647
return true;
45604648
if (scalarizeExtExtract(I))
45614649
return true;
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s
3+
4+
define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
5+
; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
6+
; CHECK-SAME: ptr [[X:%.*]]) {
7+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
8+
; CHECK-NEXT: ret i32 [[R_SCALAR]]
9+
;
10+
%lv = load <4 x i8>, ptr %x
11+
%r = bitcast <4 x i8> %lv to i32
12+
ret i32 %r
13+
}
14+
15+
define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
16+
; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
17+
; CHECK-SAME: ptr [[X:%.*]]) {
18+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
19+
; CHECK-NEXT: ret i64 [[R_SCALAR]]
20+
;
21+
%lv = load <2 x i32>, ptr %x
22+
%r = bitcast <2 x i32> %lv to i64
23+
ret i64 %r
24+
}
25+
26+
define float @load_v4i8_bitcast_to_float(ptr %x) {
27+
; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
28+
; CHECK-SAME: ptr [[X:%.*]]) {
29+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
30+
; CHECK-NEXT: ret float [[R_SCALAR]]
31+
;
32+
%lv = load <4 x i8>, ptr %x
33+
%r = bitcast <4 x i8> %lv to float
34+
ret float %r
35+
}
36+
37+
define float @load_v2i16_bitcast_to_float(ptr %x) {
38+
; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
39+
; CHECK-SAME: ptr [[X:%.*]]) {
40+
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
41+
; CHECK-NEXT: ret float [[R_SCALAR]]
42+
;
43+
%lv = load <2 x i16>, ptr %x
44+
%r = bitcast <2 x i16> %lv to float
45+
ret float %r
46+
}
47+
48+
define double @load_v4i16_bitcast_to_double(ptr %x) {
49+
; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
50+
; CHECK-SAME: ptr [[X:%.*]]) {
51+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
52+
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
53+
; CHECK-NEXT: ret double [[R_SCALAR]]
54+
;
55+
%lv = load <4 x i16>, ptr %x
56+
%r = bitcast <4 x i16> %lv to double
57+
ret double %r
58+
}
59+
60+
define double @load_v2i32_bitcast_to_double(ptr %x) {
61+
; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
62+
; CHECK-SAME: ptr [[X:%.*]]) {
63+
; CHECK-NEXT: [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
64+
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
65+
; CHECK-NEXT: ret double [[R_SCALAR]]
66+
;
67+
%lv = load <2 x i32>, ptr %x
68+
%r = bitcast <2 x i32> %lv to double
69+
ret double %r
70+
}
71+
72+
; Multiple users with the same bitcast type should be scalarized.
73+
define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
74+
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
75+
; CHECK-SAME: ptr [[X:%.*]]) {
76+
; CHECK-NEXT: [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
77+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
78+
; CHECK-NEXT: ret i32 [[ADD]]
79+
;
80+
%lv = load <4 x i8>, ptr %x
81+
%r1 = bitcast <4 x i8> %lv to i32
82+
%r2 = bitcast <4 x i8> %lv to i32
83+
%add = add i32 %r1, %r2
84+
ret i32 %add
85+
}
86+
87+
; Different bitcast types should not be scalarized.
88+
define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
89+
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
90+
; CHECK-SAME: ptr [[X:%.*]]) {
91+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
92+
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
93+
; CHECK-NEXT: [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
94+
; CHECK-NEXT: [[R2_INT:%.*]] = bitcast float [[R2]] to i32
95+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
96+
; CHECK-NEXT: ret i32 [[ADD]]
97+
;
98+
%lv = load <4 x i8>, ptr %x
99+
%r1 = bitcast <4 x i8> %lv to i32
100+
%r2 = bitcast <4 x i8> %lv to float
101+
%r2.int = bitcast float %r2 to i32
102+
%add = add i32 %r1, %r2.int
103+
ret i32 %add
104+
}
105+
106+
; Bitcast to vector should not be scalarized.
107+
define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
108+
; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
109+
; CHECK-SAME: ptr [[X:%.*]]) {
110+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
111+
; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
112+
; CHECK-NEXT: ret <2 x i16> [[R]]
113+
;
114+
%lv = load <4 x i8>, ptr %x
115+
%r = bitcast <4 x i8> %lv to <2 x i16>
116+
ret <2 x i16> %r
117+
}
118+
119+
; Load with both bitcast users and other users should not be scalarized.
120+
define i32 @load_v4i8_mixed_users(ptr %x) {
121+
; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
122+
; CHECK-SAME: ptr [[X:%.*]]) {
123+
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
124+
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
125+
; CHECK-NEXT: [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
126+
; CHECK-NEXT: [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
127+
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
128+
; CHECK-NEXT: ret i32 [[ADD]]
129+
;
130+
%lv = load <4 x i8>, ptr %x
131+
%r1 = bitcast <4 x i8> %lv to i32
132+
%r2 = extractelement <4 x i8> %lv, i32 0
133+
%r2.ext = zext i8 %r2 to i32
134+
%add = add i32 %r1, %r2.ext
135+
ret i32 %add
136+
}

0 commit comments

Comments
 (0)