Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions backends/aoti/slim/c10/core/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct Device final {
}

/// Constructs a Device from a string description.
/// The string must be "cpu" or "cpu:0".
/// The string must be "cpu", "cpu:0", "cuda", or "cuda:N".
/* implicit */ Device(const std::string& device_string)
: Device(DeviceType::CPU) {
ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty");
Expand All @@ -46,11 +46,19 @@ struct Device final {
index_ = -1;
} else if (device_string == "cpu:0" || device_string == "CPU:0") {
type_ = DeviceType::CPU;
index_ = 0;
} else if (device_string == "cuda" || device_string == "CUDA") {
type_ = DeviceType::CUDA;
index_ = 0;
} else if (
device_string.substr(0, 5) == "cuda:" ||
device_string.substr(0, 5) == "CUDA:") {
type_ = DeviceType::CUDA;
index_ = static_cast<DeviceIndex>(device_string.back() - '0');
} else {
ET_CHECK_MSG(
false,
"Invalid device string: %s. Currently only 'cpu' is supported.",
"Invalid device string: %s. Supported: 'cpu', 'cuda', 'cuda:N'.",
device_string.c_str());
}
validate();
Expand Down Expand Up @@ -92,7 +100,12 @@ struct Device final {
return type_ == DeviceType::CPU;
}

/// Returns a string representation of the device (e.g., "cpu" or "cpu:0").
/// Returns true if the device is of CUDA type.
bool is_cuda() const noexcept {
return type_ == DeviceType::CUDA;
}

/// Returns a string representation of the device (e.g., "cpu" or "cuda:0").
std::string str() const {
std::string str = DeviceTypeName(type(), /* lower_case */ true);
if (has_index()) {
Expand Down
8 changes: 6 additions & 2 deletions backends/aoti/slim/c10/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ namespace executorch::backends::aoti::slim::c10 {
/// Enum representing the type of device.
enum class DeviceType : int8_t {
CPU = 0,
COMPILE_TIME_MAX_DEVICE_TYPES = 1,
CUDA = 1,
COMPILE_TIME_MAX_DEVICE_TYPES = 2,
};

constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;

/// Maximum number of device types at compile time.
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
Expand All @@ -36,6 +38,8 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
switch (d) {
case DeviceType::CPU:
return lower_case ? "cpu" : "CPU";
case DeviceType::CUDA:
return lower_case ? "cuda" : "CUDA";
default:
ET_CHECK_MSG(false, "Unknown device type: %d", static_cast<int>(d));
}
Expand All @@ -45,7 +49,7 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) {
/// @param d The device type to check.
/// @return true if the device type is valid, false otherwise.
inline bool isValidDeviceType(DeviceType d) {
return d == DeviceType::CPU;
return d == DeviceType::CPU || d == DeviceType::CUDA;
}

inline std::ostream& operator<<(std::ostream& stream, DeviceType type) {
Expand Down
122 changes: 122 additions & 0 deletions backends/aoti/slim/c10/core/test/test_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,125 @@ TEST_F(DeviceTest, Hash) {
EXPECT_EQ(hasher(cpu1), hasher(cpu2));
EXPECT_NE(hasher(cpu1), hasher(cpu3));
}

// =============================================================================
// CUDA DeviceType Tests
// =============================================================================

class CUDADeviceTypeTest : public ::testing::Test {};

TEST_F(CUDADeviceTypeTest, CUDAEnumValue) {
// Verify CUDA has the correct enum value (1) to match PyTorch
EXPECT_EQ(static_cast<int>(DeviceType::CUDA), 1);
}

TEST_F(CUDADeviceTypeTest, DeviceTypeName) {
// Verify DeviceTypeName returns correct strings for CUDA
EXPECT_EQ(DeviceTypeName(DeviceType::CUDA, false), "CUDA");
EXPECT_EQ(DeviceTypeName(DeviceType::CUDA, true), "cuda");
}

TEST_F(CUDADeviceTypeTest, IsValidDeviceType) {
// Verify isValidDeviceType works correctly for CUDA
EXPECT_TRUE(isValidDeviceType(DeviceType::CUDA));
}

TEST_F(CUDADeviceTypeTest, KCUDAConstant) {
// Verify kCUDA constant
EXPECT_EQ(kCUDA, DeviceType::CUDA);
}

// =============================================================================
// CUDA Device Tests
// =============================================================================

class CUDADeviceTest : public ::testing::Test {};

TEST_F(CUDADeviceTest, ConstructFromDeviceType) {
// Construct Device from DeviceType
Device cuda_device(DeviceType::CUDA);

EXPECT_TRUE(cuda_device.is_cuda());
EXPECT_FALSE(cuda_device.is_cpu());
EXPECT_EQ(cuda_device.type(), DeviceType::CUDA);
EXPECT_EQ(cuda_device.index(), -1);
EXPECT_FALSE(cuda_device.has_index());
}

TEST_F(CUDADeviceTest, ConstructWithIndex) {
// Construct CUDA Device with explicit index
Device cuda_device(DeviceType::CUDA, 0);

EXPECT_TRUE(cuda_device.is_cuda());
EXPECT_FALSE(cuda_device.is_cpu());
EXPECT_EQ(cuda_device.type(), DeviceType::CUDA);
EXPECT_EQ(cuda_device.index(), 0);
EXPECT_TRUE(cuda_device.has_index());
}

TEST_F(CUDADeviceTest, ConstructWithNonZeroIndex) {
// Construct CUDA Device with non-zero index (multi-GPU)
Device cuda_device(DeviceType::CUDA, 3);

EXPECT_TRUE(cuda_device.is_cuda());
EXPECT_EQ(cuda_device.index(), 3);
EXPECT_TRUE(cuda_device.has_index());
}

TEST_F(CUDADeviceTest, ConstructFromString) {
// Construct CUDA Device from string
Device cuda1("cuda");
EXPECT_TRUE(cuda1.is_cuda());
EXPECT_EQ(cuda1.index(), 0);

Device cuda2("CUDA");
EXPECT_TRUE(cuda2.is_cuda());
EXPECT_EQ(cuda2.index(), 0);

Device cuda3("cuda:0");
EXPECT_TRUE(cuda3.is_cuda());
EXPECT_EQ(cuda3.index(), 0);

Device cuda4("cuda:1");
EXPECT_TRUE(cuda4.is_cuda());
EXPECT_EQ(cuda4.index(), 1);

Device cuda5("CUDA:2");
EXPECT_TRUE(cuda5.is_cuda());
EXPECT_EQ(cuda5.index(), 2);
}

TEST_F(CUDADeviceTest, Equality) {
Device cuda1(DeviceType::CUDA, 0);
Device cuda2(DeviceType::CUDA, 0);
Device cuda3(DeviceType::CUDA, 1);
Device cpu(DeviceType::CPU, 0);

EXPECT_EQ(cuda1, cuda2);
EXPECT_NE(cuda1, cuda3);
EXPECT_NE(cuda1, cpu);
}

TEST_F(CUDADeviceTest, Str) {
Device cuda1(DeviceType::CUDA);
EXPECT_EQ(cuda1.str(), "cuda");

Device cuda2(DeviceType::CUDA, 0);
EXPECT_EQ(cuda2.str(), "cuda:0");

Device cuda3(DeviceType::CUDA, 1);
EXPECT_EQ(cuda3.str(), "cuda:1");
}

TEST_F(CUDADeviceTest, Hash) {
// Verify CUDA Device can be hashed
Device cuda1(DeviceType::CUDA, 0);
Device cuda2(DeviceType::CUDA, 0);
Device cuda3(DeviceType::CUDA, 1);
Device cpu(DeviceType::CPU, 0);

std::hash<Device> hasher;
EXPECT_EQ(hasher(cuda1), hasher(cuda2));
EXPECT_NE(hasher(cuda1), hasher(cuda3));
EXPECT_NE(hasher(cuda1), hasher(cpu));
}
Loading