Skip to content

Commit e02e60d

Browse files
committed
[InstCombine] Allow load to store forwarding for scalable structs
This attempts to fix a regression caused when scalable types started to be emitted as structs. A __builtin_bit_cast will create a load/store pair that we currently can not break up. This patch allows load-store forwarding in InstCombine to split up the structs if it can detect that each part is bitcast-able between the source and dest types. https://alive2.llvm.org/ce/z/ewWXqQ Non-scalable structs are decomposed through unpackStoreToAggregate before they get here. Geps with scalable offsets are not valid so they would require i8 gep types with vscale offsets that do not easily get cleaned up into bitcasts.
1 parent b801860 commit e02e60d

File tree

4 files changed

+55
-12
lines changed

4 files changed

+55
-12
lines changed

llvm/include/llvm/Analysis/Loads.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ Value *FindAvailableLoadedValue(LoadInst *Load, BasicBlock *ScanBB,
156156
/// This overload cannot be used to scan across multiple blocks.
157157
Value *FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
158158
bool *IsLoadCSE,
159-
unsigned MaxInstsToScan = DefMaxInstsToScan);
159+
unsigned MaxInstsToScan = DefMaxInstsToScan,
160+
bool AllowPartwiseBitcastStructs = false);
160161

161162
/// Scan backwards to see if we have the value of the given pointer available
162163
/// locally within a small number of instructions.

