Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 116 additions & 26 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ class VectorCombine {
bool foldExtractedCmps(Instruction &I);
bool foldBinopOfReductions(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
bool scalarizeLoad(Instruction &I);
bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
bool scalarizeExtExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
Expand Down Expand Up @@ -1845,11 +1847,9 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
return false;
}

/// Try to scalarize vector loads feeding extractelement instructions.
bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (!TTI.allowVectorElementIndexingUsingGEP())
return false;

/// Try to scalarize vector loads feeding extractelement or bitcast
/// instructions.
bool VectorCombine::scalarizeLoad(Instruction &I) {
Value *Ptr;
if (!match(&I, m_Load(m_Value(Ptr))))
return false;
Expand All @@ -1859,35 +1859,30 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
return false;

InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);
InstructionCost ScalarizedCost = 0;

bool AllExtracts = true;
bool AllBitcasts = true;
Instruction *LastCheckedInst = LI;
unsigned NumInstChecked = 0;
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
auto FailureGuard = make_scope_exit([&]() {
// If the transform is aborted, discard the ScalarizationResults.
for (auto &Pair : NeedFreeze)
Pair.second.discard();
});

// Check if all users of the load are extracts with no memory modifications
// between the load and the extract. Compute the cost of both the original
// code and the scalarized version.
// Check what type of users we have (must either all be extracts or
// bitcasts) and ensure no memory modifications between the load and
// its users.
for (User *U : LI->users()) {
auto *UI = dyn_cast<ExtractElementInst>(U);
auto *UI = dyn_cast<Instruction>(U);
if (!UI || UI->getParent() != LI->getParent())
return false;

// If any extract is waiting to be erased, then bail out as this will
// If any user is waiting to be erased, then bail out as this will
// distort the cost calculation and possibly lead to infinite loops.
if (UI->use_empty())
return false;

// Check if any instruction between the load and the extract may modify
// memory.
if (!isa<ExtractElementInst>(UI))
AllExtracts = false;
if (!isa<BitCastInst>(UI))
AllBitcasts = false;

// Check if any instruction between the load and the user may modify memory.
if (LastCheckedInst->comesBefore(UI)) {
for (Instruction &I :
make_range(std::next(LI->getIterator()), UI->getIterator())) {
Expand All @@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
}
LastCheckedInst = UI;
}
}

if (AllExtracts)
return scalarizeLoadExtract(LI, VecTy, Ptr);
if (AllBitcasts)
return scalarizeLoadBitcast(LI, VecTy, Ptr);
return false;
}

/// Try to scalarize vector loads feeding extractelement instructions.
bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
Value *Ptr) {
if (!TTI.allowVectorElementIndexingUsingGEP())
return false;

DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
auto FailureGuard = make_scope_exit([&]() {
// If the transform is aborted, discard the ScalarizationResults.
for (auto &Pair : NeedFreeze)
Pair.second.discard();
});

InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);
InstructionCost ScalarizedCost = 0;

for (User *U : LI->users()) {
auto *UI = cast<ExtractElementInst>(U);

auto ScalarIdx =
canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
Expand All @@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
nullptr, nullptr, CostKind);
}

LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
<< "\n LoadExtractCost: " << OriginalCost
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");

Expand Down Expand Up @@ -1966,6 +1990,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}

/// Try to scalarize vector loads feeding bitcast instructions.
bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
Value *Ptr) {
InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);

Type *TargetScalarType = nullptr;
unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);

for (User *U : LI->users()) {
auto *BC = cast<BitCastInst>(U);

Type *DestTy = BC->getDestTy();
if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
return false;

unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
if (DestBitWidth != VecBitWidth)
return false;

// All bitcasts must target the same scalar type.
if (!TargetScalarType)
TargetScalarType = DestTy;
else if (TargetScalarType != DestTy)
return false;

OriginalCost +=
TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
TTI.getCastContextHint(BC), CostKind, BC);
}

if (!TargetScalarType)
return false;
Comment on lines +2025 to +2026
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!TargetScalarType)
return false;
if (!TargetScalarType)
return false;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks


assert(!LI->user_empty() && "Unexpected load without bitcast users");
InstructionCost ScalarizedCost =
TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);

LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
<< "\n OriginalCost: " << OriginalCost
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");

if (ScalarizedCost >= OriginalCost)
return false;

// Ensure we add the load back to the worklist BEFORE its users so they can
// erased in the correct order.
Worklist.push(LI);

Builder.SetInsertPoint(LI);
auto *ScalarLoad =
Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
ScalarLoad->setAlignment(LI->getAlign());
ScalarLoad->copyMetadata(*LI);

// Replace all bitcast users with the scalar load.
for (User *U : LI->users()) {
auto *BC = cast<BitCastInst>(U);
replaceValue(*BC, *ScalarLoad, false);
}

return true;
}

