Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 116 additions & 28 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ class VectorCombine {
bool foldExtractedCmps(Instruction &I);
bool foldBinopOfReductions(Instruction &I);
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
bool scalarizeLoad(Instruction &I);
bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
bool scalarizeExtExtract(Instruction &I);
bool foldConcatOfBoolMasks(Instruction &I);
bool foldPermuteOfBinops(Instruction &I);
Expand Down Expand Up @@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
return false;
}

/// Try to scalarize vector loads feeding extractelement instructions.
bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
if (!TTI.allowVectorElementIndexingUsingGEP())
return false;

/// Try to scalarize vector loads feeding extractelement or bitcast
/// instructions.
bool VectorCombine::scalarizeLoad(Instruction &I) {
Value *Ptr;
if (!match(&I, m_Load(m_Value(Ptr))))
return false;

auto *LI = cast<LoadInst>(&I);
auto *VecTy = cast<VectorType>(LI->getType());
if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
if (!VecTy || LI->isVolatile() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need checking now, can the function be called for scalar loads?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no, left-over from a first try, dropped, thanks!

!DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
return false;

InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);
InstructionCost ScalarizedCost = 0;

// Check what type of users we have and ensure no memory modifications betwwen
// the load and its users.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably be better to keep the comment at the loop, as it is not directly related to the variables here

Suggested change
// Check what type of users we have and ensure no memory modifications betwwen
// the load and its users.
// Check what type of users we have (must either all be extracts or bitcasts) and ensure no memory modifications between
// the load and its users.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks!

bool AllExtracts = true;
bool AllBitcasts = true;
Instruction *LastCheckedInst = LI;
unsigned NumInstChecked = 0;
DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
auto FailureGuard = make_scope_exit([&]() {
// If the transform is aborted, discard the ScalarizationResults.
for (auto &Pair : NeedFreeze)
Pair.second.discard();
});

// Check if all users of the load are extracts with no memory modifications
// between the load and the extract. Compute the cost of both the original
// code and the scalarized version.
for (User *U : LI->users()) {
auto *UI = dyn_cast<ExtractElementInst>(U);
if (!UI || UI->getParent() != LI->getParent())
auto *UI = dyn_cast<Instruction>(U);
if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!UI || UI->getParent() != LI->getParent() || UI->use_empty())
return false;
if (!UI || UI->getParent() != LI->getParent())
return false;

also checked below, with explanation in comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yes, dropped, thanks!


// If any extract is waiting to be erased, then bail out as this will
// If any user is waiting to be erased, then bail out as this will
// distort the cost calculation and possibly lead to infinite loops.
if (UI->use_empty())
return false;

// Check if any instruction between the load and the extract may modify
// memory.
if (!isa<ExtractElementInst>(UI))
AllExtracts = false;
if (!isa<BitCastInst>(UI))
AllBitcasts = false;

// Check if any instruction between the load and the user may modify memory.
if (LastCheckedInst->comesBefore(UI)) {
for (Instruction &I :
make_range(std::next(LI->getIterator()), UI->getIterator())) {
Expand All @@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
}
LastCheckedInst = UI;
}
}

if (AllExtracts)
return scalarizeLoadExtract(LI, VecTy, Ptr);
if (AllBitcasts)
return scalarizeLoadBitcast(LI, VecTy, Ptr);
return false;
}

/// Try to scalarize vector loads feeding extractelement instructions.
bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
Value *Ptr) {
if (!TTI.allowVectorElementIndexingUsingGEP())
return false;

InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);
InstructionCost ScalarizedCost = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can sink this after the FailGuard, closer to the loop that sets them

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks


DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
auto FailureGuard = make_scope_exit([&]() {
// If the transform is aborted, discard the ScalarizationResults.
for (auto &Pair : NeedFreeze)
Pair.second.discard();
});

for (User *U : LI->users()) {
auto *UI = cast<ExtractElementInst>(U);

auto ScalarIdx =
canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
Expand All @@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
nullptr, nullptr, CostKind);
}

LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
<< "\n LoadExtractCost: " << OriginalCost
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");

Expand Down Expand Up @@ -1966,6 +1990,70 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
return true;
}

