diff --git a/backends/aoti/common_shims_slim.h b/backends/aoti/common_shims_slim.h index b76b0d164d1..b8948c0a311 100644 --- a/backends/aoti/common_shims_slim.h +++ b/backends/aoti/common_shims_slim.h @@ -288,6 +288,65 @@ inline AOTITorchError aoti_torch_get_device_index( return Error::Ok; } +// ============================================================ +// DType Constants - These return PyTorch ScalarType enum values +// ============================================================ + +inline int32_t aoti_torch_dtype_float32() { + return 6; // ScalarType::Float +} + +inline int32_t aoti_torch_dtype_bfloat16() { + return 15; // ScalarType::BFloat16 +} + +inline int32_t aoti_torch_dtype_int64() { + return 4; // ScalarType::Long +} + +inline int32_t aoti_torch_dtype_int32() { + return 3; // ScalarType::Int +} + +inline int32_t aoti_torch_dtype_int16() { + return 2; // ScalarType::Short +} + +inline int32_t aoti_torch_dtype_int8() { + return 1; // ScalarType::Char +} + +inline int32_t aoti_torch_dtype_bool() { + return 11; // ScalarType::Bool +} + +// ============================================================ +// Device Type Constants +// ============================================================ + +inline int32_t aoti_torch_device_type_cpu() { + return 0; // DeviceType::CPU +} + +inline int32_t aoti_torch_device_type_cuda() { + return 1; // DeviceType::CUDA +} + +// ============================================================ +// Grad Mode Functions (not supported in ExecuTorch) +// ============================================================ + +inline bool aoti_torch_grad_mode_is_enabled() { + return false; // ExecuTorch doesn't support autograd +} + +inline AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled) { + if (enabled) { + return Error::NotSupported; // Grad mode not supported in ExecuTorch + } + return Error::Ok; +} + } // namespace aoti } // namespace backends } // namespace executorch diff --git a/backends/aoti/tests/test_common_shims_slim.cpp b/backends/aoti/tests/test_common_shims_slim.cpp index 728bcc6a34f..ca744565955 100644 --- a/backends/aoti/tests/test_common_shims_slim.cpp +++ b/backends/aoti/tests/test_common_shims_slim.cpp @@ -589,3 +589,44 @@ TEST_F(CommonShimsSlimTest, ConsistentPointerReturn) { EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok); EXPECT_EQ(strides_ptr1, strides_ptr2); } + +// ============================================================================ +// DType Constants Tests +// ============================================================================ + +TEST_F(CommonShimsSlimTest, DTypeConstants) { + // Verify dtype constants match expected PyTorch ScalarType values + EXPECT_EQ(aoti_torch_dtype_float32(), 6); // ScalarType::Float + EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // ScalarType::BFloat16 + EXPECT_EQ(aoti_torch_dtype_int64(), 4); // ScalarType::Long + EXPECT_EQ(aoti_torch_dtype_int32(), 3); // ScalarType::Int + EXPECT_EQ(aoti_torch_dtype_int16(), 2); // ScalarType::Short + EXPECT_EQ(aoti_torch_dtype_int8(), 1); // ScalarType::Char + EXPECT_EQ(aoti_torch_dtype_bool(), 11); // ScalarType::Bool +} + +// ============================================================================ +// Device Type Constants Tests +// ============================================================================ + +TEST_F(CommonShimsSlimTest, DeviceTypeConstants) { + EXPECT_EQ(aoti_torch_device_type_cpu(), 0); // DeviceType::CPU + EXPECT_EQ(aoti_torch_device_type_cuda(), 1); // DeviceType::CUDA +} + +// ============================================================================ +// Grad Mode Tests +// ============================================================================ + +TEST_F(CommonShimsSlimTest, GradModeIsEnabled) { + // ExecuTorch doesn't support autograd, so should always return false + EXPECT_EQ(aoti_torch_grad_mode_is_enabled(), false); +} + +TEST_F(CommonShimsSlimTest, GradModeSetEnabled) { + // Setting to false should succeed + EXPECT_EQ(aoti_torch_grad_mode_set_enabled(false), Error::Ok); + + // Setting to true should fail (not supported in ExecuTorch) + EXPECT_EQ(aoti_torch_grad_mode_set_enabled(true), Error::NotSupported); +}