Skip to content

Commit bbde6be

Browse files
authored
[llvm] Support multiple save/restore points in mir (#119357)
Currently mir supports only one save and one restore point specification: ``` savePoint: '%bb.1' restorePoint: '%bb.2' ``` This patch provide possibility to have multiple save and multiple restore points in mir: ``` savePoints: - point: '%bb.1' restorePoints: - point: '%bb.2' ``` Shrink-Wrap points split Part 3. RFC: https://discourse.llvm.org/t/shrink-wrap-save-restore-points-splitting/83581 Part 1: #117862 Part 2: #119355 Part 4: #119358 Part 5: #119359
1 parent ef5e65d commit bbde6be

17 files changed

+330
-73
lines changed

llvm/include/llvm/CodeGen/MIRYamlMapping.h

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,55 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::CalledGlobal)
634634
namespace llvm {
635635
namespace yaml {
636636

637+
// Struct representing one save/restore point in the 'savePoint'/'restorePoint'
638+
// list
639+
struct SaveRestorePointEntry {
640+
StringValue Point;
641+
642+
bool operator==(const SaveRestorePointEntry &Other) const {
643+
return Point == Other.Point;
644+
}
645+
};
646+
647+
using SaveRestorePoints =
648+
std::variant<std::vector<SaveRestorePointEntry>, StringValue>;
649+
650+
template <> struct PolymorphicTraits<SaveRestorePoints> {
651+
652+
static NodeKind getKind(const SaveRestorePoints &SRPoints) {
653+
if (std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
654+
return NodeKind::Sequence;
655+
if (std::holds_alternative<StringValue>(SRPoints))
656+
return NodeKind::Scalar;
657+
llvm_unreachable("Unsupported NodeKind of SaveRestorePoints");
658+
}
659+
660+
static SaveRestorePointEntry &getAsMap(SaveRestorePoints &SRPoints) {
661+
llvm_unreachable("SaveRestorePoints can't be represented as Map");
662+
}
663+
664+
static std::vector<SaveRestorePointEntry> &
665+
getAsSequence(SaveRestorePoints &SRPoints) {
666+
if (!std::holds_alternative<std::vector<SaveRestorePointEntry>>(SRPoints))
667+
SRPoints = std::vector<SaveRestorePointEntry>();
668+
669+
return std::get<std::vector<SaveRestorePointEntry>>(SRPoints);
670+
}
671+
672+
static StringValue &getAsScalar(SaveRestorePoints &SRPoints) {
673+
if (!std::holds_alternative<StringValue>(SRPoints))
674+
SRPoints = StringValue();
675+
676+
return std::get<StringValue>(SRPoints);
677+
}
678+
};
679+
680+
template <> struct MappingTraits<SaveRestorePointEntry> {
681+
static void mapping(IO &YamlIO, SaveRestorePointEntry &Entry) {
682+
YamlIO.mapRequired("point", Entry.Point);
683+
}
684+
};
685+
637686
template <> struct MappingTraits<MachineJumpTable> {
638687
static void mapping(IO &YamlIO, MachineJumpTable &JT) {
639688
YamlIO.mapRequired("kind", JT.Kind);
@@ -642,6 +691,14 @@ template <> struct MappingTraits<MachineJumpTable> {
642691
}
643692
};
644693

694+
} // namespace yaml
695+
} // namespace llvm
696+
697+
LLVM_YAML_IS_SEQUENCE_VECTOR(llvm::yaml::SaveRestorePointEntry)
698+
699+
namespace llvm {
700+
namespace yaml {
701+
645702
/// Serializable representation of MachineFrameInfo.
646703
///
647704
/// Doesn't serialize attributes like 'StackAlignment', 'IsStackRealignable' and
@@ -669,8 +726,8 @@ struct MachineFrameInfo {
669726
bool HasTailCall = false;
670727
bool IsCalleeSavedInfoValid = false;
671728
unsigned LocalFrameSize = 0;
672-
StringValue SavePoint;
673-
StringValue RestorePoint;
729+
SaveRestorePoints SavePoints;
730+
SaveRestorePoints RestorePoints;
674731

675732
bool operator==(const MachineFrameInfo &Other) const {
676733
return IsFrameAddressTaken == Other.IsFrameAddressTaken &&
@@ -691,7 +748,8 @@ struct MachineFrameInfo {
691748
HasMustTailInVarArgFunc == Other.HasMustTailInVarArgFunc &&
692749
HasTailCall == Other.HasTailCall &&
693750
LocalFrameSize == Other.LocalFrameSize &&
694-
SavePoint == Other.SavePoint && RestorePoint == Other.RestorePoint &&
751+
SavePoints == Other.SavePoints &&
752+
RestorePoints == Other.RestorePoints &&
695753
IsCalleeSavedInfoValid == Other.IsCalleeSavedInfoValid;
696754
}
697755
};
@@ -723,10 +781,14 @@ template <> struct MappingTraits<MachineFrameInfo> {
723781
YamlIO.mapOptional("isCalleeSavedInfoValid", MFI.IsCalleeSavedInfoValid,
724782
false);
725783
YamlIO.mapOptional("localFrameSize", MFI.LocalFrameSize, (unsigned)0);
726-
YamlIO.mapOptional("savePoint", MFI.SavePoint,
727-
StringValue()); // Don't print it out when it's empty.
728-
YamlIO.mapOptional("restorePoint", MFI.RestorePoint,
729-
StringValue()); // Don't print it out when it's empty.
784+
YamlIO.mapOptional(
785+
"savePoint", MFI.SavePoints,
786+
SaveRestorePoints(
787+
StringValue())); // Don't print it out when it's empty.
788+
YamlIO.mapOptional(
789+
"restorePoint", MFI.RestorePoints,
790+
SaveRestorePoints(
791+
StringValue())); // Don't print it out when it's empty.
730792
}
731793
};
732794

llvm/include/llvm/CodeGen/MachineFrameInfo.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,10 @@ class MachineFrameInfo {
332332
/// stack objects like arguments so we can't treat them as immutable.
333333
bool HasTailCall = false;
334334

335-
/// Not null, if shrink-wrapping found a better place for the prologue.
336-
MachineBasicBlock *Save = nullptr;
337-
/// Not null, if shrink-wrapping found a better place for the epilogue.
338-
MachineBasicBlock *Restore = nullptr;
335+
/// Not empty, if shrink-wrapping found a better place for the prologue.
336+
SmallVector<MachineBasicBlock *, 4> SavePoints;
337+
/// Not empty, if shrink-wrapping found a better place for the epilogue.
338+
SmallVector<MachineBasicBlock *, 4> RestorePoints;
339339

340340
/// Size of the UnsafeStack Frame
341341
uint64_t UnsafeStackSize = 0;
@@ -825,10 +825,25 @@ class MachineFrameInfo {
825825

826826
void setCalleeSavedInfoValid(bool v) { CSIValid = v; }
827827

828-
MachineBasicBlock *getSavePoint() const { return Save; }
829-
void setSavePoint(MachineBasicBlock *NewSave) { Save = NewSave; }
830-
MachineBasicBlock *getRestorePoint() const { return Restore; }
831-
void setRestorePoint(MachineBasicBlock *NewRestore) { Restore = NewRestore; }
828+
ArrayRef<MachineBasicBlock *> getSavePoints() const { return SavePoints; }
829+
void setSavePoints(ArrayRef<MachineBasicBlock *> NewSavePoints) {
830+
SavePoints = SmallVector<MachineBasicBlock *>(NewSavePoints);
831+
}
832+
ArrayRef<MachineBasicBlock *> getRestorePoints() const {
833+
return RestorePoints;
834+
}
835+
void setRestorePoints(ArrayRef<MachineBasicBlock *> NewRestorePoints) {
836+
RestorePoints = SmallVector<MachineBasicBlock *>(NewRestorePoints);
837+
}
838+
839+
static SmallVector<MachineBasicBlock *> constructSaveRestorePoints(
840+
ArrayRef<MachineBasicBlock *> SRPoints,
841+
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &BBMap) {
842+
SmallVector<MachineBasicBlock *, 4> Pts;
843+
for (auto &Src : SRPoints)
844+
Pts.push_back(BBMap.find(Src)->second);
845+
return Pts;
846+
}
832847

833848
uint64_t getUnsafeStackSize() const { return UnsafeStackSize; }
834849
void setUnsafeStackSize(uint64_t Size) { UnsafeStackSize = Size; }

llvm/lib/CodeGen/MIRParser/MIRParser.cpp

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ class MIRParserImpl {
124124
bool initializeFrameInfo(PerFunctionMIParsingState &PFS,
125125
const yaml::MachineFunction &YamlMF);
126126

127+
bool initializeSaveRestorePoints(
128+
PerFunctionMIParsingState &PFS,
129+
const yaml::SaveRestorePoints &YamlSRPoints,
130+
SmallVectorImpl<MachineBasicBlock *> &SaveRestorePoints);
131+
127132
bool initializeCallSiteInfo(PerFunctionMIParsingState &PFS,
128133
const yaml::MachineFunction &YamlMF);
129134

@@ -867,18 +872,14 @@ bool MIRParserImpl::initializeFrameInfo(PerFunctionMIParsingState &PFS,
867872
MFI.setHasTailCall(YamlMFI.HasTailCall);
868873
MFI.setCalleeSavedInfoValid(YamlMFI.IsCalleeSavedInfoValid);
869874
MFI.setLocalFrameSize(YamlMFI.LocalFrameSize);
870-
if (!YamlMFI.SavePoint.Value.empty()) {
871-
MachineBasicBlock *MBB = nullptr;
872-
if (parseMBBReference(PFS, MBB, YamlMFI.SavePoint))
873-
return true;
874-
MFI.setSavePoint(MBB);
875-
}
876-
if (!YamlMFI.RestorePoint.Value.empty()) {
877-
MachineBasicBlock *MBB = nullptr;
878-
if (parseMBBReference(PFS, MBB, YamlMFI.RestorePoint))
879-
return true;
880-
MFI.setRestorePoint(MBB);
881-
}
875+
SmallVector<MachineBasicBlock *, 4> SavePoints;
876+
if (initializeSaveRestorePoints(PFS, YamlMFI.SavePoints, SavePoints))
877+
return true;
878+
MFI.setSavePoints(SavePoints);
879+
SmallVector<MachineBasicBlock *, 4> RestorePoints;
880+
if (initializeSaveRestorePoints(PFS, YamlMFI.RestorePoints, RestorePoints))
881+
return true;
882+
MFI.setRestorePoints(RestorePoints);
882883

883884
std::vector<CalleeSavedInfo> CSIInfo;
884885
// Initialize the fixed frame objects.
@@ -1093,6 +1094,35 @@ bool MIRParserImpl::initializeConstantPool(PerFunctionMIParsingState &PFS,
10931094
return false;
10941095
}
10951096

1097+
// Return true if basic block was incorrectly specified in MIR
1098+
bool MIRParserImpl::initializeSaveRestorePoints(
1099+
PerFunctionMIParsingState &PFS, const yaml::SaveRestorePoints &YamlSRPoints,
1100+
SmallVectorImpl<MachineBasicBlock *> &SaveRestorePoints) {
1101+
MachineBasicBlock *MBB = nullptr;
1102+
if (std::holds_alternative<std::vector<yaml::SaveRestorePointEntry>>(
1103+
YamlSRPoints)) {
1104+
const auto &VectorRepr =
1105+
std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRPoints);
1106+
if (VectorRepr.empty())
1107+
return false;
1108+
for (const yaml::SaveRestorePointEntry &Entry : VectorRepr) {
1109+
const yaml::StringValue &MBBSource = Entry.Point;
1110+
if (parseMBBReference(PFS, MBB, MBBSource.Value))
1111+
return true;
1112+
SaveRestorePoints.push_back(MBB);
1113+
}
1114+
} else {
1115+
yaml::StringValue StringRepr = std::get<yaml::StringValue>(YamlSRPoints);
1116+
if (StringRepr.Value.empty())
1117+
return false;
1118+
if (parseMBBReference(PFS, MBB, StringRepr))
1119+
return true;
1120+
SaveRestorePoints.push_back(MBB);
1121+
}
1122+
1123+
return false;
1124+
}
1125+
10961126
bool MIRParserImpl::initializeJumpTableInfo(PerFunctionMIParsingState &PFS,
10971127
const yaml::MachineJumpTable &YamlJTI) {
10981128
MachineJumpTableInfo *JTI = PFS.MF.getOrCreateJumpTableInfo(YamlJTI.Kind);

llvm/lib/CodeGen/MIRPrinter.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ static void convertMJTI(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
150150
const MachineJumpTableInfo &JTI);
151151
static void convertMFI(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
152152
const MachineFrameInfo &MFI);
153+
static void convertSRPoints(ModuleSlotTracker &MST,
154+
yaml::SaveRestorePoints &YamlSRPoints,
155+
ArrayRef<MachineBasicBlock *> SaveRestorePoints);
153156
static void convertStackObjects(yaml::MachineFunction &YMF,
154157
const MachineFunction &MF,
155158
ModuleSlotTracker &MST, MFPrintState &State);
@@ -355,14 +358,10 @@ static void convertMFI(ModuleSlotTracker &MST, yaml::MachineFrameInfo &YamlMFI,
355358
YamlMFI.HasTailCall = MFI.hasTailCall();
356359
YamlMFI.IsCalleeSavedInfoValid = MFI.isCalleeSavedInfoValid();
357360
YamlMFI.LocalFrameSize = MFI.getLocalFrameSize();
358-
if (MFI.getSavePoint()) {
359-
raw_string_ostream StrOS(YamlMFI.SavePoint.Value);
360-
StrOS << printMBBReference(*MFI.getSavePoint());
361-
}
362-
if (MFI.getRestorePoint()) {
363-
raw_string_ostream StrOS(YamlMFI.RestorePoint.Value);
364-
StrOS << printMBBReference(*MFI.getRestorePoint());
365-
}
361+
if (!MFI.getSavePoints().empty())
362+
convertSRPoints(MST, YamlMFI.SavePoints, MFI.getSavePoints());
363+
if (!MFI.getRestorePoints().empty())
364+
convertSRPoints(MST, YamlMFI.RestorePoints, MFI.getRestorePoints());
366365
}
367366

368367
static void convertEntryValueObjects(yaml::MachineFunction &YMF,
@@ -616,6 +615,22 @@ static void convertMCP(yaml::MachineFunction &MF,
616615
}
617616
}
618617

618+
static void convertSRPoints(ModuleSlotTracker &MST,
619+
yaml::SaveRestorePoints &YamlSRPoints,
620+
ArrayRef<MachineBasicBlock *> SRPoints) {
621+
auto &Points =
622+
std::get<std::vector<yaml::SaveRestorePointEntry>>(YamlSRPoints);
623+
for (const auto &MBB : SRPoints) {
624+
SmallString<16> Str;
625+
yaml::SaveRestorePointEntry Entry;
626+
raw_svector_ostream StrOS(Str);
627+
StrOS << printMBBReference(*MBB);
628+
Entry.Point = StrOS.str().str();
629+
Str.clear();
630+
Points.push_back(Entry);
631+
}
632+
}
633+
619634
static void convertMJTI(ModuleSlotTracker &MST, yaml::MachineJumpTable &YamlJTI,
620635
const MachineJumpTableInfo &JTI) {
621636
YamlJTI.Kind = JTI.getEntryKind();

llvm/lib/CodeGen/MachineFrameInfo.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,22 @@ void MachineFrameInfo::print(const MachineFunction &MF, raw_ostream &OS) const{
244244
}
245245
OS << "\n";
246246
}
247+
OS << "save/restore points:\n";
248+
249+
if (!SavePoints.empty()) {
250+
OS << "save points:\n";
251+
252+
for (auto &item : SavePoints)
253+
OS << printMBBReference(*item) << "\n";
254+
} else
255+
OS << "save points are empty\n";
256+
257+
if (!RestorePoints.empty()) {
258+
OS << "restore points:\n";
259+
for (auto &item : RestorePoints)
260+
OS << printMBBReference(*item) << "\n";
261+
} else
262+
OS << "restore points are empty\n";
247263
}
248264

249265
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

llvm/lib/CodeGen/PrologEpilogInserter.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ bool PEIImpl::run(MachineFunction &MF) {
351351
delete RS;
352352
SaveBlocks.clear();
353353
RestoreBlocks.clear();
354-
MFI.setSavePoint(nullptr);
355-
MFI.setRestorePoint(nullptr);
354+
MFI.setSavePoints({});
355+
MFI.setRestorePoints({});
356356
return true;
357357
}
358358

@@ -423,16 +423,18 @@ void PEIImpl::calculateCallFrameInfo(MachineFunction &MF) {
423423
/// callee-saved registers, and placing prolog and epilog code.
424424
void PEIImpl::calculateSaveRestoreBlocks(MachineFunction &MF) {
425425
const MachineFrameInfo &MFI = MF.getFrameInfo();
426-
427426
// Even when we do not change any CSR, we still want to insert the
428427
// prologue and epilogue of the function.
429428
// So set the save points for those.
430429

431430
// Use the points found by shrink-wrapping, if any.
432-
if (MFI.getSavePoint()) {
433-
SaveBlocks.push_back(MFI.getSavePoint());
434-
assert(MFI.getRestorePoint() && "Both restore and save must be set");
435-
MachineBasicBlock *RestoreBlock = MFI.getRestorePoint();
431+
if (!MFI.getSavePoints().empty()) {
432+
assert(MFI.getSavePoints().size() == 1 &&
433+
"Multiple save points are not yet supported!");
434+
SaveBlocks.push_back(MFI.getSavePoints().front());
435+
assert(MFI.getRestorePoints().size() == 1 &&
436+
"Multiple restore points are not yet supported!");
437+
MachineBasicBlock *RestoreBlock = MFI.getRestorePoints().front();
436438
// If RestoreBlock does not have any successor and is not a return block
437439
// then the end point is unreachable and we do not need to insert any
438440
// epilogue.
@@ -558,7 +560,11 @@ static void updateLiveness(MachineFunction &MF) {
558560
SmallPtrSet<MachineBasicBlock *, 8> Visited;
559561
SmallVector<MachineBasicBlock *, 8> WorkList;
560562
MachineBasicBlock *Entry = &MF.front();
561-
MachineBasicBlock *Save = MFI.getSavePoint();
563+
564+
assert(MFI.getSavePoints().size() < 2 &&
565+
"Multiple save points not yet supported!");
566+
MachineBasicBlock *Save =
567+
MFI.getSavePoints().empty() ? nullptr : MFI.getSavePoints().front();
562568

563569
if (!Save)
564570
Save = Entry;
@@ -569,7 +575,10 @@ static void updateLiveness(MachineFunction &MF) {
569575
}
570576
Visited.insert(Save);
571577

572-
MachineBasicBlock *Restore = MFI.getRestorePoint();
578+
assert(MFI.getRestorePoints().size() < 2 &&
579+
"Multiple restore points not yet supported!");
580+
MachineBasicBlock *Restore =
581+
MFI.getRestorePoints().empty() ? nullptr : MFI.getRestorePoints().front();
573582
if (Restore)
574583
// By construction Restore cannot be visited, otherwise it
575584
// means there exists a path to Restore that does not go

llvm/lib/CodeGen/ShrinkWrap.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -967,8 +967,14 @@ bool ShrinkWrapImpl::run(MachineFunction &MF) {
967967
<< "\nRestore: " << printMBBReference(*Restore) << '\n');
968968

969969
MachineFrameInfo &MFI = MF.getFrameInfo();
970-
MFI.setSavePoint(Save);
971-
MFI.setRestorePoint(Restore);
970+
SmallVector<MachineBasicBlock *, 4> SavePoints;
971+
SmallVector<MachineBasicBlock *, 4> RestorePoints;
972+
if (Save) {
973+
SavePoints.push_back(Save);
974+
RestorePoints.push_back(Restore);
975+
}
976+
MFI.setSavePoints(SavePoints);
977+
MFI.setRestorePoints(RestorePoints);
972978
++NumCandidates;
973979
return Changed;
974980
}

0 commit comments

Comments
 (0)