Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions backends/aoti/common_shims_slim.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,84 @@ inline AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
return Error::Ok;
}

// ============================================================
// Storage & Device Property Getters - Inline implementations
// ============================================================

inline AOTITorchError aoti_torch_get_storage_offset(
Tensor* tensor,
int64_t* ret_storage_offset) {
if (tensor == nullptr) {
return Error::InvalidArgument;
}
if (ret_storage_offset == nullptr) {
return Error::InvalidArgument;
}

#ifdef CUDA_AVAILABLE
// SlimTensor supports real storage offset
*ret_storage_offset = tensor->storage_offset();
#else
// ETensor doesn't support storage_offset, return 0
*ret_storage_offset = 0;
#endif
return Error::Ok;
}

inline AOTITorchError aoti_torch_get_storage_size(
Tensor* tensor,
int64_t* ret_size) {
if (tensor == nullptr) {
return Error::InvalidArgument;
}
if (ret_size == nullptr) {
return Error::InvalidArgument;
}

*ret_size = static_cast<int64_t>(tensor->nbytes());
return Error::Ok;
}

inline AOTITorchError aoti_torch_get_device_type(
Tensor* tensor,
int32_t* ret_device_type) {
if (tensor == nullptr) {
return Error::InvalidArgument;
}
if (ret_device_type == nullptr) {
return Error::InvalidArgument;
}

#ifdef CUDA_AVAILABLE
// SlimTensor supports real device type
*ret_device_type = static_cast<int32_t>(tensor->device_type());
#else
// ETensor is always CPU in default mode
*ret_device_type = 0; // CPU
#endif
return Error::Ok;
}

inline AOTITorchError aoti_torch_get_device_index(
Tensor* tensor,
int32_t* ret_device_index) {
if (tensor == nullptr) {
return Error::InvalidArgument;
}
if (ret_device_index == nullptr) {
return Error::InvalidArgument;
}

#ifdef CUDA_AVAILABLE
// SlimTensor supports real device index
*ret_device_index = static_cast<int32_t>(tensor->device_index());
#else
// ETensor doesn't support multi-device, return 0
*ret_device_index = 0;
#endif
return Error::Ok;
}

} // namespace aoti
} // namespace backends
} // namespace executorch
131 changes: 131 additions & 0 deletions backends/aoti/tests/test_common_shims_slim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,93 @@ void runGetDimTest(slim_c10::DeviceType device_type) {
}
}

// ============================================================================
// Storage & Device Property Tests
// ============================================================================

void runGetStorageOffsetTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int64_t ret_storage_offset = -1;
AOTITorchError error =
aoti_torch_get_storage_offset(tensor, &ret_storage_offset);

EXPECT_EQ(error, Error::Ok);
// Default storage offset for newly created tensor is 0
EXPECT_EQ(ret_storage_offset, 0);

delete tensor;
}

void runGetStorageSizeTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int64_t ret_size = -1;
AOTITorchError error = aoti_torch_get_storage_size(tensor, &ret_size);

EXPECT_EQ(error, Error::Ok);
// 2 * 3 * sizeof(float) = 6 * 4 = 24 bytes
EXPECT_EQ(ret_size, 24);

delete tensor;
}

void runGetDeviceTypeTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int32_t ret_device_type = -1;
AOTITorchError error = aoti_torch_get_device_type(tensor, &ret_device_type);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(ret_device_type, static_cast<int32_t>(device_type));

delete tensor;
}

void runGetDeviceIndexTest(slim_c10::DeviceType device_type) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = calculateContiguousStrides(sizes);
slim_c10::Device device(device_type, 0);

Tensor* tensor = new Tensor(slim::empty_strided(
slim::makeArrayRef(sizes),
slim::makeArrayRef(strides),
slim_c10::ScalarType::Float,
device));

int32_t ret_device_index = -1;
AOTITorchError error = aoti_torch_get_device_index(tensor, &ret_device_index);

EXPECT_EQ(error, Error::Ok);
EXPECT_EQ(ret_device_index, 0);

delete tensor;
}

// ============================================================================
// CPU Tests
// ============================================================================
Expand All @@ -313,6 +400,22 @@ TEST_F(CommonShimsSlimTest, GetDim_CPU) {
runGetDimTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetStorageOffset_CPU) {
runGetStorageOffsetTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetStorageSize_CPU) {
runGetStorageSizeTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetDeviceType_CPU) {
runGetDeviceTypeTest(slim_c10::DeviceType::CPU);
}

TEST_F(CommonShimsSlimTest, GetDeviceIndex_CPU) {
runGetDeviceIndexTest(slim_c10::DeviceType::CPU);
}

// ============================================================================
// CUDA Tests
// ============================================================================
Expand Down Expand Up @@ -352,6 +455,34 @@ TEST_F(CommonShimsSlimTest, GetDim_CUDA) {
}
runGetDimTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetStorageOffset_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetStorageOffsetTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetStorageSize_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetStorageSizeTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetDeviceType_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetDeviceTypeTest(slim_c10::DeviceType::CUDA);
}

TEST_F(CommonShimsSlimTest, GetDeviceIndex_CUDA) {
if (!isCudaAvailable()) {
GTEST_SKIP() << "CUDA not available";
}
runGetDeviceIndexTest(slim_c10::DeviceType::CUDA);
}
#endif

// ============================================================================
Expand Down
Loading