From 1e2ee90e4aaccfd2107de78c546914634499d37c Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 5 Jan 2026 11:03:38 -0800 Subject: [PATCH] [slimtensor] Add aoti_torch_copy_ for SlimTensor Add SlimTensor-based `aoti_torch_copy_()` - Copies data from source tensor to destination tensor. Delegates to SlimTensor's `copy_()` which handles all device combinations (CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA). Differential Revision: [D90126246](https://our.internmc.facebook.com/intern/diff/D90126246/) [ghstack-poisoned] --- backends/cuda/runtime/shims/memory_slim.cpp | 20 + backends/cuda/runtime/shims/memory_slim.h | 15 + backends/cuda/runtime/shims/tests/targets.bzl | 1 + .../tests/test_aoti_torch_copy__slim.cpp | 487 ++++++++++++++++++ 4 files changed, 523 insertions(+) create mode 100644 backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp diff --git a/backends/cuda/runtime/shims/memory_slim.cpp b/backends/cuda/runtime/shims/memory_slim.cpp index eb9f485ca07..fbf2bd8ed9c 100644 --- a/backends/cuda/runtime/shims/memory_slim.cpp +++ b/backends/cuda/runtime/shims/memory_slim.cpp @@ -186,6 +186,26 @@ AOTITorchError aoti_torch__reinterpret_tensor( return Error::Ok; } +AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) { + (void)non_blocking; // SlimTensor::copy_() is always synchronous for now + + ET_CHECK_OR_RETURN_ERROR( + self != nullptr, InvalidArgument, "aoti_torch_copy_: self is null"); + + ET_CHECK_OR_RETURN_ERROR( + src != nullptr, InvalidArgument, "aoti_torch_copy_: src is null"); + + // SlimTensor::copy_() handles: + // - Same numel validation + // - Same dtype validation + // - CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA copies + // - Contiguous fast path and non-contiguous element-wise copy + self->copy_(*src); + + return Error::Ok; +} + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/memory_slim.h b/backends/cuda/runtime/shims/memory_slim.h index 64a7a561141..3c6a58fb783 100644 --- a/backends/cuda/runtime/shims/memory_slim.h +++ b/backends/cuda/runtime/shims/memory_slim.h @@ -128,6 +128,21 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( int64_t storage_offset, Tensor** ret_new_tensor); +/** + * Copies data from source tensor to destination tensor. + * + * Handles all device combinations (CPU-CPU, CPU-CUDA, CUDA-CPU, CUDA-CUDA) + * and supports tensors with different strides. The destination tensor must + * already be allocated with sufficient storage. + * + * @param self Destination tensor (must not be null) + * @param src Source tensor to copy from (must not be null) + * @param non_blocking If true, the copy may be asynchronous (currently ignored) + * @return AOTITorchError error code (Error::Ok on success) + */ +AOTI_SHIM_EXPORT AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); + } // extern "C" } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index ce9f8fcc647..099759d0649 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -76,3 +76,4 @@ def define_common_targets(): cuda_shim_slim_cpp_unittest("aoti_torch_delete_tensor_object") cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle") cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor") + cuda_shim_slim_cpp_unittest("aoti_torch_copy_") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp new file mode 100644 index 00000000000..c2e67732b41 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy__slim.cpp @@ -0,0 +1,487 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using executorch::runtime::Error; + +namespace slim_c10 = executorch::backends::aoti::slim::c10; + +namespace { + +bool isCudaAvailable() { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + return (err == cudaSuccess && device_count > 0); +} + +std::vector calculateContiguousStrides( + const std::vector& sizes) { + std::vector strides(sizes.size()); + if (sizes.empty()) { + return strides; + } + strides[sizes.size() - 1] = 1; + for (int64_t i = static_cast(sizes.size()) - 2; i >= 0; i--) { + strides[i] = strides[i + 1] * sizes[i + 1]; + } + return strides; +} + +} // namespace + +class AOTITorchCopySlimTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + } + + Tensor* createTestTensor( + const std::vector& sizes, + const std::vector& strides = {}, + int32_t dtype = static_cast(slim_c10::ScalarType::Float), + int32_t device_type = static_cast(slim_c10::DeviceType::CPU), + int32_t device_index = 0) { + Tensor* tensor = nullptr; + + std::vector effective_strides = strides; + if (strides.empty()) { + effective_strides = calculateContiguousStrides(sizes); + } + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + effective_strides.data(), + dtype, + device_type, + device_index, + &tensor); + + return (error == Error::Ok) ? tensor : nullptr; + } +}; + +// ============================================================================ +// Basic Functionality Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, BasicCopy_CPU) { + std::vector sizes = {3, 4}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i + 1); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i + 1)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, NullSelf) { + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + AOTITorchError error = aoti_torch_copy_(nullptr, src, 0); + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, NullSrc) { + std::vector sizes = {2, 3}; + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, nullptr, 0); + EXPECT_EQ(error, Error::InvalidArgument); + + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// Different Dtype Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, Int64Copy_CPU) { + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Long), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + int64_t* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = i * 100; + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Long), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + int64_t* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_EQ(dst_data[i], i * 100); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, BoolCopy_CPU) { + std::vector sizes = {4}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Bool), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + bool* src_data = static_cast(src->data_ptr()); + src_data[0] = true; + src_data[1] = false; + src_data[2] = true; + src_data[3] = false; + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Bool), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + bool* dst_data = static_cast(dst->data_ptr()); + EXPECT_EQ(dst_data[0], true); + EXPECT_EQ(dst_data[1], false); + EXPECT_EQ(dst_data[2], true); + EXPECT_EQ(dst_data[3], false); + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// Tensor Shape Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, ScalarTensorCopy_CPU) { + std::vector sizes = {}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + EXPECT_EQ(src->dim(), 0); + EXPECT_EQ(src->numel(), 1); + + float* src_data = static_cast(src->data_ptr()); + *src_data = 42.0f; + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + EXPECT_FLOAT_EQ(*dst_data, 42.0f); + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, LargeTensorCopy_CPU) { + std::vector sizes = {100, 100}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// CUDA Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, CudaToCuda) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {3, 4}; + + std::vector host_src_data(12); + for (size_t i = 0; i < host_src_data.size(); i++) { + host_src_data[i] = static_cast(i + 1); + } + + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(src, nullptr); + EXPECT_TRUE(src->is_cuda()); + + cudaMemcpy( + src->data_ptr(), + host_src_data.data(), + host_src_data.size() * sizeof(float), + cudaMemcpyHostToDevice); + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(dst, nullptr); + EXPECT_TRUE(dst->is_cuda()); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + std::vector host_dst_data(12); + cudaMemcpy( + host_dst_data.data(), + dst->data_ptr(), + host_dst_data.size() * sizeof(float), + cudaMemcpyDeviceToHost); + + for (size_t i = 0; i < host_dst_data.size(); i++) { + EXPECT_FLOAT_EQ(host_dst_data[i], static_cast(i + 1)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, CpuToCuda) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + EXPECT_TRUE(src->is_cpu()); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i * 10); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(dst, nullptr); + EXPECT_TRUE(dst->is_cuda()); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + std::vector host_dst_data(6); + cudaMemcpy( + host_dst_data.data(), + dst->data_ptr(), + host_dst_data.size() * sizeof(float), + cudaMemcpyDeviceToHost); + + for (size_t i = 0; i < host_dst_data.size(); i++) { + EXPECT_FLOAT_EQ(host_dst_data[i], static_cast(i * 10)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +TEST_F(AOTITorchCopySlimTest, CudaToCpu) { + if (!isCudaAvailable()) { + GTEST_SKIP() << "CUDA not available"; + } + + std::vector sizes = {2, 3}; + + std::vector host_src_data(6); + for (size_t i = 0; i < host_src_data.size(); i++) { + host_src_data[i] = static_cast(i * 5); + } + + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CUDA), + 0); + ASSERT_NE(src, nullptr); + + cudaMemcpy( + src->data_ptr(), + host_src_data.data(), + host_src_data.size() * sizeof(float), + cudaMemcpyHostToDevice); + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + EXPECT_TRUE(dst->is_cpu()); + + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i * 5)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +} + +// ============================================================================ +// Non-blocking Tests +// ============================================================================ + +TEST_F(AOTITorchCopySlimTest, NonBlockingFlag_CPU) { + std::vector sizes = {2, 3}; + Tensor* src = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(src, nullptr); + + float* src_data = static_cast(src->data_ptr()); + for (int64_t i = 0; i < src->numel(); i++) { + src_data[i] = static_cast(i); + } + + Tensor* dst = createTestTensor( + sizes, + {}, + static_cast(slim_c10::ScalarType::Float), + static_cast(slim_c10::DeviceType::CPU), + 0); + ASSERT_NE(dst, nullptr); + + AOTITorchError error = aoti_torch_copy_(dst, src, 1); + EXPECT_EQ(error, Error::Ok); + + float* dst_data = static_cast(dst->data_ptr()); + for (int64_t i = 0; i < dst->numel(); i++) { + EXPECT_FLOAT_EQ(dst_data[i], static_cast(i)); + } + + EXPECT_EQ(aoti_torch_delete_tensor_object(src), Error::Ok); + EXPECT_EQ(aoti_torch_delete_tensor_object(dst), Error::Ok); +}