Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions llvm/include/llvm/ADT/ValueOrSentinel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file defines the ValueOrSentinel class, which is a type akin to a
/// std::optional, but uses a sentinel rather than an additional "valid" flag.
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_ADT_VALUEORSENTINEL_H
#define LLVM_ADT_VALUEORSENTINEL_H

#include <cassert>
#include <limits>
#include <utility>

namespace llvm {

template <typename T, T Sentinel> class ValueOrSentinel {
public:
ValueOrSentinel() = default;

ValueOrSentinel(T Value) : Value(std::move(Value)) {
assert(Value != Sentinel && "Value is sentinel (use default constructor)");
};

ValueOrSentinel &operator=(const T &NewValue) {
assert(NewValue != Sentinel && "NewValue is sentinel (use .clear())");
Value = NewValue;
return *this;
}

bool operator==(const ValueOrSentinel &Other) const {
return Value == Other.Value;
}

bool operator!=(const ValueOrSentinel &Other) const {
return !(*this == Other);
}

T &value() {
assert(has_value() && ".value() called on sentinel");
return Value;
}
const T &value() const {
return const_cast<ValueOrSentinel &>(*this).value();
}

T &operator*() { return value(); }
const T &operator*() const { return value(); }

bool has_value() const { return Value != Sentinel; }

explicit operator bool() const { return has_value(); }
explicit operator T() const { return value(); }

void clear() { Value = Sentinel; }

private:
T Value{Sentinel};
};

template <typename T>
using ValueOrSentinelIntMax = ValueOrSentinel<T, std::numeric_limits<T>::max()>;

} // namespace llvm

#endif
31 changes: 15 additions & 16 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3479,7 +3479,7 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
}

