Skip to content
Open
38 changes: 36 additions & 2 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,39 @@ find_package(CUDAToolkit REQUIRED)
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# CUDA tensor maker for backends that support incontiguous tensors
set(_tensor_maker_sources runtime/tensor/tensor_maker.cpp)
add_library(cuda_tensor_maker STATIC ${_tensor_maker_sources})
target_include_directories(
cuda_tensor_maker
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}> $<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
)
target_compile_options(
cuda_tensor_maker
PUBLIC $<$<CXX_COMPILER_ID:MSVC>:/EHsc /GR>
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-fexceptions -frtti -fPIC>
)
# Ensure symbols are exported properly
if(APPLE)
target_link_options(cuda_tensor_maker PUBLIC -Wl,-export_dynamic)
else()
target_link_options(
cuda_tensor_maker PUBLIC
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
)
endif()

# Link against ExecuTorch core libraries
target_link_libraries(cuda_tensor_maker PUBLIC executorch ${CMAKE_DL_LIBS})
executorch_target_link_options_shared_lib(cuda_tensor_maker)

install(
TARGETS cuda_tensor_maker
EXPORT ExecuTorchTargets
DESTINATION lib
)

# CUDA-specific AOTI functionality
set(_aoti_cuda_sources
runtime/cuda_backend.cpp
Expand Down Expand Up @@ -62,9 +95,10 @@ target_link_options(
aoti_cuda PUBLIC $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
)

# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and PyTorch
# CUDA libraries
target_link_libraries(
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
aoti_cuda PUBLIC aoti_common cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS}
)
# If you need other CUDA libraries, link them similarly:
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)
Expand Down
21 changes: 20 additions & 1 deletion backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ runtime.cxx_library(
],
)

runtime.cxx_library(
name = "tensor_maker",
srcs = [
"tensor/tensor_maker.cpp",
],
headers = [
"tensor/tensor_maker.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/util:tensor_util",
],
)

