diff --git a/backends/cuda/runtime/shims/memory_slim.cpp b/backends/cuda/runtime/shims/memory_slim.cpp index 500cd41308e..5526fa57125 100644 --- a/backends/cuda/runtime/shims/memory_slim.cpp +++ b/backends/cuda/runtime/shims/memory_slim.cpp @@ -76,6 +76,39 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2( return Error::Ok; } +AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor) { + ET_CHECK_OR_RETURN_ERROR( + ret_new_tensor != nullptr, + InvalidArgument, + "aoti_torch_empty_strided: ret_new_tensor is null"); + + ET_CHECK_OR_RETURN_ERROR( + !(sizes_ptr == nullptr && ndim > 0), + InvalidArgument, + "aoti_torch_empty_strided: sizes_ptr is null but ndim > 0"); + + IntArrayRef sizes(sizes_ptr, static_cast(ndim)); + IntArrayRef strides(strides_ptr, static_cast(ndim)); + + // Create the SlimTensor using empty_strided (owning) + *ret_new_tensor = new Tensor(empty_strided( + sizes, + strides, + static_cast(dtype), + Device( + static_cast(device_type), + static_cast(device_index)))); + + return Error::Ok; +} + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory_slim.h b/backends/cuda/runtime/shims/memory_slim.h index 7650c4de4b6..109fbdcb08b 100644 --- a/backends/cuda/runtime/shims/memory_slim.h +++ b/backends/cuda/runtime/shims/memory_slim.h @@ -57,6 +57,28 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( const uint8_t* opaque_metadata, int64_t opaque_metadata_size); +/** + * Creates an uninitialized tensor with specified dimensions, strides, and + * dtype on either CPU or CUDA device. + * + * @param ndim Number of dimensions in the tensor + * @param sizes_ptr Pointer to array of dimension sizes + * @param strides_ptr Pointer to array of strides for each dimension + * @param dtype Data type identifier (matches PyTorch scalar types) + * @param device_type Device type (0=CPU, 1=CUDA) + * @param device_index Device index + * @param ret_new_tensor Output parameter for the created tensor + * @return AOTITorchError error code (Error::Ok on success) + */ +AOTI_SHIM_EXPORT AOTITorchError aoti_torch_empty_strided( + int64_t ndim, + const int64_t* sizes_ptr, + const int64_t* strides_ptr, + int32_t dtype, + int32_t device_type, + int32_t device_index, + Tensor** ret_new_tensor); + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index 78f8dea20ce..003bb2acece 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -71,4 +71,5 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") # SlimTensor-based shim tests + cuda_shim_slim_cpp_unittest("aoti_torch_empty_strided") cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided_slim.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided_slim.cpp new file mode 100644 index 00000000000..d563eea98bc --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided_slim.cpp @@ -0,0 +1,467 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using executorch::runtime::Error; + +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +namespace { + +// Helper to check if CUDA is available +bool isCudaAvailable() { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + return (err == cudaSuccess && device_count > 0); +} + +// Helper to calculate contiguous strides from sizes +std::vector calculateContiguousStrides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; +} + +} // namespace + +// Test fixture for SlimTensor-based aoti_torch_empty_strided tests +class AOTITorchEmptyStridedSlimTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + } + + void TearDown() override { + // Tensors are cleaned up via their destructors + for (Tensor* t : tensors_) { + delete t; + } + tensors_.clear(); + } + + // Track tensors for cleanup + void trackTensor(Tensor* t) { + if (t != nullptr) { + tensors_.push_back(t); + } + } + + private: + std::vector tensors_; +}; + +// ============================================================================ +// Common test body - parameterized by device type +// ============================================================================ + +void runBasicEmptyStridedTest(int32_t device_type, int32_t device_index) { + // Test 1D tensor + std::vector sizes_1d = {5}; + std::vector strides_1d = calculateContiguousStrides(sizes_1d); + + Tensor* tensor_1d = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes_1d.size(), + sizes_1d.data(), + strides_1d.data(), + static_cast(slim_c10::ScalarType::Float), // dtype = 6 + device_type, + device_index, + &tensor_1d); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor_1d, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor_1d->dim(), 1); + EXPECT_EQ(tensor_1d->size(0), 5); + EXPECT_EQ(tensor_1d->numel(), 5); + EXPECT_EQ( + static_cast(tensor_1d->dtype()), + static_cast(slim_c10::ScalarType::Float)); + EXPECT_NE(tensor_1d->data_ptr(), nullptr); + + // Cleanup + delete tensor_1d; +} + +void runMultiDimensionalEmptyStridedTest( + int32_t device_type, + int32_t device_index) { + // Test 3D tensor + std::vector sizes = {2, 3, 4}; + std::vector strides = calculateContiguousStrides(sizes); + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + device_type, + device_index, + &tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + // Check tensor properties + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 3); + EXPECT_EQ(tensor->size(2), 4); + EXPECT_EQ(tensor->numel(), 24); + + // Check strides + EXPECT_EQ(tensor->stride(0), 12); + EXPECT_EQ(tensor->stride(1), 4); + EXPECT_EQ(tensor->stride(2), 1); + + delete tensor; +} + +void runScalarTensorEmptyStridedTest( + int32_t device_type, + int32_t device_index) { + std::vector sizes = {}; // 0D tensor + std::vector strides = {}; + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + device_type, + device_index, + &tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + EXPECT_EQ(tensor->dim(), 0); + EXPECT_EQ(tensor->numel(), 1); + EXPECT_NE(tensor->data_ptr(), nullptr); + + delete tensor; +} + +void runZeroSizedTensorEmptyStridedTest( + int32_t device_type, + int32_t device_index) { + std::vector sizes = {0, 5}; + std::vector strides = calculateContiguousStrides(sizes); + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + device_type, + device_index, + &tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 0); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->numel(), 0); + + delete tensor; +} + +void runCustomStridesEmptyStridedTest( + int32_t device_type, + int32_t device_index) { + // Create a transposed (column-major) tensor + std::vector sizes = {3, 4}; + std::vector strides = {1, 3}; // Column-major + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + device_type, + device_index, + &tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 3); + EXPECT_EQ(tensor->size(1), 4); + EXPECT_EQ(tensor->stride(0), 1); + EXPECT_EQ(tensor->stride(1), 3); + + // Non-contiguous due to custom strides + EXPECT_FALSE(tensor->is_contiguous()); + + delete tensor; +} + +void runDifferentDtypesEmptyStridedTest( + int32_t device_type, + int32_t device_index) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + + // Test Float32 + { + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + device_type, + device_index, + &tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + EXPECT_EQ(tensor->dtype(), slim_c10::ScalarType::Float); + EXPECT_EQ(tensor->itemsize(), 4); + delete tensor; + } + + // Test BFloat16 + { + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::BFloat16), + device_type, + device_index, + &tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + EXPECT_EQ(tensor->dtype(), slim_c10::ScalarType::BFloat16); + EXPECT_EQ(tensor->itemsize(), 2); + delete tensor; + } + + // Test Int64 + { + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Long), + device_type, + device_index, + &tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + EXPECT_EQ(tensor->dtype(), slim_c10::ScalarType::Long); + EXPECT_EQ(tensor->itemsize(), 8); + delete tensor; + } + + // Test Bool + { + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Bool), + device_type, + device_index, + &tensor); + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + EXPECT_EQ(tensor->dtype(), slim_c10::ScalarType::Bool); + EXPECT_EQ(tensor->itemsize(), 1); + delete tensor; + } +} + +// ============================================================================ +// CPU Tests +// ============================================================================ + +TEST_F(AOTITorchEmptyStridedSlimTest, BasicFunctionality_CPU) { + runBasicEmptyStridedTest(static_cast(slim_c10::DeviceType::CPU), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, MultiDimensional_CPU) { + runMultiDimensionalEmptyStridedTest( + static_cast(slim_c10::DeviceType::CPU), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, ScalarTensor_CPU) { + runScalarTensorEmptyStridedTest( + static_cast(slim_c10::DeviceType::CPU), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, ZeroSizedTensor_CPU) { + runZeroSizedTensorEmptyStridedTest( + static_cast(slim_c10::DeviceType::CPU), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, CustomStrides_CPU) { + runCustomStridesEmptyStridedTest( + static_cast(slim_c10::DeviceType::CPU), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, DifferentDtypes_CPU) { + runDifferentDtypesEmptyStridedTest( + static_cast(slim_c10::DeviceType::CPU), 0); +} + +// ============================================================================ +// CUDA Tests +// ============================================================================ + +TEST_F(AOTITorchEmptyStridedSlimTest, BasicFunctionality_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runBasicEmptyStridedTest(static_cast(slim_c10::DeviceType::CUDA), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, MultiDimensional_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runMultiDimensionalEmptyStridedTest( + static_cast(slim_c10::DeviceType::CUDA), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, ScalarTensor_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runScalarTensorEmptyStridedTest( + static_cast(slim_c10::DeviceType::CUDA), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, ZeroSizedTensor_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runZeroSizedTensorEmptyStridedTest( + static_cast(slim_c10::DeviceType::CUDA), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, CustomStrides_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runCustomStridesEmptyStridedTest( + static_cast(slim_c10::DeviceType::CUDA), 0); +} + +TEST_F(AOTITorchEmptyStridedSlimTest, DifferentDtypes_CUDA) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + runDifferentDtypesEmptyStridedTest( + static_cast(slim_c10::DeviceType::CUDA), 0); +} + +// ============================================================================ +// Verify Device Properties +// ============================================================================ + +TEST_F(AOTITorchEmptyStridedSlimTest, VerifyCPUDevice) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0, + &tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + EXPECT_TRUE(tensor->is_cpu()); + EXPECT_FALSE(tensor->is_cuda()); + EXPECT_EQ(tensor->device_type(), slim_c10::DeviceType::CPU); + + delete tensor; +} + +TEST_F(AOTITorchEmptyStridedSlimTest, VerifyCUDADevice) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + + Tensor* tensor = nullptr; + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0, + &tensor); + + EXPECT_EQ(error, Error::Ok); + ASSERT_NE(tensor, nullptr); + + EXPECT_FALSE(tensor->is_cpu()); + EXPECT_TRUE(tensor->is_cuda()); + EXPECT_EQ(tensor->device_type(), slim_c10::DeviceType::CUDA); + + delete tensor; +} + +// ============================================================================ +// Error Cases +// ============================================================================ + +TEST_F(AOTITorchEmptyStridedSlimTest, NullReturnPointer) { + std::vector sizes = {2, 3}; + std::vector strides = calculateContiguousStrides(sizes); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides.data(), + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0, + nullptr); // null return pointer + + EXPECT_EQ(error, Error::InvalidArgument); +}