/// Try to scalarize vector loads feeding bitcast instructions.
bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
Value *Ptr) {
InstructionCost OriginalCost =
TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);

Type *TargetScalarType = nullptr;
unsigned VecBitWidth = DL->getTypeSizeInBits(VecTy);

for (User *U : LI->users()) {
auto *BC = cast<BitCastInst>(U);

Type *DestTy = BC->getDestTy();
if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
return false;

unsigned DestBitWidth = DL->getTypeSizeInBits(DestTy);
if (DestBitWidth != VecBitWidth)
return false;

// All bitcasts should target the same scalar type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// All bitcasts should target the same scalar type.
// All bitcasts must target the same scalar type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

if (!TargetScalarType)
TargetScalarType = DestTy;
else if (TargetScalarType != DestTy)
return false;

OriginalCost +=
TTI.getCastInstrCost(Instruction::BitCast, TargetScalarType, VecTy,
TTI.getCastContextHint(BC), CostKind, BC);
}

if (!TargetScalarType || LI->user_empty())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TargetScalarType != nullptr should imply that there are users, right?

Suggested change
if (!TargetScalarType || LI->user_empty())
return false;
if (!TargetScalarType)
return false;
assert(!LI->user_empty() && "...");

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, updated as suggested -- thanks!

InstructionCost ScalarizedCost =
TTI.getMemoryOpCost(Instruction::Load, TargetScalarType, LI->getAlign(),
LI->getPointerAddressSpace(), CostKind);

LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
<< "\n OriginalCost: " << OriginalCost
<< " vs ScalarizedCost: " << ScalarizedCost << "\n");

if (ScalarizedCost >= OriginalCost)
return false;

// Ensure we add the load back to the worklist BEFORE its users so they can
// erased in the correct order.
Worklist.push(LI);

Builder.SetInsertPoint(LI);
auto *ScalarLoad =
Builder.CreateLoad(TargetScalarType, Ptr, LI->getName() + ".scalar");
ScalarLoad->setAlignment(LI->getAlign());
ScalarLoad->copyMetadata(*LI);

// Replace all bitcast users with the scalar load.
for (User *U : LI->users()) {
auto *BC = cast<BitCastInst>(U);
replaceValue(*BC, *ScalarLoad, false);
}

return true;
}