Register LastReg = 0;
int HazardSlotIndex = std::numeric_limits<int>::max();
ValueOrSentinelIntMax<int> HazardSlotIndex;
for (auto &CS : CSI) {
MCRegister Reg = CS.getReg();
const TargetRegisterClass *RC = RegInfo->getMinimalPhysRegClass(Reg);
Expand All @@ -3488,16 +3488,16 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
if (AFI->hasStackHazardSlotIndex() &&
(!LastReg || !AArch64InstrInfo::isFpOrNEON(LastReg)) &&
AArch64InstrInfo::isFpOrNEON(Reg)) {
assert(HazardSlotIndex == std::numeric_limits<int>::max() &&
assert(!HazardSlotIndex.has_value() &&
"Unexpected register order for hazard slot");
HazardSlotIndex = MFI.CreateStackObject(StackHazardSize, Align(8), true);
LLVM_DEBUG(dbgs() << "Created CSR Hazard at slot " << HazardSlotIndex
LLVM_DEBUG(dbgs() << "Created CSR Hazard at slot " << *HazardSlotIndex
<< "\n");
AFI->setStackHazardCSRSlotIndex(HazardSlotIndex);
if ((unsigned)HazardSlotIndex < MinCSFrameIndex)
MinCSFrameIndex = HazardSlotIndex;
if ((unsigned)HazardSlotIndex > MaxCSFrameIndex)
MaxCSFrameIndex = HazardSlotIndex;
AFI->setStackHazardCSRSlotIndex(*HazardSlotIndex);
if (static_cast<unsigned>(*HazardSlotIndex) < MinCSFrameIndex)
MinCSFrameIndex = *HazardSlotIndex;
if (static_cast<unsigned>(*HazardSlotIndex) > MaxCSFrameIndex)
MaxCSFrameIndex = *HazardSlotIndex;
}

unsigned Size = RegInfo->getSpillSize(*RC);
Expand All @@ -3524,16 +3524,15 @@ bool AArch64FrameLowering::assignCalleeSavedSpillSlots(
}

// Add hazard slot in the case where no FPR CSRs are present.
if (AFI->hasStackHazardSlotIndex() &&
HazardSlotIndex == std::numeric_limits<int>::max()) {
if (AFI->hasStackHazardSlotIndex() && !HazardSlotIndex.has_value()) {
HazardSlotIndex = MFI.CreateStackObject(StackHazardSize, Align(8), true);
LLVM_DEBUG(dbgs() << "Created CSR Hazard at slot " << HazardSlotIndex
LLVM_DEBUG(dbgs() << "Created CSR Hazard at slot " << *HazardSlotIndex
<< "\n");
AFI->setStackHazardCSRSlotIndex(HazardSlotIndex);
if ((unsigned)HazardSlotIndex < MinCSFrameIndex)
MinCSFrameIndex = HazardSlotIndex;
if ((unsigned)HazardSlotIndex > MaxCSFrameIndex)
MaxCSFrameIndex = HazardSlotIndex;
AFI->setStackHazardCSRSlotIndex(*HazardSlotIndex);
if (static_cast<unsigned>(*HazardSlotIndex) < MinCSFrameIndex)
MinCSFrameIndex = *HazardSlotIndex;
if (static_cast<unsigned>(*HazardSlotIndex) > MaxCSFrameIndex)
MaxCSFrameIndex = *HazardSlotIndex;
}

return true;
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3061,10 +3061,10 @@ AArch64TargetLowering::EmitInitTPIDR2Object(MachineInstr &MI,
BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STPXi))
.addReg(MI.getOperand(0).getReg())
.addReg(MI.getOperand(1).getReg())
.addFrameIndex(TPIDR2.FrameIndex)
.addFrameIndex(*TPIDR2.FrameIndex)
.addImm(0);
} else
MFI.RemoveStackObject(TPIDR2.FrameIndex);
MFI.RemoveStackObject(*TPIDR2.FrameIndex);

BB->remove_instr(&MI);
return BB;
Expand Down Expand Up @@ -9399,7 +9399,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
if (RequiresLazySave) {
TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
TPIDR2.FrameIndex,
*TPIDR2.FrameIndex,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
Chain = DAG.getNode(
ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
Expand Down Expand Up @@ -9956,7 +9956,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// RESTORE_ZA pseudo.
SDValue Glue;
SDValue TPIDR2Block = DAG.getFrameIndex(
TPIDR2.FrameIndex,
*TPIDR2.FrameIndex,
DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
Result =
Expand Down
31 changes: 17 additions & 14 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/ValueOrSentinel.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MIRYamlMapping.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
Expand All @@ -38,7 +39,7 @@ class AArch64Subtarget;
class MachineInstr;

struct TPIDR2Object {
int FrameIndex = std::numeric_limits<int>::max();
ValueOrSentinelIntMax<int> FrameIndex;
unsigned Uses = 0;
};

Expand Down Expand Up @@ -114,8 +115,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
/// The stack slots used to add space between FPR and GPR accesses when using
/// hazard padding. StackHazardCSRSlotIndex is added between GPR and FPR CSRs.
/// StackHazardSlotIndex is added between (sorted) stack objects.
int StackHazardSlotIndex = std::numeric_limits<int>::max();
int StackHazardCSRSlotIndex = std::numeric_limits<int>::max();
ValueOrSentinelIntMax<int> StackHazardSlotIndex;
ValueOrSentinelIntMax<int> StackHazardCSRSlotIndex;

/// True if this function has a subset of CSRs that is handled explicitly via
/// copies.
Expand Down Expand Up @@ -205,7 +206,7 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
bool HasSwiftAsyncContext = false;

/// The stack slot where the Swift asynchronous context is stored.
int SwiftAsyncContextFrameIdx = std::numeric_limits<int>::max();
ValueOrSentinelIntMax<int> SwiftAsyncContextFrameIdx;

bool IsMTETagged = false;

Expand Down Expand Up @@ -372,16 +373,16 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
MaxOffset = std::max<int64_t>(Offset + ObjSize, MaxOffset);
}

if (SwiftAsyncContextFrameIdx != std::numeric_limits<int>::max()) {
if (SwiftAsyncContextFrameIdx.has_value()) {
int64_t Offset = MFI.getObjectOffset(getSwiftAsyncContextFrameIdx());
int64_t ObjSize = MFI.getObjectSize(getSwiftAsyncContextFrameIdx());
MinOffset = std::min<int64_t>(Offset, MinOffset);
MaxOffset = std::max<int64_t>(Offset + ObjSize, MaxOffset);
}

if (StackHazardCSRSlotIndex != std::numeric_limits<int>::max()) {
int64_t Offset = MFI.getObjectOffset(StackHazardCSRSlotIndex);
int64_t ObjSize = MFI.getObjectSize(StackHazardCSRSlotIndex);
if (StackHazardCSRSlotIndex.has_value()) {
int64_t Offset = MFI.getObjectOffset(*StackHazardCSRSlotIndex);
int64_t ObjSize = MFI.getObjectSize(*StackHazardCSRSlotIndex);
MinOffset = std::min<int64_t>(Offset, MinOffset);
MaxOffset = std::max<int64_t>(Offset + ObjSize, MaxOffset);
}
Expand Down Expand Up @@ -447,16 +448,16 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
void setVarArgsFPRSize(unsigned Size) { VarArgsFPRSize = Size; }

bool hasStackHazardSlotIndex() const {
return StackHazardSlotIndex != std::numeric_limits<int>::max();
return StackHazardSlotIndex.has_value();
}
int getStackHazardSlotIndex() const { return StackHazardSlotIndex; }
int getStackHazardSlotIndex() const { return *StackHazardSlotIndex; }
void setStackHazardSlotIndex(int Index) {
assert(StackHazardSlotIndex == std::numeric_limits<int>::max());
assert(!StackHazardSlotIndex.has_value());
StackHazardSlotIndex = Index;
}
int getStackHazardCSRSlotIndex() const { return StackHazardCSRSlotIndex; }
int getStackHazardCSRSlotIndex() const { return *StackHazardCSRSlotIndex; }
void setStackHazardCSRSlotIndex(int Index) {
assert(StackHazardCSRSlotIndex == std::numeric_limits<int>::max());
assert(!StackHazardCSRSlotIndex.has_value());
StackHazardCSRSlotIndex = Index;
}

Expand Down Expand Up @@ -574,7 +575,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
void setSwiftAsyncContextFrameIdx(int FI) {
SwiftAsyncContextFrameIdx = FI;
}
int getSwiftAsyncContextFrameIdx() const { return SwiftAsyncContextFrameIdx; }
int getSwiftAsyncContextFrameIdx() const {
return *SwiftAsyncContextFrameIdx;
}

bool needsDwarfUnwindInfo(const MachineFunction &MF) const;
bool needsAsyncDwarfUnwindInfo(const MachineFunction &MF) const;
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/ADT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ add_llvm_unittest(ADTTests
LazyAtomicPointerTest.cpp
MappedIteratorTest.cpp
MapVectorTest.cpp
ValueOrSentinelTest.cpp
PackedVectorTest.cpp
PagedVectorTest.cpp
PointerEmbeddedIntTest.cpp
Expand Down
62 changes: 62 additions & 0 deletions llvm/unittests/ADT/ValueOrSentinelTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/ValueOrSentinel.h"
#include "gtest/gtest.h"

using namespace llvm;

namespace {

TEST(ValueOrSentinelTest, Basic) {
// Default constructor should equal sentinel.
ValueOrSentinelIntMax<int> Value;
EXPECT_FALSE(Value.has_value());
EXPECT_FALSE(bool(Value));

// Assignment operator.
Value = 1000;
EXPECT_TRUE(Value.has_value());

// .value(), operator*, implicit constructor, explicit conversion
EXPECT_EQ(Value, 1000);
EXPECT_EQ(Value.value(), 1000);
EXPECT_EQ(*Value, 1000);
EXPECT_EQ(int(Value), 1000);

// .clear() should set value to sentinel
Value.clear();
EXPECT_FALSE(Value.has_value());
EXPECT_FALSE(bool(Value));

// construction from value, comparison operators
ValueOrSentinelIntMax<int> OtherValue(99);
EXPECT_TRUE(OtherValue.has_value());
EXPECT_TRUE(bool(OtherValue));
EXPECT_EQ(OtherValue, 99);
EXPECT_NE(Value, OtherValue);

Value = OtherValue;
EXPECT_EQ(Value, OtherValue);
}

TEST(ValueOrSentinelTest, PointerType) {
ValueOrSentinel<int *, nullptr> Value;
EXPECT_FALSE(Value.has_value());

int A = 10;
Value = &A;
EXPECT_TRUE(Value.has_value());

EXPECT_EQ(*Value.value(), 10);

Value.clear();
EXPECT_FALSE(Value.has_value());
}

} // end anonymous namespace
4 changes: 2 additions & 2 deletions llvm/unittests/Support/CommandLineTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ TEST_F(PrintOptionInfoTest, PrintOptionInfoValueOptionalWithoutSentinel) {
// clang-format on
}

TEST_F(PrintOptionInfoTest, PrintOptionInfoValueOptionalWithSentinel) {
TEST_F(PrintOptionInfoTest, PrintOptionInfoValueValueOrSentinel) {
std::string Output = runTest(
cl::ValueOptional, cl::values(clEnumValN(OptionValue::Val, "v1", "desc1"),
clEnumValN(OptionValue::Val, "", "")));
Expand All @@ -1426,7 +1426,7 @@ TEST_F(PrintOptionInfoTest, PrintOptionInfoValueOptionalWithSentinel) {
// clang-format on
}

TEST_F(PrintOptionInfoTest, PrintOptionInfoValueOptionalWithSentinelWithHelp) {
TEST_F(PrintOptionInfoTest, PrintOptionInfoValueValueOrSentinelWithHelp) {
std::string Output = runTest(
cl::ValueOptional, cl::values(clEnumValN(OptionValue::Val, "v1", "desc1"),
clEnumValN(OptionValue::Val, "", "desc2")));
Expand Down
Loading