Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions backends/aoti/common_shims_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,65 @@ inline AOTITorchError aoti_torch_get_device_index(
return Error::Ok;
}

// ============================================================
// DType Constants - These return PyTorch ScalarType enum values
// ============================================================

inline int32_t aoti_torch_dtype_float32() {
return 6; // ScalarType::Float
}

inline int32_t aoti_torch_dtype_bfloat16() {
return 15; // ScalarType::BFloat16
}

inline int32_t aoti_torch_dtype_int64() {
return 4; // ScalarType::Long
}

inline int32_t aoti_torch_dtype_int32() {
return 3; // ScalarType::Int
}

inline int32_t aoti_torch_dtype_int16() {
return 2; // ScalarType::Short
}

inline int32_t aoti_torch_dtype_int8() {
return 1; // ScalarType::Char
}

inline int32_t aoti_torch_dtype_bool() {
return 11; // ScalarType::Bool
}

// ============================================================
// Device Type Constants
// ============================================================

inline int32_t aoti_torch_device_type_cpu() {
return 0; // DeviceType::CPU
}

inline int32_t aoti_torch_device_type_cuda() {
return 1; // DeviceType::CUDA
}

// ============================================================
// Grad Mode Functions (not supported in ExecuTorch)
// ============================================================

inline bool aoti_torch_grad_mode_is_enabled() {
return false; // ExecuTorch doesn't support autograd
}

inline AOTITorchError aoti_torch_grad_mode_set_enabled(bool enabled) {
if (enabled) {
return Error::NotSupported; // Grad mode not supported in ExecuTorch
}
return Error::Ok;
}

} // namespace aoti
} // namespace backends
} // namespace executorch
41 changes: 41 additions & 0 deletions backends/aoti/tests/test_common_shims_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,44 @@ TEST_F(CommonShimsSlimTest, ConsistentPointerReturn) {
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr2), Error::Ok);
EXPECT_EQ(strides_ptr1, strides_ptr2);
}

// ============================================================================
// DType Constants Tests
// ============================================================================

TEST_F(CommonShimsSlimTest, DTypeConstants) {
// Verify dtype constants match expected PyTorch ScalarType values
EXPECT_EQ(aoti_torch_dtype_float32(), 6); // ScalarType::Float
EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // ScalarType::BFloat16
EXPECT_EQ(aoti_torch_dtype_int64(), 4); // ScalarType::Long
EXPECT_EQ(aoti_torch_dtype_int32(), 3); // ScalarType::Int
EXPECT_EQ(aoti_torch_dtype_int16(), 2); // ScalarType::Short
EXPECT_EQ(aoti_torch_dtype_int8(), 1); // ScalarType::Char
EXPECT_EQ(aoti_torch_dtype_bool(), 11); // ScalarType::Bool
}

// ============================================================================
// Device Type Constants Tests
// ============================================================================

TEST_F(CommonShimsSlimTest, DeviceTypeConstants) {
EXPECT_EQ(aoti_torch_device_type_cpu(), 0); // DeviceType::CPU
EXPECT_EQ(aoti_torch_device_type_cuda(), 1); // DeviceType::CUDA
}

// ============================================================================
// Grad Mode Tests
// ============================================================================

TEST_F(CommonShimsSlimTest, GradModeIsEnabled) {
// ExecuTorch doesn't support autograd, so should always return false
EXPECT_EQ(aoti_torch_grad_mode_is_enabled(), false);
}

TEST_F(CommonShimsSlimTest, GradModeSetEnabled) {
// Setting to false should succeed
EXPECT_EQ(aoti_torch_grad_mode_set_enabled(false), Error::Ok);

// Setting to true should fail (not supported in ExecuTorch)
EXPECT_EQ(aoti_torch_grad_mode_set_enabled(true), Error::NotSupported);
}
Loading