runtime.cxx_library(
name = "runtime_shims",
srcs = [
Expand All @@ -52,8 +71,8 @@ runtime.cxx_library(
compiler_flags = ["-Wno-global-constructors"],
visibility = ["@EXECUTORCH_CLIENTS"],
deps = [
":tensor_maker",
"//executorch/backends/aoti:common_shims",
"//executorch/extension/tensor:tensor",
"//executorch/runtime/core:core",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/platform:platform",
Expand Down
50 changes: 38 additions & 12 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <executorch/backends/cuda/runtime/platform/platform.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
#include <executorch/backends/cuda/runtime/tensor/tensor_maker.h>
#include <executorch/backends/cuda/runtime/utils.h>
#include <executorch/runtime/platform/log.h>
#include <cstdint>
Expand Down Expand Up @@ -163,9 +164,11 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(

// Create ExecutorTorch tensor that wraps the existing memory
// Note: We're NOT copying the data, just wrapping it
auto tensor = executorch::extension::from_blob(
data, // existing memory (don't copy!)
// Using CUDA-specific tensor maker that supports incontiguous tensors
auto tensor = make_tensor(
sizes, // tensor dimensions
data, // existing memory (don't copy!)
{}, // dim_order (empty, will be auto-generated)
strides, // tensor strides (allows different strides)
dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType
);
Expand Down Expand Up @@ -210,10 +213,6 @@ AOTITorchError aoti_torch_empty_strided(

// This requires us to reserve CUDA memory and put it into a ETensor
void* ptr;
int64_t numel = 1;
for (int64_t i = 0; i < ndim; i++) {
numel *= sizes_ptr[i];
}

ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));

Expand All @@ -223,7 +222,28 @@ AOTITorchError aoti_torch_empty_strided(
InvalidArgument,
"Invalid element size for dtype: %d",
dtype);
int64_t nbytes = numel * element_size;

// Calculate storage size based on strides, matching PyTorch's behavior
// This is critical when sizes and strides don't match the expected contiguous
// layout Reference: PyTorch's computeStorageNbytes in EmptyTensor.cpp
int64_t storage_size = 1; // storage offset (0) + 1
for (int64_t i = 0; i < ndim; i++) {
if (sizes_ptr[i] == 0) {
storage_size = 0;
break;
}
// For each dimension, add stride[i] * (size[i] - 1)
// This gives us the maximum offset in that dimension
int64_t stride_i = (strides_ptr != nullptr) ? strides_ptr[i] : 1;
if (strides_ptr == nullptr) {
// Calculate contiguous stride if not provided
for (int64_t j = i + 1; j < ndim; j++) {
stride_i *= sizes_ptr[j];
}
}
storage_size += stride_i * (sizes_ptr[i] - 1);
}
int64_t nbytes = storage_size * element_size;

if (device_type == static_cast<int32_t>(SupportedDevices::CUDA)) {
ET_CUDA_CHECK_OR_RETURN_ERROR(
Expand All @@ -250,16 +270,20 @@ AOTITorchError aoti_torch_empty_strided(
auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// ETensor creation with dynamic shape support for edge cases
auto tensor = executorch::extension::from_blob(
ptr, sizes, strides, dtype_to_scalar_type(dtype));
// Using CUDA-specific tensor maker that supports incontiguous tensors
auto tensor = make_tensor(
sizes,
ptr,
{}, // dim_order (empty, will be auto-generated)
strides,
dtype_to_scalar_type(dtype));

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);
*ret_new_tensor = tensor.get();

// This tensor owns the memory it allocated, set reference count to 1
memory_to_n_tensor[ptr] = 1;

return Error::Ok;
}

Expand Down Expand Up @@ -630,9 +654,11 @@ AOTITorchError aoti_torch__reinterpret_tensor(

// Create new tensor view that reinterprets the same memory with different
// shape/strides This creates a view, not a copy - the data pointer is shared
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
data_ptr, // Reuse the same memory from source tensor
// Using CUDA-specific tensor maker that supports incontiguous tensors
std::shared_ptr<Tensor> tensor = make_tensor(
sizes, // New sizes with explicit SizesType
data_ptr, // Reuse the same memory from source tensor
{}, // dim_order (empty, will be auto-generated)
strides, // New strides with explicit StridesType
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
);
Expand Down
2 changes: 1 addition & 1 deletion backends/cuda/runtime/shims/tensor_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#pragma once

#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <cstdint>

namespace executorch::backends::cuda {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
#include <executorch/backends/cuda/runtime/utils.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/platform/platform.h>
#include <gtest/gtest.h>
Expand Down
135 changes: 107 additions & 28 deletions backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,30 +278,6 @@ TEST_F(AOTITorchEmptyStridedTest, LargeTensor) {
EXPECT_EQ(tensor->size(2), 50);
}

// Test error handling with memory allocation failures
TEST_F(AOTITorchEmptyStridedTest, MemoryAllocationStress) {
// Try to create a very large tensor that might cause allocation failure
// (This test may pass or fail depending on available memory)
std::vector<int64_t> huge_sizes = {10000, 10000, 100}; // ~38GB for float32
Tensor* tensor;

AOTITorchError error = aoti_torch_empty_strided(
huge_sizes.size(),
huge_sizes.data(),
nullptr,
6, // float32
1, // CUDA device
0, // device index
&tensor);

// Either succeed or fail with memory allocation error
if (error == Error::Ok) {
EXPECT_NE(tensor, nullptr);
} else {
EXPECT_EQ(error, Error::MemoryAllocationFailed);
}
}

// Test aoti_torch_empty_strided with bfloat16 dtype
TEST_F(AOTITorchEmptyStridedTest, BFloat16Tensor) {
// Test creating bfloat16 tensor on CUDA
Expand Down Expand Up @@ -509,11 +485,11 @@ TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) {
EXPECT_EQ(sizes_ptr[2], 3);
}

// Test different data types (only float32 is currently supported)
// Test different data types (currently we support bf16, fp32 and int32)
TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
std::vector<int64_t> sizes = {2, 3};

// Test float32 (dtype 6) - currently the only supported type
// Test float32 (dtype 6) - one of the supported types
Tensor* tensor_float32;
AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
Expand All @@ -527,7 +503,7 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor_float32, nullptr);

// Test unsupported data types should return error
// Test int32 (dtype 3) - one of the supported types
Tensor* tensor_int32;
error = aoti_torch_empty_strided(
sizes.size(),
Expand All @@ -538,7 +514,8 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
0, // device index
&tensor_int32);

EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype
EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor_int32, nullptr);

// Test another unsupported data type
Tensor* tensor_float64;
Expand Down Expand Up @@ -586,3 +563,105 @@ TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) {
EXPECT_EQ(tensor_5d->size(3), 4);
EXPECT_EQ(tensor_5d->size(4), 5);
}

// Test incontiguous tensor creation - transpose-like layout
TEST_F(AOTITorchEmptyStridedTest, IncontiguousTransposeLayout) {
// Create a tensor with transpose-like strides (column-major)
// For a 3x4 tensor in column-major order, strides should be [1, 3]
// This means each row step is 1, and each column step is 3
std::vector<int64_t> sizes = {3, 4};
std::vector<int64_t> strides = {1, 3}; // Column-major (incontiguous)

Tensor* tensor;
AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
sizes.data(),
strides.data(),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDevices::CUDA),
0, // device index
&tensor);

EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor, nullptr);

// Verify tensor properties
EXPECT_EQ(tensor->dim(), 2);
EXPECT_EQ(tensor->size(0), 3);
EXPECT_EQ(tensor->size(1), 4);

// Verify the strides are what we specified
int64_t* strides_ptr;
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok);
EXPECT_EQ(strides_ptr[0], 1); // Column-major stride for dimension 0
EXPECT_EQ(strides_ptr[1], 3); // Column-major stride for dimension 1

// Verify that memory was allocated correctly for incontiguous layout
// Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
// 1) + 1 = 1 * (3 - 1) + 3 * (4 - 1) + 1 = 1 * 2 + 3 * 3 + 1 = 2 + 9 + 1 = 12
// elements Total bytes = 12 * 4 = 48 bytes (for float32)
EXPECT_EQ(tensor->numel(), 12); // numel is still 3*4=12 for logical shape

// The tensor should be accessible and writable
void* data_ptr = tensor->mutable_data_ptr();
EXPECT_NE(data_ptr, nullptr);

// Verify we can use CUDA to write to the memory
std::vector<float> test_data(12, 1.0f);
cudaError_t cuda_err = cudaMemcpy(
data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice);
EXPECT_EQ(cuda_err, cudaSuccess);
}

// Test incontiguous tensor creation - expanded/broadcasted stride pattern
TEST_F(AOTITorchEmptyStridedTest, IncontiguousExpandedStrides) {
// Create a tensor with expanded strides (simulating broadcasting)
// A 2x3x4 tensor where the first dimension has stride 0 (expanded)
// This creates a tensor where the first dimension is "broadcasted"
std::vector<int64_t> sizes = {2, 3, 4};
std::vector<int64_t> strides = {0, 4, 1}; // First dimension has stride 0

Tensor* tensor;
AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
sizes.data(),
strides.data(),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDevices::CUDA),
0, // device index
&tensor);

EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor, nullptr);

// Verify tensor properties
EXPECT_EQ(tensor->dim(), 3);
EXPECT_EQ(tensor->size(0), 2);
EXPECT_EQ(tensor->size(1), 3);
EXPECT_EQ(tensor->size(2), 4);

// Verify the strides are what we specified
int64_t* strides_ptr;
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok);
EXPECT_EQ(strides_ptr[0], 0); // Expanded dimension stride
EXPECT_EQ(strides_ptr[1], 4);
EXPECT_EQ(strides_ptr[2], 1);

// Verify that memory was allocated correctly for this incontiguous layout
// Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
// 1) + stride[2] * (size[2] - 1) + 1 = 0 * (2 - 1) + 4 * (3 - 1) + 1 * (4 -
// 1) + 1 = 0 + 8 + 3 + 1 = 12 elements Note: numel() returns logical number
// of elements (2*3*4=24), not storage size
EXPECT_EQ(tensor->numel(), 24); // Logical numel is 2*3*4=24

// The tensor should be accessible and writable
void* data_ptr = tensor->mutable_data_ptr();
EXPECT_NE(data_ptr, nullptr);

// Verify we can use CUDA to write to the allocated memory
// We only need to allocate 12 elements (storage size), not 24
std::vector<float> test_data(12, 2.0f);
cudaError_t cuda_err = cudaMemcpy(
data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice);
EXPECT_EQ(cuda_err, cudaSuccess);
}
Loading
Loading