diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index 498a31d42aa..b70a63f579a 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -27,6 +27,8 @@ using executorch::aten::SizesType; using executorch::aten::StridesType; using executorch::backends::aoti::aoti_torch_get_device_index; using executorch::backends::aoti::aoti_torch_get_dtype; +using executorch::backends::aoti::aoti_torch_get_sizes; +using executorch::backends::aoti::aoti_torch_get_strides; using executorch::backends::aoti::dtype_to_element_size; using executorch::backends::aoti::dtype_to_scalar_type; using executorch::backends::aoti::validate_storage_offset; @@ -40,6 +42,67 @@ std::unordered_set> tensors; constexpr int32_t NOT_OWN = -1; std::unordered_map memory_to_n_tensor; +namespace { + +// Calculate linear offset from strides and indices +int64_t calculate_linear_offset( + const int64_t* indices, + const int64_t* strides, + int64_t ndim) { + int64_t offset = 0; + for (int64_t i = 0; i < ndim; ++i) { + offset += indices[i] * strides[i]; + } + return offset; +} + +// Convert linear index to multi-dimensional indices based on sizes +void linear_to_indices( + int64_t linear_idx, + const int64_t* sizes, + int64_t ndim, + int64_t* indices) { + for (int64_t i = ndim - 1; i >= 0; --i) { + indices[i] = linear_idx % sizes[i]; + linear_idx /= sizes[i]; + } +} + +// Generic pointwise copy function that handles arbitrary strides +template +AOTITorchError pointwise_copy_generic( + T* dst_data, + const T* src_data, + const int64_t* dst_sizes, + const int64_t* dst_strides, + const int64_t* src_sizes, + const int64_t* src_strides, + int64_t dst_ndim, + int64_t src_ndim, + int64_t total_elements) { + std::vector dst_indices(dst_ndim); + std::vector src_indices(src_ndim); + + for (int64_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) { + // Convert linear index to multi-dimensional indices for both tensors + linear_to_indices(linear_idx, dst_sizes, dst_ndim, dst_indices.data()); + linear_to_indices(linear_idx, src_sizes, src_ndim, src_indices.data()); + + // Calculate offsets for both source and destination + int64_t src_offset = + calculate_linear_offset(src_indices.data(), src_strides, src_ndim); + int64_t dst_offset = + calculate_linear_offset(dst_indices.data(), dst_strides, dst_ndim); + + // Copy element + dst_data[dst_offset] = src_data[src_offset]; + } + + return Error::Ok; +} + +} // anonymous namespace + extern "C" { AOTITorchError aoti_torch_create_tensor_from_blob_v2( @@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided( } int64_t nbytes = numel * element_size; - if (device_type == 1) { // cuda - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes)); - } else if (device_type == 0) { // cpu + if (device_type == static_cast(SupportedDevices::CUDA)) { + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaMallocManaged(&ptr, static_cast(nbytes))); + } else if (device_type == static_cast(SupportedDevices::CPU)) { // Ensure 16-byte alignment for CPU memory to match CUDA requirements int result = posix_memalign(&ptr, 16, nbytes); if (result != 0) { @@ -312,6 +376,207 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) { return Error::Internal; } +AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) { + (void)non_blocking; + + // Check for null pointers first + if (self == nullptr) { + ET_LOG(Error, "aoti_torch_copy_ failed: self tensor is null"); + return Error::InvalidArgument; + } + + if (src == nullptr) { + ET_LOG(Error, "aoti_torch_copy_ failed: src tensor is null"); + return Error::InvalidArgument; + } + + // Get dtype information and validate compatibility + int32_t self_dtype, src_dtype; + aoti_torch_get_dtype(self, &self_dtype); + aoti_torch_get_dtype(src, &src_dtype); + + AOTITorchError self_dtype_error = validate_dtype(self_dtype); + if (self_dtype_error != Error::Ok) { + return self_dtype_error; + } + + AOTITorchError src_dtype_error = validate_dtype(src_dtype); + if (src_dtype_error != Error::Ok) { + return src_dtype_error; + } + + // Check dtype compatibility - both tensors must have the same dtype + if (self_dtype != src_dtype) { + ET_LOG( + Error, + "dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes", + self_dtype, + src_dtype); + return Error::InvalidArgument; + } + + // Check total number of elements compatibility (PyTorch copy_ behavior) + int64_t self_numel = self->numel(); + int64_t src_numel = src->numel(); + + if (self_numel != src_numel) { + ET_LOG( + Error, + "numel mismatch. self.numel()=%ld, src.numel()=%ld", + self_numel, + src_numel); + return Error::InvalidArgument; + } + + // Get tensor metadata + int64_t* self_strides; + int64_t* src_strides; + aoti_torch_get_strides(self, &self_strides); + aoti_torch_get_strides(src, &src_strides); + + int64_t* self_sizes; + int64_t* src_sizes; + aoti_torch_get_sizes(self, &self_sizes); + aoti_torch_get_sizes(src, &src_sizes); + + // Determine device locations + cudaPointerAttributes srcAttributes{}; + cudaPointerAttributes dstAttributes{}; + + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&srcAttributes, src->data_ptr())); + + ET_CUDA_CHECK_OR_RETURN_ERROR( + cudaPointerGetAttributes(&dstAttributes, self->data_ptr())); + + bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice; + bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice; + + // Check if tensors have the same schema (sizes, strides, dtype) for fast path + bool same_schema = true; + for (int i = 0; i < self->dim(); i++) { + if (self_strides[i] != src_strides[i]) { + same_schema = false; + break; + } + } + + size_t total_bytes = src->nbytes(); + int64_t total_elements = self->numel(); + + if (same_schema) { + // Fast path: Direct memory copy since layouts match exactly + if (srcIsDevice && dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToDevice)); + } else if (srcIsDevice && !dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyDeviceToHost)); + } else if (!srcIsDevice && dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + src->data_ptr(), + total_bytes, + cudaMemcpyHostToDevice)); + } else { + std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes); + } + } else { + // Fallback path: Pointwise copy with stride-aware indexing + // This handles arbitrary tensor layouts and strides + + size_t element_size = dtype_to_element_size(self_dtype); + if (element_size == 0) { + ET_LOG(Error, "Invalid element size for dtype: %d", self_dtype); + return Error::InvalidArgument; + } + + // Allocate temporary host memory for GPU tensors + float* src_host_data = nullptr; + float* dst_host_data = nullptr; + bool need_free_src = false; + bool need_free_dst = false; + + if (srcIsDevice) { + src_host_data = + static_cast(malloc(total_elements * sizeof(float))); + if (src_host_data == nullptr) { + ET_LOG(Error, "Failed to allocate memory for src_host_data"); + return Error::MemoryAllocationFailed; + } + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + src_host_data, src->data_ptr(), total_bytes, cudaMemcpyDeviceToHost)); + need_free_src = true; + } else { + src_host_data = static_cast(src->data_ptr()); + } + + if (dstIsDevice) { + dst_host_data = + static_cast(malloc(total_elements * sizeof(float))); + if (dst_host_data == nullptr) { + ET_LOG(Error, "Failed to allocate memory for dst_host_data"); + if (need_free_src) { + free(src_host_data); + } + return Error::MemoryAllocationFailed; + } + need_free_dst = true; + } else { + dst_host_data = static_cast(self->mutable_data_ptr()); + } + + // Perform pointwise copy with stride calculation + AOTITorchError copy_err = pointwise_copy_generic( + dst_host_data, + src_host_data, + self_sizes, + self_strides, + src_sizes, + src_strides, + self->dim(), + src->dim(), + total_elements); + + if (copy_err != Error::Ok) { + // Clean up temporary buffers before returning + if (need_free_src) { + free(src_host_data); + } + if (need_free_dst) { + free(dst_host_data); + } + return copy_err; + } + + // Copy result back to device if needed + if (dstIsDevice) { + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy( + self->mutable_data_ptr(), + dst_host_data, + total_bytes, + cudaMemcpyHostToDevice)); + } + + // Clean up temporary buffers + if (need_free_src) { + free(src_host_data); + } + if (need_free_dst) { + free(dst_host_data); + } + } + + return Error::Ok; +} + AOTITorchError aoti_torch__reinterpret_tensor( Tensor* self, int64_t ndim, diff --git a/backends/cuda/runtime/shims/memory.h b/backends/cuda/runtime/shims/memory.h index 4e9780840e1..bcec6621285 100644 --- a/backends/cuda/runtime/shims/memory.h +++ b/backends/cuda/runtime/shims/memory.h @@ -116,6 +116,31 @@ AOTITorchError aoti_torch__reinterpret_tensor( int64_t storage_offset, Tensor** ret_new_tensor); +/** + * Copies data from source tensor to destination tensor. + * + * This function implements copy function for tensors living in CUDA AOTI + * backend. It supports copying between tensors with different shapes (as long + * as they have the same total number of elements) and different memory + * layouts/strides. + * + * Note that currently this function does not support copying between tensors + * with different dtypes. + * + * @param self Destination tensor (data will be overwritten) + * @param src Source tensor (data will be copied from this tensor) + * @param non_blocking Whether the copy should be non-blocking (currently + * ignored) + * + * @return Error::Ok on success, appropriate error code on failure: + * - Error::InvalidArgument: null pointers, dtype mismatch, numel + * mismatch + * - Error::MemoryAllocationFailed: failed to allocate temporary memory + * - Error::Internal: CUDA operation failures + */ +AOTITorchError +aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking); + // Function to clear all tensors from internal storage void clear_all_tensors(); } // extern "C" diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index ac6d2072d58..fcb95a0beb7 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -31,3 +31,4 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object") cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2") cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor") + cuda_shim_cpp_unittest("aoti_torch_copy_") diff --git a/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp new file mode 100644 index 00000000000..7579eaef039 --- /dev/null +++ b/backends/cuda/runtime/shims/tests/test_aoti_torch_copy_.cpp @@ -0,0 +1,398 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::backends::aoti; +using namespace executorch::runtime; + +// Test fixture for aoti_torch_copy_ tests +class AOTITorchCopyTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize ExecuTorch Platform Abstraction Layer + et_pal_init(); + + // Check if CUDA is available + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available, skipping CUDA tests"; + } + + // Clean up any existing cached metadata before each test + cleanup_tensor_metadata(); + + // Clear any remaining tensors from previous tests + clear_all_tensors(); + } + + void TearDown() override { + // Clean up metadata + cleanup_tensor_metadata(); + + // Clear the global tensor storage using the provided function + clear_all_tensors(); + } + + // Helper to create test tensors with specific data + Tensor* create_test_tensor_with_data( + const std::vector& sizes, + const std::vector& data, + const std::vector& strides = {}, + int32_t dtype = static_cast(SupportedDTypes::FLOAT32), + int32_t device_type = static_cast(SupportedDevices::CUDA), + int32_t device_index = 0) { + Tensor* tensor; + + const int64_t* strides_ptr = strides.empty() ? nullptr : strides.data(); + + AOTITorchError error = aoti_torch_empty_strided( + sizes.size(), + sizes.data(), + strides_ptr, + dtype, + device_type, + device_index, + &tensor); + + if (error != Error::Ok || tensor == nullptr) { + return nullptr; + } + + // Fill tensor with data + size_t total_bytes = data.size() * sizeof(float); + if (device_type == static_cast(SupportedDevices::CUDA)) { + cudaError_t memcpy_err = cudaMemcpy( + tensor->mutable_data_ptr(), + data.data(), + total_bytes, + cudaMemcpyHostToDevice); + // Note: Error is checked but we don't fail the function + // This allows tests to proceed and handle errors as needed + (void)memcpy_err; // Suppress unused variable warning + } else { // CPU + std::memcpy(tensor->mutable_data_ptr(), data.data(), total_bytes); + } + + return tensor; + } + + // Helper to get data from tensor + std::vector get_tensor_data(Tensor* tensor) { + if (!tensor) { + return {}; + } + + size_t num_elements = tensor->numel(); + std::vector data(num_elements); + + // Determine if this is a CUDA tensor + cudaPointerAttributes attributes{}; + cudaError_t err = cudaPointerGetAttributes(&attributes, tensor->data_ptr()); + bool is_device = + (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice); + + if (is_device) { + cudaError_t memcpy_err = cudaMemcpy( + data.data(), + tensor->data_ptr(), + num_elements * sizeof(float), + cudaMemcpyDeviceToHost); + // Note: Error is checked but we don't fail the function + // This allows tests to proceed and handle errors as needed + (void)memcpy_err; // Suppress unused variable warning + } else { + std::memcpy( + data.data(), tensor->data_ptr(), num_elements * sizeof(float)); + } + + return data; + } + + // Helper to verify two tensors have same data + bool tensors_equal(Tensor* a, Tensor* b, float tolerance = 1e-6f) { + if (!a || !b) { + return false; + } + if (a->numel() != b->numel()) { + return false; + } + + auto data_a = get_tensor_data(a); + auto data_b = get_tensor_data(b); + + for (size_t i = 0; i < data_a.size(); ++i) { + if (std::abs(data_a[i] - data_b[i]) > tolerance) { + return false; + } + } + return true; + } +}; + +// Test basic copy functionality - same schema (fast path) +TEST_F(AOTITorchCopyTest, BasicCopySameSchema) { + // Create source tensor with test data + std::vector sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* src = create_test_tensor_with_data(sizes, src_data); + EXPECT_NE(src, nullptr); + + // Create destination tensor with same schema + Tensor* dst = + create_test_tensor_with_data(sizes, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + EXPECT_NE(dst, nullptr); + + // Perform copy + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy was successful + EXPECT_TRUE(tensors_equal(dst, src)); +} + +// Test copy with different strides (pointwise fallback) +TEST_F(AOTITorchCopyTest, CopyDifferentStrides) { + // Create source tensor (2x3) with contiguous layout + std::vector src_sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* src = create_test_tensor_with_data(src_sizes, src_data); + EXPECT_NE(src, nullptr); + + // Create destination tensor with transposed strides + std::vector dst_strides = {1, 2}; // Column-major layout + Tensor* dst = create_test_tensor_with_data( + src_sizes, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, dst_strides); + EXPECT_NE(dst, nullptr); + + // Perform copy - this should use pointwise fallback + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify the copy worked correctly by checking specific elements + auto dst_data = get_tensor_data(dst); + auto src_data_check = get_tensor_data(src); + + // For transposed layout, the data should be rearranged + EXPECT_EQ(dst_data.size(), 6); + EXPECT_EQ(src_data_check.size(), 6); +} + +// Test copy between CPU and CUDA tensors +TEST_F(AOTITorchCopyTest, CopyCPUToCUDA) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor_with_data( + sizes, + data, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); // CPU + EXPECT_NE(cpu_tensor, nullptr); + + // Create CUDA tensor + Tensor* cuda_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); // CUDA + EXPECT_NE(cuda_tensor, nullptr); + + // Copy from CPU to CUDA + AOTITorchError error = aoti_torch_copy_(cuda_tensor, cpu_tensor, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy + EXPECT_TRUE(tensors_equal(cuda_tensor, cpu_tensor)); +} + +// Test copy between CUDA and CPU tensors +TEST_F(AOTITorchCopyTest, CopyCUDAToCPU) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create CUDA tensor + Tensor* cuda_tensor = create_test_tensor_with_data( + sizes, + data, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CUDA)); // CUDA + EXPECT_NE(cuda_tensor, nullptr); + + // Create CPU tensor + Tensor* cpu_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDevices::CPU)); // CPU + EXPECT_NE(cpu_tensor, nullptr); + + // Copy from CUDA to CPU + AOTITorchError error = aoti_torch_copy_(cpu_tensor, cuda_tensor, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify copy + EXPECT_TRUE(tensors_equal(cpu_tensor, cuda_tensor)); +} + +// Test copy with bf16 dtype support +TEST_F(AOTITorchCopyTest, CopyBf16Tensors) { + // Test that bf16 tensors can be created and copied + std::vector sizes = {2, 3}; + std::vector src_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Note: We create float32 data but the tensor will be created with bf16 dtype + // This simulates creating bf16 tensors + Tensor* src = create_test_tensor_with_data( + sizes, + src_data, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(src, nullptr); + + // Create destination tensor with bf16 dtype + std::vector dst_init(6, 0.0f); + Tensor* dst = create_test_tensor_with_data( + sizes, + dst_init, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(dst, nullptr); + + // Perform copy between bf16 tensors + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify that both tensors have the expected dtype + int32_t src_dtype, dst_dtype; + aoti_torch_get_dtype(src, &src_dtype); + aoti_torch_get_dtype(dst, &dst_dtype); + + EXPECT_EQ(src_dtype, static_cast(SupportedDTypes::BFLOAT16)); + EXPECT_EQ(dst_dtype, static_cast(SupportedDTypes::BFLOAT16)); + + // Verify copy was successful by checking numel matches + EXPECT_EQ(src->numel(), dst->numel()); + EXPECT_EQ(src->numel(), 6); +} + +// Test copy between different dtypes should fail +TEST_F(AOTITorchCopyTest, CopyDTypeMismatchError) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Create float32 tensor + Tensor* float32_tensor = create_test_tensor_with_data( + sizes, + data, + {}, // default strides + static_cast(SupportedDTypes::FLOAT32), // float32 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(float32_tensor, nullptr); + + // Create bf16 tensor + Tensor* bf16_tensor = create_test_tensor_with_data( + sizes, + {0.0f, 0.0f, 0.0f, 0.0f}, + {}, // default strides + static_cast(SupportedDTypes::BFLOAT16), // bf16 dtype + static_cast(SupportedDevices::CUDA), // CUDA device + 0 // device_index = 0 + ); + EXPECT_NE(bf16_tensor, nullptr); + + // Attempting to copy between different dtypes should fail + AOTITorchError error = aoti_torch_copy_(bf16_tensor, float32_tensor, 0); + EXPECT_EQ(error, Error::InvalidArgument); + + // Reverse direction should also fail + error = aoti_torch_copy_(float32_tensor, bf16_tensor, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test error conditions +TEST_F(AOTITorchCopyTest, ErrorHandling) { + std::vector sizes = {2, 3}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + Tensor* valid_tensor = create_test_tensor_with_data(sizes, data); + EXPECT_NE(valid_tensor, nullptr); + + // Test null pointers + AOTITorchError error = aoti_torch_copy_(nullptr, valid_tensor, 0); + EXPECT_NE(error, Error::Ok); + + error = aoti_torch_copy_(valid_tensor, nullptr, 0); + EXPECT_NE(error, Error::Ok); + + // Test numel mismatch (different total number of elements) + std::vector different_numel_sizes = { + 2, 3, 4}; // 24 elements vs 6 elements + std::vector different_data(24, 1.0f); + Tensor* different_numel = + create_test_tensor_with_data(different_numel_sizes, different_data); + EXPECT_NE(different_numel, nullptr); + + error = aoti_torch_copy_(valid_tensor, different_numel, 0); + EXPECT_EQ(error, Error::InvalidArgument); +} + +// Test copy from 1D to 3D with same total elements +TEST_F(AOTITorchCopyTest, Copy1DTo3DSameNumel) { + // Source tensor: 8 elements in 1D + std::vector src_sizes = {8}; + std::vector src_data = { + 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + + Tensor* src = create_test_tensor_with_data(src_sizes, src_data); + EXPECT_NE(src, nullptr); + + // Destination tensor: 2x2x2 = 8 elements (different shape, same total) + std::vector dst_sizes = {2, 2, 2}; + std::vector dst_init(8, 0.0f); + Tensor* dst = create_test_tensor_with_data(dst_sizes, dst_init); + EXPECT_NE(dst, nullptr); + + // This should work - same total number of elements + AOTITorchError error = aoti_torch_copy_(dst, src, 0); + EXPECT_EQ(error, Error::Ok); + + // Verify the data was copied correctly + auto dst_data = get_tensor_data(dst); + EXPECT_EQ(dst_data.size(), 8); + + // Check some specific elements to verify correct copying + EXPECT_FLOAT_EQ(dst_data[0], 1.0f); + EXPECT_FLOAT_EQ(dst_data[7], 8.0f); +}