llvm/lib/Analysis/Loads.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,8 @@ static bool areNonOverlapSameBaseLoadAndStore(const Value *LoadPtr,
531531

532532
static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr,
533533
Type *AccessTy, bool AtLeastAtomic,
534-
const DataLayout &DL, bool *IsLoadCSE) {
534+
const DataLayout &DL, bool *IsLoadCSE,
535+
bool AllowPartwiseBitcastStructs = false) {
535536
// If this is a load of Ptr, the loaded value is available.
536537
// (This is true even if the load is volatile or atomic, although
537538
// those cases are unlikely.)
@@ -572,6 +573,19 @@ static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr,
572573
if (CastInst::isBitOrNoopPointerCastable(Val->getType(), AccessTy, DL))
573574
return Val;
574575

576+
if (AllowPartwiseBitcastStructs) {
577+
if (StructType *SrcStructTy = dyn_cast<StructType>(Val->getType())) {
578+
if (StructType *DestStructTy = dyn_cast<StructType>(AccessTy)) {
579+
if (SrcStructTy->getNumElements() == DestStructTy->getNumElements() &&
580+
all_of_zip(SrcStructTy->elements(), DestStructTy->elements(),
581+
[](Type *T1, Type *T2) {
582+
return CastInst::isBitCastable(T1, T2);
583+
}))
584+
return Val;
585+
}
586+
}
587+
}
588+
575589
TypeSize StoreSize = DL.getTypeSizeInBits(Val->getType());
576590
TypeSize LoadSize = DL.getTypeSizeInBits(AccessTy);
577591
if (TypeSize::isKnownLE(LoadSize, StoreSize))
@@ -704,8 +718,8 @@ Value *llvm::findAvailablePtrLoadStore(
704718
}
705719

706720
Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
707-
bool *IsLoadCSE,
708-
unsigned MaxInstsToScan) {
721+
bool *IsLoadCSE, unsigned MaxInstsToScan,
722+
bool AllowPartwiseBitcastStructs) {
709723
const DataLayout &DL = Load->getDataLayout();
710724
Value *StrippedPtr = Load->getPointerOperand()->stripPointerCasts();
711725
BasicBlock *ScanBB = Load->getParent();
@@ -727,8 +741,9 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
727741
if (MaxInstsToScan-- == 0)
728742
return nullptr;
729743

730-
Available = getAvailableLoadStore(&Inst, StrippedPtr, AccessTy,
731-
AtLeastAtomic, DL, IsLoadCSE);
744+
Available =
745+
getAvailableLoadStore(&Inst, StrippedPtr, AccessTy, AtLeastAtomic, DL,
746+
IsLoadCSE, AllowPartwiseBitcastStructs);
732747
if (Available)
733748
break;
734749

llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,10 +1010,24 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
10101010
// separated by a few arithmetic operations.
10111011
bool IsLoadCSE = false;
10121012
BatchAAResults BatchAA(*AA);
1013-
if (Value *AvailableVal = FindAvailableLoadedValue(&LI, BatchAA, &IsLoadCSE)) {
1013+
if (Value *AvailableVal =
1014+
FindAvailableLoadedValue(&LI, BatchAA, &IsLoadCSE, DefMaxInstsToScan,
1015+
/*AllowPartwiseBitcastStructs=*/true)) {
10141016
if (IsLoadCSE)
10151017
combineMetadataForCSE(cast<LoadInst>(AvailableVal), &LI, false);
10161018

1019+
if (AvailableVal->getType() != LI.getType() &&
1020+
isa<StructType>(LI.getType())) {
1021+
StructType *DstST = cast<StructType>(LI.getType());
1022+
Value *R = PoisonValue::get(LI.getType());
1023+
for (unsigned I = 0, E = DstST->getNumElements(); I < E; I++) {
1024+
Value *Ext = Builder.CreateExtractValue(AvailableVal, I);
1025+
Value *BC =
1026+
Builder.CreateBitOrPointerCast(Ext, DstST->getElementType(I));
1027+
R = Builder.CreateInsertValue(R, BC, I);
1028+
}
1029+
return replaceInstUsesWith(LI, R);
1030+
}
10171031
return replaceInstUsesWith(
10181032
LI, Builder.CreateBitOrPointerCast(AvailableVal, LI.getType(),
10191033
LI.getName() + ".cast"));

llvm/test/Transforms/InstCombine/availableloadstruct.ll

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ define {<vscale x 16 x i8>, <vscale x 16 x i8>} @check_nxv16i8_nxv4i32({<vscale
2929
; CHECK-SAME: { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X:%.*]], ptr [[P:%.*]]) #[[ATTR0]] {
3030
; CHECK-NEXT: [[ENTRY:.*:]]
3131
; CHECK-NEXT: store { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], ptr [[P]], align 16
32-
; CHECK-NEXT: [[R:%.*]] = load { <vscale x 16 x i8>, <vscale x 16 x i8> }, ptr [[P]], align 16
32+
; CHECK-NEXT: [[TMP0:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 0
33+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <vscale x 4 x i32> [[TMP0]] to <vscale x 16 x i8>
34+
; CHECK-NEXT: [[TMP2:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } poison, <vscale x 16 x i8> [[TMP1]], 0
35+
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 1
36+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <vscale x 4 x i32> [[TMP3]] to <vscale x 16 x i8>
37+
; CHECK-NEXT: [[R:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[TMP2]], <vscale x 16 x i8> [[TMP4]], 1
3338
; CHECK-NEXT: ret { <vscale x 16 x i8>, <vscale x 16 x i8> } [[R]]
3439
;
3540
entry:
@@ -42,9 +47,12 @@ define {<vscale x 16 x i8>, <vscale x 16 x i8>} @alloca_nxv16i8_nxv4i32({<vscale
4247
; CHECK-LABEL: define { <vscale x 16 x i8>, <vscale x 16 x i8> } @alloca_nxv16i8_nxv4i32(
4348
; CHECK-SAME: { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X:%.*]]) #[[ATTR0]] {
4449
; CHECK-NEXT: [[ENTRY:.*:]]
45-
; CHECK-NEXT: [[P:%.*]] = alloca { <vscale x 4 x i32>, <vscale x 4 x i32> }, align 16
46-
; CHECK-NEXT: store { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], ptr [[P]], align 16
47-
; CHECK-NEXT: [[R:%.*]] = load { <vscale x 16 x i8>, <vscale x 16 x i8> }, ptr [[P]], align 16
50+
; CHECK-NEXT: [[TMP0:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 0
51+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <vscale x 4 x i32> [[TMP0]] to <vscale x 16 x i8>
52+
; CHECK-NEXT: [[TMP2:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } poison, <vscale x 16 x i8> [[TMP1]], 0
53+
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 1
54+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <vscale x 4 x i32> [[TMP3]] to <vscale x 16 x i8>
55+
; CHECK-NEXT: [[R:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[TMP2]], <vscale x 16 x i8> [[TMP4]], 1
4856
; CHECK-NEXT: ret { <vscale x 16 x i8>, <vscale x 16 x i8> } [[R]]
4957
;
5058
entry:
@@ -60,7 +68,12 @@ define { <16 x i8>, <32 x i8> } @differenttypes({ <4 x i32>, <8 x i32> } %a, ptr
6068
; CHECK-NEXT: [[ENTRY:.*:]]
6169
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr nonnull [[P]])
6270
; CHECK-NEXT: store { <4 x i32>, <8 x i32> } [[A]], ptr [[P]], align 16
63-
; CHECK-NEXT: [[TMP0:%.*]] = load { <16 x i8>, <32 x i8> }, ptr [[P]], align 16
71+
; CHECK-NEXT: [[TMP5:%.*]] = extractvalue { <4 x i32>, <8 x i32> } [[A]], 0
72+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[TMP5]] to <16 x i8>
73+
; CHECK-NEXT: [[TMP2:%.*]] = insertvalue { <16 x i8>, <32 x i8> } poison, <16 x i8> [[TMP1]], 0
74+
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { <4 x i32>, <8 x i32> } [[A]], 1
75+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP3]] to <32 x i8>
76+
; CHECK-NEXT: [[TMP0:%.*]] = insertvalue { <16 x i8>, <32 x i8> } [[TMP2]], <32 x i8> [[TMP4]], 1
6477
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr nonnull [[P]])
6578
; CHECK-NEXT: ret { <16 x i8>, <32 x i8> } [[TMP0]]
6679
;

0 commit comments

Comments
 (0)