diff --git a/backends/aoti/slim/c10/core/Device.h b/backends/aoti/slim/c10/core/Device.h index 5638f6f80e8..02e41d3d221 100644 --- a/backends/aoti/slim/c10/core/Device.h +++ b/backends/aoti/slim/c10/core/Device.h @@ -36,7 +36,7 @@ struct Device final { } /// Constructs a Device from a string description. - /// The string must be "cpu" or "cpu:0". + /// The string must be "cpu", "cpu:0", "cuda", or "cuda:N". /* implicit */ Device(const std::string& device_string) : Device(DeviceType::CPU) { ET_CHECK_MSG(!device_string.empty(), "Device string must not be empty"); @@ -46,11 +46,19 @@ struct Device final { index_ = -1; } else if (device_string == "cpu:0" || device_string == "CPU:0") { type_ = DeviceType::CPU; + index_ = 0; + } else if (device_string == "cuda" || device_string == "CUDA") { + type_ = DeviceType::CUDA; + index_ = 0; + } else if ( + device_string.substr(0, 5) == "cuda:" || + device_string.substr(0, 5) == "CUDA:") { + type_ = DeviceType::CUDA; index_ = static_cast(device_string.back() - '0'); } else { ET_CHECK_MSG( false, - "Invalid device string: %s. Currently only 'cpu' is supported.", + "Invalid device string: %s. Supported: 'cpu', 'cuda', 'cuda:N'.", device_string.c_str()); } validate(); @@ -92,7 +100,12 @@ struct Device final { return type_ == DeviceType::CPU; } - /// Returns a string representation of the device (e.g., "cpu" or "cpu:0"). + /// Returns true if the device is of CUDA type. + bool is_cuda() const noexcept { + return type_ == DeviceType::CUDA; + } + + /// Returns a string representation of the device (e.g., "cpu" or "cuda:0"). std::string str() const { std::string str = DeviceTypeName(type(), /* lower_case */ true); if (has_index()) { diff --git a/backends/aoti/slim/c10/core/DeviceType.h b/backends/aoti/slim/c10/core/DeviceType.h index c8c36c7faab..21ccd7976d3 100644 --- a/backends/aoti/slim/c10/core/DeviceType.h +++ b/backends/aoti/slim/c10/core/DeviceType.h @@ -19,10 +19,12 @@ namespace executorch::backends::aoti::slim::c10 { /// Enum representing the type of device. enum class DeviceType : int8_t { CPU = 0, - COMPILE_TIME_MAX_DEVICE_TYPES = 1, + CUDA = 1, + COMPILE_TIME_MAX_DEVICE_TYPES = 2, }; constexpr DeviceType kCPU = DeviceType::CPU; +constexpr DeviceType kCUDA = DeviceType::CUDA; /// Maximum number of device types at compile time. constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = @@ -36,6 +38,8 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) { switch (d) { case DeviceType::CPU: return lower_case ? "cpu" : "CPU"; + case DeviceType::CUDA: + return lower_case ? "cuda" : "CUDA"; default: ET_CHECK_MSG(false, "Unknown device type: %d", static_cast(d)); } @@ -45,7 +49,7 @@ inline std::string DeviceTypeName(DeviceType d, bool lower_case = false) { /// @param d The device type to check. /// @return true if the device type is valid, false otherwise. inline bool isValidDeviceType(DeviceType d) { - return d == DeviceType::CPU; + return d == DeviceType::CPU || d == DeviceType::CUDA; } inline std::ostream& operator<<(std::ostream& stream, DeviceType type) { diff --git a/backends/aoti/slim/c10/core/test/test_device.cpp b/backends/aoti/slim/c10/core/test/test_device.cpp index 57123589775..50ac76a8a3e 100644 --- a/backends/aoti/slim/c10/core/test/test_device.cpp +++ b/backends/aoti/slim/c10/core/test/test_device.cpp @@ -109,3 +109,125 @@ TEST_F(DeviceTest, Hash) { EXPECT_EQ(hasher(cpu1), hasher(cpu2)); EXPECT_NE(hasher(cpu1), hasher(cpu3)); } + +// ============================================================================= +// CUDA DeviceType Tests +// ============================================================================= + +class CUDADeviceTypeTest : public ::testing::Test {}; + +TEST_F(CUDADeviceTypeTest, CUDAEnumValue) { + // Verify CUDA has the correct enum value (1) to match PyTorch + EXPECT_EQ(static_cast(DeviceType::CUDA), 1); +} + +TEST_F(CUDADeviceTypeTest, DeviceTypeName) { + // Verify DeviceTypeName returns correct strings for CUDA + EXPECT_EQ(DeviceTypeName(DeviceType::CUDA, false), "CUDA"); + EXPECT_EQ(DeviceTypeName(DeviceType::CUDA, true), "cuda"); +} + +TEST_F(CUDADeviceTypeTest, IsValidDeviceType) { + // Verify isValidDeviceType works correctly for CUDA + EXPECT_TRUE(isValidDeviceType(DeviceType::CUDA)); +} + +TEST_F(CUDADeviceTypeTest, KCUDAConstant) { + // Verify kCUDA constant + EXPECT_EQ(kCUDA, DeviceType::CUDA); +} + +// ============================================================================= +// CUDA Device Tests +// ============================================================================= + +class CUDADeviceTest : public ::testing::Test {}; + +TEST_F(CUDADeviceTest, ConstructFromDeviceType) { + // Construct Device from DeviceType + Device cuda_device(DeviceType::CUDA); + + EXPECT_TRUE(cuda_device.is_cuda()); + EXPECT_FALSE(cuda_device.is_cpu()); + EXPECT_EQ(cuda_device.type(), DeviceType::CUDA); + EXPECT_EQ(cuda_device.index(), -1); + EXPECT_FALSE(cuda_device.has_index()); +} + +TEST_F(CUDADeviceTest, ConstructWithIndex) { + // Construct CUDA Device with explicit index + Device cuda_device(DeviceType::CUDA, 0); + + EXPECT_TRUE(cuda_device.is_cuda()); + EXPECT_FALSE(cuda_device.is_cpu()); + EXPECT_EQ(cuda_device.type(), DeviceType::CUDA); + EXPECT_EQ(cuda_device.index(), 0); + EXPECT_TRUE(cuda_device.has_index()); +} + +TEST_F(CUDADeviceTest, ConstructWithNonZeroIndex) { + // Construct CUDA Device with non-zero index (multi-GPU) + Device cuda_device(DeviceType::CUDA, 3); + + EXPECT_TRUE(cuda_device.is_cuda()); + EXPECT_EQ(cuda_device.index(), 3); + EXPECT_TRUE(cuda_device.has_index()); +} + +TEST_F(CUDADeviceTest, ConstructFromString) { + // Construct CUDA Device from string + Device cuda1("cuda"); + EXPECT_TRUE(cuda1.is_cuda()); + EXPECT_EQ(cuda1.index(), 0); + + Device cuda2("CUDA"); + EXPECT_TRUE(cuda2.is_cuda()); + EXPECT_EQ(cuda2.index(), 0); + + Device cuda3("cuda:0"); + EXPECT_TRUE(cuda3.is_cuda()); + EXPECT_EQ(cuda3.index(), 0); + + Device cuda4("cuda:1"); + EXPECT_TRUE(cuda4.is_cuda()); + EXPECT_EQ(cuda4.index(), 1); + + Device cuda5("CUDA:2"); + EXPECT_TRUE(cuda5.is_cuda()); + EXPECT_EQ(cuda5.index(), 2); +} + +TEST_F(CUDADeviceTest, Equality) { + Device cuda1(DeviceType::CUDA, 0); + Device cuda2(DeviceType::CUDA, 0); + Device cuda3(DeviceType::CUDA, 1); + Device cpu(DeviceType::CPU, 0); + + EXPECT_EQ(cuda1, cuda2); + EXPECT_NE(cuda1, cuda3); + EXPECT_NE(cuda1, cpu); +} + +TEST_F(CUDADeviceTest, Str) { + Device cuda1(DeviceType::CUDA); + EXPECT_EQ(cuda1.str(), "cuda"); + + Device cuda2(DeviceType::CUDA, 0); + EXPECT_EQ(cuda2.str(), "cuda:0"); + + Device cuda3(DeviceType::CUDA, 1); + EXPECT_EQ(cuda3.str(), "cuda:1"); +} + +TEST_F(CUDADeviceTest, Hash) { + // Verify CUDA Device can be hashed + Device cuda1(DeviceType::CUDA, 0); + Device cuda2(DeviceType::CUDA, 0); + Device cuda3(DeviceType::CUDA, 1); + Device cpu(DeviceType::CPU, 0); + + std::hash hasher; + EXPECT_EQ(hasher(cuda1), hasher(cuda2)); + EXPECT_NE(hasher(cuda1), hasher(cuda3)); + EXPECT_NE(hasher(cuda1), hasher(cpu)); +}