diff --git a/backends/aoti/slim/core/SlimTensor.h b/backends/aoti/slim/core/SlimTensor.h index c662202493d..92b34e8a3e8 100644 --- a/backends/aoti/slim/core/SlimTensor.h +++ b/backends/aoti/slim/core/SlimTensor.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -277,69 +278,67 @@ class SlimTensor { * Copy data from another tensor to this tensor. * * Both tensors must have the same numel and dtype. - * Currently only supports CPU-to-CPU copy (contiguous tensors only). + * Supports CPU-to-CPU and cross-device copies (CPU↔CUDA, CUDA↔CUDA). * * @param other The source tensor to copy from * @return Reference to this tensor */ SlimTensor& copy_(const SlimTensor& other) { ET_CHECK_MSG( - this->numel() == other.numel(), - "copy_: numel mismatch (dst=%zu, src=%zu)", - this->numel(), - other.numel()); - ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype mismatch"); + this->numel() == other.numel(), "copy_: numel of tensors must match"); + ET_CHECK_MSG(this->dtype() == other.dtype(), "copy_: dtype must match"); if (this->numel() == 0) { return *this; } - // Current we only support CPU-only tensors - // TODO(gasoonjia): support other device types. - ET_CHECK_MSG( - this->is_cpu() && other.is_cpu(), "copy_: only CPU tensors supported"); - + // Case 1: Both tensors are contiguous. We can do a fast bulk copy. if (this->is_contiguous() && other.is_contiguous()) { - // Fast path: both tensors are contiguous, use memcpy - std::memcpy(this->data_ptr(), other.data_ptr(), other.nbytes()); - } else { - // Slow path: element-wise copy for non-contiguous tensors - copy_strided_(other); + storage_->copy_( + this->data_ptr(), other.data_ptr(), other.nbytes(), other.device()); + return *this; } - return *this; - } - - private: - /** - * Element-wise copy for non-contiguous tensors. - */ - void copy_strided_(const SlimTensor& other) { + // Case 2: At least one tensor is non-contiguous, perform element-wise copy + // that respects both source and destination strides. const size_t elem_size = c10::elementSize(dtype_); char* dst_data = static_cast(this->data_ptr()); const char* src_data = static_cast(other.data_ptr()); std::vector counter(this->dim(), 0); for (size_t i = 0; i < this->numel(); i++) { - // Compute source offset + // Compute src offset in elements int64_t src_offset = 0; for (size_t d = 0; d < other.dim(); d++) { - src_offset += counter[d] * other.stride(static_cast(d)); + src_offset += counter[d] * other.stride(d); } - // Compute destination offset + // Compute dst offset in elements int64_t dst_offset = 0; for (size_t d = 0; d < this->dim(); d++) { - dst_offset += counter[d] * this->stride(static_cast(d)); + dst_offset += counter[d] * this->stride(d); } - // Copy single element - std::memcpy( - dst_data + dst_offset * static_cast(elem_size), - src_data + src_offset * static_cast(elem_size), - elem_size); - - // Increment multi-dimensional counter + // Copy elem_size bytes from src to dst + if (this->device().is_cpu() && other.device().is_cpu()) { + std::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size); + } else if (this->device().is_cuda() || other.device().is_cuda()) { +#if defined(CUDA_AVAILABLE) + DeviceTraits::memcpy( + dst_data + dst_offset * elem_size, + src_data + src_offset * elem_size, + elem_size, + device(), // dst device + other.device() // src device + ); +#else + ET_CHECK_MSG(false, "Failed on copy_ cuda tensors: no CUDA support"); +#endif + } + // Increment the multi-dimensional counter for (int64_t d = static_cast(this->dim()) - 1; d >= 0; --d) { counter[d]++; if (counter[d] < this->size(d)) { @@ -348,8 +347,10 @@ class SlimTensor { counter[d] = 0; } } + return *this; } + private: void refresh_numel() { numel_ = compute_numel(sizes_and_strides_.sizes_arrayref()); } diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h index 121031a6d59..da9bf0638c1 100644 --- a/backends/aoti/slim/core/Storage.h +++ b/backends/aoti/slim/core/Storage.h @@ -270,12 +270,15 @@ class MaybeOwningStorage { return; } - ET_CHECK_MSG( - device_.is_cpu() && src_device.is_cpu(), - "Only CPU-to-CPU copy is currently supported"); - - DeviceTraits::memcpy( - dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + if (device_.is_cpu() && src_device.is_cpu()) { + // CPU to CPU copy + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } else { + // At least one of the devices is CUDA + DeviceTraits::memcpy( + dst_data_ptr, src_data_ptr, nbytes, device_, src_device); + } } /// Creates a clone of this storage on the specified device. diff --git a/backends/aoti/slim/core/test/targets.bzl b/backends/aoti/slim/core/test/targets.bzl index 3400fd943e8..d0991708c7f 100644 --- a/backends/aoti/slim/core/test/targets.bzl +++ b/backends/aoti/slim/core/test/targets.bzl @@ -44,16 +44,18 @@ def define_common_targets(): **backend_kwargs ) - runtime.cxx_test( - name = "test_slimtensor_copy", - srcs = [ - "test_slimtensor_copy.cpp", - ], - deps = [ - "//executorch/backends/aoti/slim/core:slimtensor", - "//executorch/backends/aoti/slim/core:storage", - ], - ) + runtime.cxx_test( + name = "test_slimtensor_copy" + backend_suffix, + srcs = [ + "test_slimtensor_copy.cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/core:slimtensor", + "//executorch/backends/aoti/slim/core:storage", + "//executorch/backends/aoti/slim/factory:empty", + ], + **backend_kwargs + ) runtime.cxx_test( name = "test_slimtensor_dtypes", diff --git a/backends/aoti/slim/core/test/test_slimtensor_copy.cpp b/backends/aoti/slim/core/test/test_slimtensor_copy.cpp index f227f954798..6d2ed745446 100644 --- a/backends/aoti/slim/core/test/test_slimtensor_copy.cpp +++ b/backends/aoti/slim/core/test/test_slimtensor_copy.cpp @@ -10,6 +10,7 @@ #include #include +#include namespace executorch::backends::aoti::slim { @@ -256,4 +257,379 @@ TEST(SlimTensorCopyTest, CopyWithStorageOffset) { EXPECT_FLOAT_EQ(dst_base[23], 4.0f); } +// ============================================================================= +// CUDA Tensor Creation Tests +// These tests verify CUDA tensor creation and the is_cuda() method. +// When CUDA_AVAILABLE is not defined, CUDA operations abort with an error. +// ============================================================================= + +#ifdef CUDA_AVAILABLE + +TEST(CUDATensorTest, CreateEmptyCUDATensor) { + auto tensor = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.defined()); + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_FALSE(tensor.is_cpu()); + EXPECT_EQ(tensor.dim(), 2); + EXPECT_EQ(tensor.size(0), 2); + EXPECT_EQ(tensor.size(1), 3); + EXPECT_EQ(tensor.numel(), 6); + EXPECT_TRUE(tensor.is_contiguous()); + EXPECT_EQ(tensor.device().type(), c10::DeviceType::CUDA); + EXPECT_EQ(tensor.device().index(), 0); +} + +TEST(CUDATensorTest, CreateEmptyStridedCUDATensor) { + std::vector sizes = {2, 4}; + std::vector strides = {4, 1}; + + auto tensor = empty_strided( + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Float, + DEFAULT_CUDA_DEVICE); + + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_EQ(tensor.stride(0), 4); + EXPECT_EQ(tensor.stride(1), 1); + EXPECT_EQ(tensor.numel(), 8); +} + +TEST(CUDATensorTest, CreateCUDATensorWithDeviceIndex) { + c10::Device device(c10::DeviceType::CUDA, 0); + auto tensor = empty({4, 4}, c10::ScalarType::Float, device); + + EXPECT_TRUE(tensor.is_cuda()); + EXPECT_EQ(tensor.device_index(), 0); +} + +// ============================================================================= +// Cross-Device Copy Tests +// ============================================================================= + +TEST(CUDACopyTest, CopyFromCPUToCUDA) { + constexpr size_t kNumFloats = 6; + auto cpu_tensor = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + + // Fill CPU tensor with known values + float* cpu_data = static_cast(cpu_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + cpu_data[i] = static_cast(i) * 1.5f; + } + + // Create CUDA tensor and copy from CPU + auto cuda_tensor = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_tensor); + + // Copy back to CPU to verify + auto verify_tensor = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + verify_tensor.copy_(cuda_tensor); + + float* verify_data = static_cast(verify_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(verify_data[i], static_cast(i) * 1.5f); + } +} + +TEST(CUDACopyTest, CopyFromCUDAToCPU) { + constexpr size_t kNumFloats = 4; + auto cpu_src = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + + float* src_data = static_cast(cpu_src.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + src_data[i] = static_cast(i) + 100.0f; + } + + // Copy to CUDA + auto cuda_tensor = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_src); + + // Copy back to new CPU tensor + auto cpu_dst = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + cpu_dst.copy_(cuda_tensor); + + float* dst_data = static_cast(cpu_dst.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i) + 100.0f); + } +} + +TEST(CUDACopyTest, CopyFromCUDAToCUDA) { + constexpr size_t kNumFloats = 4; + auto cpu_tensor = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + + float* cpu_data = static_cast(cpu_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + cpu_data[i] = static_cast(i) * 2.0f; + } + + // Create first CUDA tensor from CPU + auto cuda_src = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_src.copy_(cpu_tensor); + + // Copy to second CUDA tensor + auto cuda_dst = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_dst.copy_(cuda_src); + + // Verify by copying back to CPU + auto verify_tensor = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + verify_tensor.copy_(cuda_dst); + + float* verify_data = static_cast(verify_tensor.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + EXPECT_FLOAT_EQ(verify_data[i], static_cast(i) * 2.0f); + } +} + +TEST(CUDACopyTest, CopyDifferentDtypes) { + auto cpu_int = empty({4}, c10::ScalarType::Int, CPU_DEVICE); + int32_t* int_data = static_cast(cpu_int.data_ptr()); + for (int i = 0; i < 4; ++i) { + int_data[i] = i * 10; + } + + auto cuda_int = empty({4}, c10::ScalarType::Int, DEFAULT_CUDA_DEVICE); + cuda_int.copy_(cpu_int); + + auto verify_int = empty({4}, c10::ScalarType::Int, CPU_DEVICE); + verify_int.copy_(cuda_int); + + int32_t* verify_data = static_cast(verify_int.data_ptr()); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(verify_data[i], i * 10); + } +} + +TEST(CUDACopyTest, CopyEmptyTensor) { + auto cpu_empty = empty({0}, c10::ScalarType::Float, CPU_DEVICE); + auto cuda_empty = empty({0}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + + // Should not crash + cuda_empty.copy_(cpu_empty); + cpu_empty.copy_(cuda_empty); + + EXPECT_EQ(cpu_empty.numel(), 0); + EXPECT_EQ(cuda_empty.numel(), 0); +} + +// ============================================================================= +// Non-Contiguous Cross-Device Copy Tests +// These tests verify copying non-contiguous CPU tensors to/from CUDA tensors. +// The CUDA tensor must be contiguous, but the CPU tensor can be non-contiguous. +// ============================================================================= + +TEST(CUDACopyTest, CopyNonContiguousCPUToCUDA) { + // Create a transposed (non-contiguous) CPU source tensor + // Logical shape: 2x3, but stored transposed in memory + std::vector src_sizes = {2, 3}; + std::vector src_strides = {1, 2}; // Transposed strides + + Storage src_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 6 * sizeof(float))); + float* src_data = static_cast(src_storage->data()); + // Physical layout for transposed tensor: + // Logical[0,0]=Physical[0], Logical[1,0]=Physical[1] + // Logical[0,1]=Physical[2], Logical[1,1]=Physical[3] + // Logical[0,2]=Physical[4], Logical[1,2]=Physical[5] + src_data[0] = 1.0f; // Logical[0,0] + src_data[1] = 4.0f; // Logical[1,0] + src_data[2] = 2.0f; // Logical[0,1] + src_data[3] = 5.0f; // Logical[1,1] + src_data[4] = 3.0f; // Logical[0,2] + src_data[5] = 6.0f; // Logical[1,2] + + SlimTensor cpu_src( + std::move(src_storage), + makeArrayRef(src_sizes), + makeArrayRef(src_strides), + c10::ScalarType::Float); + + ASSERT_FALSE(cpu_src.is_contiguous()); + + // Create a contiguous CUDA destination + auto cuda_dst = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + ASSERT_TRUE(cuda_dst.is_contiguous()); + + // Copy non-contiguous CPU → contiguous CUDA + cuda_dst.copy_(cpu_src); + + // Verify by copying back to CPU + auto verify = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + verify.copy_(cuda_dst); + + // Values should be in logical order (contiguous layout) + float* verify_data = static_cast(verify.data_ptr()); + EXPECT_FLOAT_EQ(verify_data[0], 1.0f); // [0,0] + EXPECT_FLOAT_EQ(verify_data[1], 2.0f); // [0,1] + EXPECT_FLOAT_EQ(verify_data[2], 3.0f); // [0,2] + EXPECT_FLOAT_EQ(verify_data[3], 4.0f); // [1,0] + EXPECT_FLOAT_EQ(verify_data[4], 5.0f); // [1,1] + EXPECT_FLOAT_EQ(verify_data[5], 6.0f); // [1,2] +} + +TEST(CUDACopyTest, CopyCUDAToNonContiguousCPU) { + constexpr size_t kNumFloats = 6; + + // Create a contiguous CPU source, copy to CUDA + auto cpu_src = empty({2, 3}, c10::ScalarType::Float, CPU_DEVICE); + float* src_data = static_cast(cpu_src.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + src_data[i] = static_cast(i + 1); + } + + auto cuda_tensor = empty({2, 3}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_src); + + // Create a transposed (non-contiguous) CPU destination + std::vector dst_sizes = {2, 3}; + std::vector dst_strides = {1, 2}; // Transposed strides + + Storage dst_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 6 * sizeof(float))); + SlimTensor cpu_dst( + std::move(dst_storage), + makeArrayRef(dst_sizes), + makeArrayRef(dst_strides), + c10::ScalarType::Float); + + ASSERT_FALSE(cpu_dst.is_contiguous()); + + // Copy contiguous CUDA → non-contiguous CPU + cpu_dst.copy_(cuda_tensor); + + // Verify physical layout matches transposed storage + float* dst_data = static_cast(cpu_dst.storage()->data()); + // Physical layout: [1,4,2,5,3,6] for logical [[1,2,3],[4,5,6]] + EXPECT_FLOAT_EQ(dst_data[0], 1.0f); // Logical[0,0] + EXPECT_FLOAT_EQ(dst_data[1], 4.0f); // Logical[1,0] + EXPECT_FLOAT_EQ(dst_data[2], 2.0f); // Logical[0,1] + EXPECT_FLOAT_EQ(dst_data[3], 5.0f); // Logical[1,1] + EXPECT_FLOAT_EQ(dst_data[4], 3.0f); // Logical[0,2] + EXPECT_FLOAT_EQ(dst_data[5], 6.0f); // Logical[1,2] +} + +TEST(CUDACopyTest, CopyNonContiguousCPUToCUDA3D) { + // Test 3D non-contiguous tensor copy + std::vector sizes = {2, 2, 2}; + // Permuted strides (e.g., from permute(2, 0, 1)) + std::vector non_contig_strides = {2, 1, 4}; + + Storage src_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 8 * sizeof(float))); + float* src_data = static_cast(src_storage->data()); + // Fill with values 1-8 + for (int i = 0; i < 8; ++i) { + src_data[i] = static_cast(i + 1); + } + + SlimTensor cpu_src( + std::move(src_storage), + makeArrayRef(sizes), + makeArrayRef(non_contig_strides), + c10::ScalarType::Float); + + ASSERT_FALSE(cpu_src.is_contiguous()); + + auto cuda_dst = empty({2, 2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_dst.copy_(cpu_src); + + // Copy back and verify the logical order is preserved + auto verify = empty({2, 2, 2}, c10::ScalarType::Float, CPU_DEVICE); + verify.copy_(cuda_dst); + + // Access elements via strided indexing on source + float* verify_data = static_cast(verify.data_ptr()); + + // Verify a few key positions + // The values should match the logical traversal of the source tensor + EXPECT_NE(verify_data[0], 0.0f); // Should have data + EXPECT_EQ(verify.numel(), 8); +} + +TEST(CUDACopyTest, CopyCUDAToNonContiguousCPUWithOffset) { + // Test with storage offset + constexpr size_t kNumFloats = 4; + + auto cpu_src = empty({2, 2}, c10::ScalarType::Float, CPU_DEVICE); + float* src_data = static_cast(cpu_src.data_ptr()); + for (size_t i = 0; i < kNumFloats; ++i) { + src_data[i] = static_cast(i + 10); + } + + auto cuda_tensor = empty({2, 2}, c10::ScalarType::Float, DEFAULT_CUDA_DEVICE); + cuda_tensor.copy_(cpu_src); + + // Create non-contiguous destination with storage offset + std::vector dst_sizes = {2, 2}; + std::vector dst_strides = {1, 2}; // Transposed + + Storage dst_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 10 * sizeof(float))); + SlimTensor cpu_dst( + std::move(dst_storage), + makeArrayRef(dst_sizes), + makeArrayRef(dst_strides), + c10::ScalarType::Float, + 2); // offset of 2 elements + + ASSERT_FALSE(cpu_dst.is_contiguous()); + + cpu_dst.copy_(cuda_tensor); + + // Verify data starts at offset + float* raw_data = static_cast(cpu_dst.storage()->data()); + float* offset_data = static_cast(cpu_dst.data_ptr()); + + // offset_data should be 2 elements after raw_data + EXPECT_EQ(offset_data, raw_data + 2); + + // Verify transposed layout at offset + EXPECT_FLOAT_EQ(offset_data[0], 10.0f); // Logical[0,0] + EXPECT_FLOAT_EQ(offset_data[1], 12.0f); // Logical[1,0] + EXPECT_FLOAT_EQ(offset_data[2], 11.0f); // Logical[0,1] + EXPECT_FLOAT_EQ(offset_data[3], 13.0f); // Logical[1,1] +} + +TEST(CUDACopyTest, CopyNonContiguousCPUToCUDAInt64) { + // Test with different dtype (int64) + std::vector sizes = {2, 3}; + std::vector strides = {1, 2}; // Transposed + + Storage src_storage = + Storage(new MaybeOwningStorage(CPU_DEVICE, 6 * sizeof(int64_t))); + int64_t* src_data = static_cast(src_storage->data()); + // Fill transposed layout + src_data[0] = 100; // Logical[0,0] + src_data[1] = 400; // Logical[1,0] + src_data[2] = 200; // Logical[0,1] + src_data[3] = 500; // Logical[1,1] + src_data[4] = 300; // Logical[0,2] + src_data[5] = 600; // Logical[1,2] + + SlimTensor cpu_src( + std::move(src_storage), + makeArrayRef(sizes), + makeArrayRef(strides), + c10::ScalarType::Long); + + ASSERT_FALSE(cpu_src.is_contiguous()); + + auto cuda_dst = empty({2, 3}, c10::ScalarType::Long, DEFAULT_CUDA_DEVICE); + cuda_dst.copy_(cpu_src); + + auto verify = empty({2, 3}, c10::ScalarType::Long, CPU_DEVICE); + verify.copy_(cuda_dst); + + int64_t* verify_data = static_cast(verify.data_ptr()); + EXPECT_EQ(verify_data[0], 100); + EXPECT_EQ(verify_data[1], 200); + EXPECT_EQ(verify_data[2], 300); + EXPECT_EQ(verify_data[3], 400); + EXPECT_EQ(verify_data[4], 500); + EXPECT_EQ(verify_data[5], 600); +} + +#endif // CUDA_AVAILABLE + } // namespace executorch::backends::aoti::slim