bool VectorCombine::scalarizeExtExtract(Instruction &I) {
if (!TTI.allowVectorElementIndexingUsingGEP())
return false;
Expand Down Expand Up @@ -4555,7 +4643,7 @@ bool VectorCombine::run() {
if (IsVectorType) {
if (scalarizeOpOrCmp(I))
return true;
if (scalarizeLoadExtract(I))
if (scalarizeLoad(I))
return true;
if (scalarizeExtExtract(I))
return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6

latet version should be 6I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated

; RUN: opt -passes=vector-combine -mtriple=arm64-apple-darwinos -S %s | FileCheck %s

define i32 @load_v4i8_bitcast_to_i32(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_bitcast_to_i32(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: ret i32 [[R_SCALAR]]
;
%lv = load <4 x i8>, ptr %x
%r = bitcast <4 x i8> %lv to i32
ret i32 %r
}

define i64 @load_v2i32_bitcast_to_i64(ptr %x) {
; CHECK-LABEL: define i64 @load_v2i32_bitcast_to_i64(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load i64, ptr [[X]], align 8
; CHECK-NEXT: ret i64 [[R_SCALAR]]
;
%lv = load <2 x i32>, ptr %x
%r = bitcast <2 x i32> %lv to i64
ret i64 %r
}

define float @load_v4i8_bitcast_to_float(ptr %x) {
; CHECK-LABEL: define float @load_v4i8_bitcast_to_float(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
; CHECK-NEXT: ret float [[R_SCALAR]]
;
%lv = load <4 x i8>, ptr %x
%r = bitcast <4 x i8> %lv to float
ret float %r
}

define float @load_v2i16_bitcast_to_float(ptr %x) {
; CHECK-LABEL: define float @load_v2i16_bitcast_to_float(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[R_SCALAR:%.*]] = load float, ptr [[X]], align 4
; CHECK-NEXT: ret float [[R_SCALAR]]
;
%lv = load <2 x i16>, ptr %x
%r = bitcast <2 x i16> %lv to float
ret float %r
}

define double @load_v4i16_bitcast_to_double(ptr %x) {
; CHECK-LABEL: define double @load_v4i16_bitcast_to_double(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i16>, ptr [[X]], align 8
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <4 x i16> [[LV]] to double
; CHECK-NEXT: ret double [[R_SCALAR]]
;
%lv = load <4 x i16>, ptr %x
%r = bitcast <4 x i16> %lv to double
ret double %r
}

define double @load_v2i32_bitcast_to_double(ptr %x) {
; CHECK-LABEL: define double @load_v2i32_bitcast_to_double(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <2 x i32>, ptr [[X]], align 8
; CHECK-NEXT: [[R_SCALAR:%.*]] = bitcast <2 x i32> [[LV]] to double
; CHECK-NEXT: ret double [[R_SCALAR]]
;
%lv = load <2 x i32>, ptr %x
%r = bitcast <2 x i32> %lv to double
ret double %r
}

; Multiple users with the same bitcast type should be scalarized.
define i32 @load_v4i8_bitcast_multiple_users_same_type(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_same_type(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV_SCALAR:%.*]] = load i32, ptr [[X]], align 4
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[LV_SCALAR]], [[LV_SCALAR]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%lv = load <4 x i8>, ptr %x
%r1 = bitcast <4 x i8> %lv to i32
%r2 = bitcast <4 x i8> %lv to i32
%add = add i32 %r1, %r2
ret i32 %add
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't other passes likely to have already folded these duplication or are you seeing this kind of thing in real world code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for larger the motivating cases this improves results in combination with extend scalarization. @juliannagele could you add a larger case showing the interaction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now that #164683 is in we can potentially scalarize a load-ext-extend sequence. Added a test that show this.


; Different bitcast types should not be scalarized.
define i32 @load_v4i8_bitcast_multiple_users_different_types(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_bitcast_multiple_users_different_types(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
; CHECK-NEXT: [[R2:%.*]] = bitcast <4 x i8> [[LV]] to float
; CHECK-NEXT: [[R2_INT:%.*]] = bitcast float [[R2]] to i32
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_INT]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%lv = load <4 x i8>, ptr %x
%r1 = bitcast <4 x i8> %lv to i32
%r2 = bitcast <4 x i8> %lv to float
%r2.int = bitcast float %r2 to i32
%add = add i32 %r1, %r2.int
ret i32 %add
}

; Bitcast to vector should not be scalarized.
define <2 x i16> @load_v4i8_bitcast_to_vector(ptr %x) {
; CHECK-LABEL: define <2 x i16> @load_v4i8_bitcast_to_vector(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
; CHECK-NEXT: [[R:%.*]] = bitcast <4 x i8> [[LV]] to <2 x i16>
; CHECK-NEXT: ret <2 x i16> [[R]]
;
%lv = load <4 x i8>, ptr %x
%r = bitcast <4 x i8> %lv to <2 x i16>
ret <2 x i16> %r
}

; Load with both bitcast users and other users should not be scalarized.
define i32 @load_v4i8_mixed_users(ptr %x) {
; CHECK-LABEL: define i32 @load_v4i8_mixed_users(
; CHECK-SAME: ptr [[X:%.*]]) {
; CHECK-NEXT: [[LV:%.*]] = load <4 x i8>, ptr [[X]], align 4
; CHECK-NEXT: [[R1:%.*]] = bitcast <4 x i8> [[LV]] to i32
; CHECK-NEXT: [[R2:%.*]] = extractelement <4 x i8> [[LV]], i32 0
; CHECK-NEXT: [[R2_EXT:%.*]] = zext i8 [[R2]] to i32
; CHECK-NEXT: [[ADD:%.*]] = add i32 [[R1]], [[R2_EXT]]
; CHECK-NEXT: ret i32 [[ADD]]
;
%lv = load <4 x i8>, ptr %x
%r1 = bitcast <4 x i8> %lv to i32
%r2 = extractelement <4 x i8> %lv, i32 0
%r2.ext = zext i8 %r2 to i32
%add = add i32 %r1, %r2.ext
ret i32 %add
}