@@ -27,6 +27,8 @@ using executorch::aten::SizesType;
2727using executorch::aten::StridesType;
2828using executorch::backends::aoti::aoti_torch_get_device_index;
2929using executorch::backends::aoti::aoti_torch_get_dtype;
30+ using executorch::backends::aoti::aoti_torch_get_sizes;
31+ using executorch::backends::aoti::aoti_torch_get_strides;
3032using executorch::backends::aoti::dtype_to_element_size;
3133using executorch::backends::aoti::dtype_to_scalar_type;
3234using executorch::backends::aoti::validate_storage_offset;
@@ -40,6 +42,67 @@ std::unordered_set<std::shared_ptr<Tensor>> tensors;
4042constexpr int32_t NOT_OWN = -1 ;
4143std::unordered_map<void *, int32_t > memory_to_n_tensor;
4244
45+ namespace {
46+
47+ // Calculate linear offset from strides and indices
48+ int64_t calculate_linear_offset (
49+ const int64_t * indices,
50+ const int64_t * strides,
51+ int64_t ndim) {
52+ int64_t offset = 0 ;
53+ for (int64_t i = 0 ; i < ndim; ++i) {
54+ offset += indices[i] * strides[i];
55+ }
56+ return offset;
57+ }
58+
59+ // Convert linear index to multi-dimensional indices based on sizes
60+ void linear_to_indices (
61+ int64_t linear_idx,
62+ const int64_t * sizes,
63+ int64_t ndim,
64+ int64_t * indices) {
65+ for (int64_t i = ndim - 1 ; i >= 0 ; --i) {
66+ indices[i] = linear_idx % sizes[i];
67+ linear_idx /= sizes[i];
68+ }
69+ }
70+
71+ // Generic pointwise copy function that handles arbitrary strides
72+ template <typename T>
73+ AOTITorchError pointwise_copy_generic (
74+ T* dst_data,
75+ const T* src_data,
76+ const int64_t * dst_sizes,
77+ const int64_t * dst_strides,
78+ const int64_t * src_sizes,
79+ const int64_t * src_strides,
80+ int64_t dst_ndim,
81+ int64_t src_ndim,
82+ int64_t total_elements) {
83+ std::vector<int64_t > dst_indices (dst_ndim);
84+ std::vector<int64_t > src_indices (src_ndim);
85+
86+ for (int64_t linear_idx = 0 ; linear_idx < total_elements; ++linear_idx) {
87+ // Convert linear index to multi-dimensional indices for both tensors
88+ linear_to_indices (linear_idx, dst_sizes, dst_ndim, dst_indices.data ());
89+ linear_to_indices (linear_idx, src_sizes, src_ndim, src_indices.data ());
90+
91+ // Calculate offsets for both source and destination
92+ int64_t src_offset =
93+ calculate_linear_offset (src_indices.data (), src_strides, src_ndim);
94+ int64_t dst_offset =
95+ calculate_linear_offset (dst_indices.data (), dst_strides, dst_ndim);
96+
97+ // Copy element
98+ dst_data[dst_offset] = src_data[src_offset];
99+ }
100+
101+ return Error::Ok;
102+ }
103+
104+ } // anonymous namespace
105+
43106extern " C" {
44107
45108AOTITorchError aoti_torch_create_tensor_from_blob_v2 (
@@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided(
178241 }
179242 int64_t nbytes = numel * element_size;
180243
181- if (device_type == 1 ) { // cuda
182- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMallocManaged (&ptr, nbytes));
183- } else if (device_type == 0 ) { // cpu
244+ if (device_type == static_cast <int32_t >(SupportedDevices::CUDA)) {
245+ ET_CUDA_CHECK_OR_RETURN_ERROR (
246+ cudaMallocManaged (&ptr, static_cast <size_t >(nbytes)));
247+ } else if (device_type == static_cast <int32_t >(SupportedDevices::CPU)) {
184248 // Ensure 16-byte alignment for CPU memory to match CUDA requirements
185249 int result = posix_memalign (&ptr, 16 , nbytes);
186250 if (result != 0 ) {
@@ -312,6 +376,207 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
312376 return Error::Internal;
313377}
314378
379+ AOTITorchError
380+ aoti_torch_copy_ (Tensor* self, Tensor* src, int32_t non_blocking) {
381+ (void )non_blocking;
382+
383+ // Check for null pointers first
384+ if (self == nullptr ) {
385+ ET_LOG (Error, " aoti_torch_copy_ failed: self tensor is null" );
386+ return Error::InvalidArgument;
387+ }
388+
389+ if (src == nullptr ) {
390+ ET_LOG (Error, " aoti_torch_copy_ failed: src tensor is null" );
391+ return Error::InvalidArgument;
392+ }
393+
394+ // Get dtype information and validate compatibility
395+ int32_t self_dtype, src_dtype;
396+ aoti_torch_get_dtype (self, &self_dtype);
397+ aoti_torch_get_dtype (src, &src_dtype);
398+
399+ AOTITorchError self_dtype_error = validate_dtype (self_dtype);
400+ if (self_dtype_error != Error::Ok) {
401+ return self_dtype_error;
402+ }
403+
404+ AOTITorchError src_dtype_error = validate_dtype (src_dtype);
405+ if (src_dtype_error != Error::Ok) {
406+ return src_dtype_error;
407+ }
408+
409+ // Check dtype compatibility - both tensors must have the same dtype
410+ if (self_dtype != src_dtype) {
411+ ET_LOG (
412+ Error,
413+ " dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes" ,
414+ self_dtype,
415+ src_dtype);
416+ return Error::InvalidArgument;
417+ }
418+
419+ // Check total number of elements compatibility (PyTorch copy_ behavior)
420+ int64_t self_numel = self->numel ();
421+ int64_t src_numel = src->numel ();
422+
423+ if (self_numel != src_numel) {
424+ ET_LOG (
425+ Error,
426+ " numel mismatch. self.numel()=%ld, src.numel()=%ld" ,
427+ self_numel,
428+ src_numel);
429+ return Error::InvalidArgument;
430+ }
431+
432+ // Get tensor metadata
433+ int64_t * self_strides;
434+ int64_t * src_strides;
435+ aoti_torch_get_strides (self, &self_strides);
436+ aoti_torch_get_strides (src, &src_strides);
437+
438+ int64_t * self_sizes;
439+ int64_t * src_sizes;
440+ aoti_torch_get_sizes (self, &self_sizes);
441+ aoti_torch_get_sizes (src, &src_sizes);
442+
443+ // Determine device locations
444+ cudaPointerAttributes srcAttributes{};
445+ cudaPointerAttributes dstAttributes{};
446+
447+ ET_CUDA_CHECK_OR_RETURN_ERROR (
448+ cudaPointerGetAttributes (&srcAttributes, src->data_ptr ()));
449+
450+ ET_CUDA_CHECK_OR_RETURN_ERROR (
451+ cudaPointerGetAttributes (&dstAttributes, self->data_ptr ()));
452+
453+ bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
454+ bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;
455+
456+ // Check if tensors have the same schema (sizes, strides, dtype) for fast path
457+ bool same_schema = true ;
458+ for (int i = 0 ; i < self->dim (); i++) {
459+ if (self_strides[i] != src_strides[i]) {
460+ same_schema = false ;
461+ break ;
462+ }
463+ }
464+
465+ size_t total_bytes = src->nbytes ();
466+ int64_t total_elements = self->numel ();
467+
468+ if (same_schema) {
469+ // Fast path: Direct memory copy since layouts match exactly
470+ if (srcIsDevice && dstIsDevice) {
471+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
472+ self->mutable_data_ptr (),
473+ src->data_ptr (),
474+ total_bytes,
475+ cudaMemcpyDeviceToDevice));
476+ } else if (srcIsDevice && !dstIsDevice) {
477+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
478+ self->mutable_data_ptr (),
479+ src->data_ptr (),
480+ total_bytes,
481+ cudaMemcpyDeviceToHost));
482+ } else if (!srcIsDevice && dstIsDevice) {
483+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
484+ self->mutable_data_ptr (),
485+ src->data_ptr (),
486+ total_bytes,
487+ cudaMemcpyHostToDevice));
488+ } else {
489+ std::memcpy (self->mutable_data_ptr (), src->data_ptr (), total_bytes);
490+ }
491+ } else {
492+ // Fallback path: Pointwise copy with stride-aware indexing
493+ // This handles arbitrary tensor layouts and strides
494+
495+ size_t element_size = dtype_to_element_size (self_dtype);
496+ if (element_size == 0 ) {
497+ ET_LOG (Error, " Invalid element size for dtype: %d" , self_dtype);
498+ return Error::InvalidArgument;
499+ }
500+
501+ // Allocate temporary host memory for GPU tensors
502+ float * src_host_data = nullptr ;
503+ float * dst_host_data = nullptr ;
504+ bool need_free_src = false ;
505+ bool need_free_dst = false ;
506+
507+ if (srcIsDevice) {
508+ src_host_data =
509+ static_cast <float *>(malloc (total_elements * sizeof (float )));
510+ if (src_host_data == nullptr ) {
511+ ET_LOG (Error, " Failed to allocate memory for src_host_data" );
512+ return Error::MemoryAllocationFailed;
513+ }
514+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
515+ src_host_data, src->data_ptr (), total_bytes, cudaMemcpyDeviceToHost));
516+ need_free_src = true ;
517+ } else {
518+ src_host_data = static_cast <float *>(src->data_ptr ());
519+ }
520+
521+ if (dstIsDevice) {
522+ dst_host_data =
523+ static_cast <float *>(malloc (total_elements * sizeof (float )));
524+ if (dst_host_data == nullptr ) {
525+ ET_LOG (Error, " Failed to allocate memory for dst_host_data" );
526+ if (need_free_src) {
527+ free (src_host_data);
528+ }
529+ return Error::MemoryAllocationFailed;
530+ }
531+ need_free_dst = true ;
532+ } else {
533+ dst_host_data = static_cast <float *>(self->mutable_data_ptr ());
534+ }
535+
536+ // Perform pointwise copy with stride calculation
537+ AOTITorchError copy_err = pointwise_copy_generic (
538+ dst_host_data,
539+ src_host_data,
540+ self_sizes,
541+ self_strides,
542+ src_sizes,
543+ src_strides,
544+ self->dim (),
545+ src->dim (),
546+ total_elements);
547+
548+ if (copy_err != Error::Ok) {
549+ // Clean up temporary buffers before returning
550+ if (need_free_src) {
551+ free (src_host_data);
552+ }
553+ if (need_free_dst) {
554+ free (dst_host_data);
555+ }
556+ return copy_err;
557+ }
558+
559+ // Copy result back to device if needed
560+ if (dstIsDevice) {
561+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpy (
562+ self->mutable_data_ptr (),
563+ dst_host_data,
564+ total_bytes,
565+ cudaMemcpyHostToDevice));
566+ }
567+
568+ // Clean up temporary buffers
569+ if (need_free_src) {
570+ free (src_host_data);
571+ }
572+ if (need_free_dst) {
573+ free (dst_host_data);
574+ }
575+ }
576+
577+ return Error::Ok;
578+ }
579+
315580AOTITorchError aoti_torch__reinterpret_tensor (
316581 Tensor* self,
317582 int64_t ndim,
0 commit comments