@@ -27,6 +27,8 @@ using executorch::aten::SizesType;
2727using executorch::aten::StridesType;
2828using executorch::backends::aoti::aoti_torch_get_device_index;
2929using executorch::backends::aoti::aoti_torch_get_dtype;
30+ using executorch::backends::aoti::aoti_torch_get_sizes;
31+ using executorch::backends::aoti::aoti_torch_get_strides;
3032using executorch::backends::aoti::dtype_to_element_size;
3133using executorch::backends::aoti::dtype_to_scalar_type;
3234using executorch::backends::aoti::validate_storage_offset;
@@ -40,6 +42,67 @@ std::unordered_set<std::shared_ptr<Tensor>> tensors;
4042constexpr int32_t NOT_OWN = -1 ;
4143std::unordered_map<void *, int32_t > memory_to_n_tensor;
4244
45+ namespace {
46+
47+ // Calculate linear offset from strides and indices
48+ int64_t calculate_linear_offset (
49+ const int64_t * indices,
50+ const int64_t * strides,
51+ int64_t ndim) {
52+ int64_t offset = 0 ;
53+ for (int64_t i = 0 ; i < ndim; ++i) {
54+ offset += indices[i] * strides[i];
55+ }
56+ return offset;
57+ }
58+
59+ // Convert linear index to multi-dimensional indices based on sizes
60+ void linear_to_indices (
61+ int64_t linear_idx,
62+ const int64_t * sizes,
63+ int64_t ndim,
64+ int64_t * indices) {
65+ for (int64_t i = ndim - 1 ; i >= 0 ; --i) {
66+ indices[i] = linear_idx % sizes[i];
67+ linear_idx /= sizes[i];
68+ }
69+ }
70+
71+ // Generic pointwise copy function that handles arbitrary strides
72+ template <typename T>
73+ AOTITorchError pointwise_copy_generic (
74+ T* dst_data,
75+ const T* src_data,
76+ const int64_t * dst_sizes,
77+ const int64_t * dst_strides,
78+ const int64_t * src_sizes,
79+ const int64_t * src_strides,
80+ int64_t dst_ndim,
81+ int64_t src_ndim,
82+ int64_t total_elements) {
83+ std::vector<int64_t > dst_indices (dst_ndim);
84+ std::vector<int64_t > src_indices (src_ndim);
85+
86+ for (int64_t linear_idx = 0 ; linear_idx < total_elements; ++linear_idx) {
87+ // Convert linear index to multi-dimensional indices for both tensors
88+ linear_to_indices (linear_idx, dst_sizes, dst_ndim, dst_indices.data ());
89+ linear_to_indices (linear_idx, src_sizes, src_ndim, src_indices.data ());
90+
91+ // Calculate offsets for both source and destination
92+ int64_t src_offset =
93+ calculate_linear_offset (src_indices.data (), src_strides, src_ndim);
94+ int64_t dst_offset =
95+ calculate_linear_offset (dst_indices.data (), dst_strides, dst_ndim);
96+
97+ // Copy element
98+ dst_data[dst_offset] = src_data[src_offset];
99+ }
100+
101+ return Error::Ok;
102+ }
103+
104+ } // anonymous namespace
105+
43106extern " C" {
44107
45108AOTITorchError aoti_torch_create_tensor_from_blob_v2 (
@@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided(
178241 }
179242 int64_t nbytes = numel * element_size;
180243
181- if (device_type == 1 ) { // cuda
182- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMallocManaged (&ptr, nbytes));
183- } else if (device_type == 0 ) { // cpu
244+ if (device_type == static_cast <int32_t >(SupportedDevices::CUDA)) {
245+ ET_CUDA_CHECK_OR_RETURN_ERROR (
246+ cudaMallocManaged (&ptr, static_cast <size_t >(nbytes)));
247+ } else if (device_type == static_cast <int32_t >(SupportedDevices::CPU)) {
184248 // Ensure 16-byte alignment for CPU memory to match CUDA requirements
185249 int result = posix_memalign (&ptr, 16 , nbytes);
186250 if (result != 0 ) {
@@ -312,6 +376,205 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
312376 return Error::Internal;
313377}
314378
379+ AOTITorchError
380+ aoti_torch_copy_ (Tensor* self, Tensor* src, int32_t non_blocking) {
381+ // Check for null pointers first
382+ if (self == nullptr ) {
383+ ET_LOG (Error, " aoti_torch_copy_ failed: self tensor is null" );
384+ return Error::InvalidArgument;
385+ }
386+
387+ if (src == nullptr ) {
388+ ET_LOG (Error, " aoti_torch_copy_ failed: src tensor is null" );
389+ return Error::InvalidArgument;
390+ }
391+
392+ // Get dtype information and validate compatibility
393+ int32_t self_dtype, src_dtype;
394+ aoti_torch_get_dtype (self, &self_dtype);
395+ aoti_torch_get_dtype (src, &src_dtype);
396+
397+ AOTITorchError self_dtype_error = validate_dtype (self_dtype);
398+ if (self_dtype_error != Error::Ok) {
399+ return self_dtype_error;
400+ }
401+
402+ AOTITorchError src_dtype_error = validate_dtype (src_dtype);
403+ if (src_dtype_error != Error::Ok) {
404+ return src_dtype_error;
405+ }
406+
407+ // Check dtype compatibility - both tensors must have the same dtype
408+ if (self_dtype != src_dtype) {
409+ ET_LOG (
410+ Error,
411+ " dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes" ,
412+ self_dtype,
413+ src_dtype);
414+ return Error::InvalidArgument;
415+ }
416+
417+ // Check total number of elements compatibility (PyTorch copy_ behavior)
418+ int64_t self_numel = self->numel ();
419+ int64_t src_numel = src->numel ();
420+
421+ if (self_numel != src_numel) {
422+ ET_LOG (
423+ Error,
424+ " numel mismatch. self.numel()=%ld, src.numel()=%ld" ,
425+ self_numel,
426+ src_numel);
427+ return Error::InvalidArgument;
428+ }
429+
430+ // Get tensor metadata
431+ int64_t * self_strides;
432+ int64_t * src_strides;
433+ aoti_torch_get_strides (self, &self_strides);
434+ aoti_torch_get_strides (src, &src_strides);
435+
436+ int64_t * self_sizes;
437+ int64_t * src_sizes;
438+ aoti_torch_get_sizes (self, &self_sizes);
439+ aoti_torch_get_sizes (src, &src_sizes);
440+
441+ // Determine device locations
442+ cudaPointerAttributes srcAttributes{};
443+ cudaPointerAttributes dstAttributes{};
444+
445+ ET_CUDA_CHECK_OR_RETURN_ERROR (
446+ cudaPointerGetAttributes (&srcAttributes, src->data_ptr ()));
447+
448+ ET_CUDA_CHECK_OR_RETURN_ERROR (
449+ cudaPointerGetAttributes (&dstAttributes, self->data_ptr ()));
450+
451+ bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
452+ bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;
453+
454+ // Check if tensors have the same schema (sizes, strides, dtype) for fast path
455+ bool same_schema = true ;
456+ for (int i = 0 ; i < self->dim (); i++) {
457+ if (self_strides[i] != src_strides[i]) {
458+ same_schema = false ;
459+ break ;
460+ }
461+ }
462+
463+ size_t total_bytes = src->nbytes ();
464+ int64_t total_elements = self->numel ();
465+
466+ if (same_schema) {
467+ // Fast path: Direct memory copy since layouts match exactly
468+ if (srcIsDevice && dstIsDevice) {
469+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
470+ self->mutable_data_ptr (),
471+ src->data_ptr (),
472+ total_bytes,
473+ cudaMemcpyDeviceToDevice));
474+ } else if (srcIsDevice && !dstIsDevice) {
475+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
476+ self->mutable_data_ptr (),
477+ src->data_ptr (),
478+ total_bytes,
479+ cudaMemcpyDeviceToHost));
480+ } else if (!srcIsDevice && dstIsDevice) {
481+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
482+ self->mutable_data_ptr (),
483+ src->data_ptr (),
484+ total_bytes,
485+ cudaMemcpyHostToDevice));
486+ } else {
487+ std::memcpy (self->mutable_data_ptr (), src->data_ptr (), total_bytes);
488+ }
489+ } else {
490+ // Fallback path: Pointwise copy with stride-aware indexing
491+ // This handles arbitrary tensor layouts and strides
492+
493+ size_t element_size = dtype_to_element_size (self_dtype);
494+ if (element_size == 0 ) {
495+ ET_LOG (Error, " Invalid element size for dtype: %d" , self_dtype);
496+ return Error::InvalidArgument;
497+ }
498+
499+ // Allocate temporary host memory for GPU tensors
500+ float * src_host_data = nullptr ;
501+ float * dst_host_data = nullptr ;
502+ bool need_free_src = false ;
503+ bool need_free_dst = false ;
504+
505+ if (srcIsDevice) {
506+ src_host_data =
507+ static_cast <float *>(malloc (total_elements * sizeof (float )));
508+ if (src_host_data == nullptr ) {
509+ ET_LOG (Error, " Failed to allocate memory for src_host_data" );
510+ return Error::MemoryAllocationFailed;
511+ }
512+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
513+ src_host_data, src->data_ptr (), total_bytes, cudaMemcpyDeviceToHost));
514+ need_free_src = true ;
515+ } else {
516+ src_host_data = static_cast <float *>(src->data_ptr ());
517+ }
518+
519+ if (dstIsDevice) {
520+ dst_host_data =
521+ static_cast <float *>(malloc (total_elements * sizeof (float )));
522+ if (dst_host_data == nullptr ) {
523+ ET_LOG (Error, " Failed to allocate memory for dst_host_data" );
524+ if (need_free_src) {
525+ free (src_host_data);
526+ }
527+ return Error::MemoryAllocationFailed;
528+ }
529+ need_free_dst = true ;
530+ } else {
531+ dst_host_data = static_cast <float *>(self->mutable_data_ptr ());
532+ }
533+
534+ // Perform pointwise copy with stride calculation
535+ AOTITorchError copy_err = pointwise_copy_generic (
536+ dst_host_data,
537+ src_host_data,
538+ self_sizes,
539+ self_strides,
540+ src_sizes,
541+ src_strides,
542+ self->dim (),
543+ src->dim (),
544+ total_elements);
545+
546+ if (copy_err != Error::Ok) {
547+ // Clean up temporary buffers before returning
548+ if (need_free_src) {
549+ free (src_host_data);
550+ }
551+ if (need_free_dst) {
552+ free (dst_host_data);
553+ }
554+ return copy_err;
555+ }
556+
557+ // Copy result back to device if needed
558+ if (dstIsDevice) {
559+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
560+ self->mutable_data_ptr (),
561+ dst_host_data,
562+ total_bytes,
563+ cudaMemcpyHostToDevice));
564+ }
565+
566+ // Clean up temporary buffers
567+ if (need_free_src) {
568+ free (src_host_data);
569+ }
570+ if (need_free_dst) {
571+ free (dst_host_data);
572+ }
573+ }
574+
575+ return Error::Ok;
576+ }
577+
315578AOTITorchError aoti_torch__reinterpret_tensor (
316579 Tensor* self,
317580 int64_t ndim,
0 commit comments