bool VectorCombine::scalarizeExtExtract(Instruction &I) {
if (!TTI.allowVectorElementIndexingUsingGEP())
return false;
Expand Down Expand Up @@ -4578,7 +4668,7 @@ bool VectorCombine::run() {
if (IsVectorType) {
if (scalarizeOpOrCmp(I))
return true;
if (scalarizeLoadExtract(I))
if (scalarizeLoad(I))
return true;
if (scalarizeExtExtract(I))
return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
; RUN: opt -O3 -mtriple=arm64-apple-darwinos -S %s | FileCheck %s

define noundef i32 @load_ext_extract(ptr %src) {
; CHECK-LABEL: define noundef range(i32 0, 1021) i32 @load_ext_extract(
; CHECK-SAME: ptr readonly captures(none) [[SRC:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr [[SRC]], align 4
; CHECK-NEXT: [[TMP15:%.*]] = lshr i32 [[TMP14]], 24
; CHECK-NEXT: [[TMP16:%.*]] = lshr i32 [[TMP14]], 16
; CHECK-NEXT: [[TMP17:%.*]] = and i32 [[TMP16]], 255
; CHECK-NEXT: [[TMP18:%.*]] = lshr i32 [[TMP14]], 8
; CHECK-NEXT: [[TMP19:%.*]] = and i32 [[TMP18]], 255
; CHECK-NEXT: [[TMP20:%.*]] = and i32 [[TMP14]], 255
; CHECK-NEXT: [[ADD1:%.*]] = add nuw nsw i32 [[TMP20]], [[TMP19]]
; CHECK-NEXT: [[ADD2:%.*]] = add nuw nsw i32 [[ADD1]], [[TMP17]]
; CHECK-NEXT: [[ADD3:%.*]] = add nuw nsw i32 [[ADD2]], [[TMP15]]
; CHECK-NEXT: ret i32 [[ADD3]]
;
entry:
%x = load <4 x i8>, ptr %src, align 4
%ext = zext nneg <4 x i8> %x to <4 x i32>
%ext.0 = extractelement <4 x i32> %ext, i64 0
%ext.1 = extractelement <4 x i32> %ext, i64 1
%ext.2 = extractelement <4 x i32> %ext, i64 2
%ext.3 = extractelement <4 x i32> %ext, i64 3

%add1 = add i32 %ext.0, %ext.1
%add2 = add i32 %add1, %ext.2
%add3 = add i32 %add2, %ext.3
ret i32 %add3
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s

define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: ret i32 [[R_SCALAR]]
;
%lv = load <4 x i8>, ptr %x
%r = bitcast <4 x i8> %lv to i32
ret i32 %r
}

define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
; CHECK-NEXT: ret i64 [[R_SCALAR]]
;
%lv = load <2 x i32>, ptr %x
%r = bitcast <2 x i32> %lv to i64
ret i64 %r
}

define float @load_v4i8_bitcast_to_float(ptr %x) {
; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
; CHECK-NEXT: ret float [[R_SCALAR]]
;
%lv = load <4 x i8>, ptr %x
%r = bitcast <4 x i8> %lv to float
ret float %r
}

define float @load_v2i16_bitcast_to_float(ptr %x) {
; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
; CHECK-NEXT: ret float [[R_SCALAR]]
;
%lv = load <2 x i16>, ptr %x
%r = bitcast <2 x i16> %lv to float
ret float %r
}

define double @load_v4i16_bitcast_to_double(ptr %x) {
; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
; CHECK-NEXT: ret double [[R_SCALAR]]
;
%lv = load <4 x i16>, ptr %x
%r = bitcast <4 x i16> %lv to double
ret double %r
}

define double @load_v2i32_bitcast_to_double(ptr %x) {
; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
; CHECK-NEXT: ret double [[R_SCALAR]]
;
%lv = load <2 x i32>, ptr %x
%r = bitcast <2 x i32> %lv to double
ret double %r
}

; Multiple users with the same bitcast type should be scalarized.
define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%lv = load <4 x i8>, ptr %x
%r1 = bitcast <4 x i8> %lv to i32
%r2 = bitcast <4 x i8> %lv to i32
%add = add i32 %r1, %r2
ret i32 %add
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't other passes likely to have already folded these duplication or are you seeing this kind of thing in real world code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for larger the motivating cases this improves results in combination with extend scalarization. @juliannagele could you add a larger case showing the interaction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now that #164683 is in we can potentially scalarize a load-ext-extend sequence. Added a test that show this.


; Different bitcast types should not be scalarized.
define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
; CHECK-NEXT: [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
; CHECK-NEXT: [[R2_INT:%.*]] = bitcast float [[R2]] to i32
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%lv = load <4 x i8>, ptr %x
%r1 = bitcast <4 x i8> %lv to i32
%r2 = bitcast <4 x i8> %lv to float
%r2.int = bitcast float %r2 to i32
%add = add i32 %r1, %r2.int
ret i32 %add
}

; Bitcast to vector should not be scalarized.
define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
; CHECK-NEXT: ret <2 x i16> [[R]]
;
%lv = load <4 x i8>, ptr %x
%r = bitcast <4 x i8> %lv to <2 x i16>
ret <2 x i16> %r
}

; Load with both bitcast users and other users should not be scalarized.
define i32 @load_v4i8_mixed_users(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
; CHECK-NEXT: [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
; CHECK-NEXT: [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%lv = load <4 x i8>, ptr %x
%r1 = bitcast <4 x i8> %lv to i32
%r2 = extractelement <4 x i8> %lv, i32 0
%r2.ext = zext i8 %r2 to i32
%add = add i32 %r1, %r2.ext
ret i32 %add
}