Skip to content

Commit 65100f6

Browse files
authored
aoti_torch_copy_ (#14689)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #14690 * __->__ #14689 * #14688 * #14687 * #14686 Summary: This diff introduce `aoti_torch_copy_`, the function for copying tensor inside cuda backend. Right now it only support copy between tensors with same dtype. Reviewed By: Differential Revision:
1 parent aeed916 commit 65100f6

File tree

4 files changed

+692
-3
lines changed

4 files changed

+692
-3
lines changed

backends/cuda/runtime/shims/memory.cpp

Lines changed: 268 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ using executorch::aten::SizesType;
2727
using executorch::aten::StridesType;
2828
using executorch::backends::aoti::aoti_torch_get_device_index;
2929
using executorch::backends::aoti::aoti_torch_get_dtype;
30+
using executorch::backends::aoti::aoti_torch_get_sizes;
31+
using executorch::backends::aoti::aoti_torch_get_strides;
3032
using executorch::backends::aoti::dtype_to_element_size;
3133
using executorch::backends::aoti::dtype_to_scalar_type;
3234
using executorch::backends::aoti::validate_storage_offset;
@@ -40,6 +42,67 @@ std::unordered_set<std::shared_ptr<Tensor>> tensors;
4042
constexpr int32_t NOT_OWN = -1;
4143
std::unordered_map<void*, int32_t> memory_to_n_tensor;
4244

45+
namespace {
46+
47+
// Calculate linear offset from strides and indices
48+
int64_t calculate_linear_offset(
49+
const int64_t* indices,
50+
const int64_t* strides,
51+
int64_t ndim) {
52+
int64_t offset = 0;
53+
for (int64_t i = 0; i < ndim; ++i) {
54+
offset += indices[i] * strides[i];
55+
}
56+
return offset;
57+
}
58+
59+
// Convert linear index to multi-dimensional indices based on sizes
60+
void linear_to_indices(
61+
int64_t linear_idx,
62+
const int64_t* sizes,
63+
int64_t ndim,
64+
int64_t* indices) {
65+
for (int64_t i = ndim - 1; i >= 0; --i) {
66+
indices[i] = linear_idx % sizes[i];
67+
linear_idx /= sizes[i];
68+
}
69+
}
70+
71+
// Generic pointwise copy function that handles arbitrary strides
72+
template <typename T>
73+
AOTITorchError pointwise_copy_generic(
74+
T* dst_data,
75+
const T* src_data,
76+
const int64_t* dst_sizes,
77+
const int64_t* dst_strides,
78+
const int64_t* src_sizes,
79+
const int64_t* src_strides,
80+
int64_t dst_ndim,
81+
int64_t src_ndim,
82+
int64_t total_elements) {
83+
std::vector<int64_t> dst_indices(dst_ndim);
84+
std::vector<int64_t> src_indices(src_ndim);
85+
86+
for (int64_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) {
87+
// Convert linear index to multi-dimensional indices for both tensors
88+
linear_to_indices(linear_idx, dst_sizes, dst_ndim, dst_indices.data());
89+
linear_to_indices(linear_idx, src_sizes, src_ndim, src_indices.data());
90+
91+
// Calculate offsets for both source and destination
92+
int64_t src_offset =
93+
calculate_linear_offset(src_indices.data(), src_strides, src_ndim);
94+
int64_t dst_offset =
95+
calculate_linear_offset(dst_indices.data(), dst_strides, dst_ndim);
96+
97+
// Copy element
98+
dst_data[dst_offset] = src_data[src_offset];
99+
}
100+
101+
return Error::Ok;
102+
}
103+
104+
} // anonymous namespace
105+
43106
extern "C" {
44107

45108
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
@@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided(
178241
}
179242
int64_t nbytes = numel * element_size;
180243

181-
if (device_type == 1) { // cuda
182-
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes));
183-
} else if (device_type == 0) { // cpu
244+
if (device_type == static_cast<int32_t>(SupportedDevices::CUDA)) {
245+
ET_CUDA_CHECK_OR_RETURN_ERROR(
246+
cudaMallocManaged(&ptr, static_cast<size_t>(nbytes)));
247+
} else if (device_type == static_cast<int32_t>(SupportedDevices::CPU)) {
184248
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
185249
int result = posix_memalign(&ptr, 16, nbytes);
186250
if (result != 0) {
@@ -312,6 +376,207 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
312376
return Error::Internal;
313377
}
314378

379+
AOTITorchError
380+
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
381+
(void)non_blocking;
382+
383+
// Check for null pointers first
384+
if (self == nullptr) {
385+
ET_LOG(Error, "aoti_torch_copy_ failed: self tensor is null");
386+
return Error::InvalidArgument;
387+
}
388+
389+
if (src == nullptr) {
390+
ET_LOG(Error, "aoti_torch_copy_ failed: src tensor is null");
391+
return Error::InvalidArgument;
392+
}
393+
394+
// Get dtype information and validate compatibility
395+
int32_t self_dtype, src_dtype;
396+
aoti_torch_get_dtype(self, &self_dtype);
397+
aoti_torch_get_dtype(src, &src_dtype);
398+
399+
AOTITorchError self_dtype_error = validate_dtype(self_dtype);
400+
if (self_dtype_error != Error::Ok) {
401+
return self_dtype_error;
402+
}
403+
404+
AOTITorchError src_dtype_error = validate_dtype(src_dtype);
405+
if (src_dtype_error != Error::Ok) {
406+
return src_dtype_error;
407+
}
408+
409+
// Check dtype compatibility - both tensors must have the same dtype
410+
if (self_dtype != src_dtype) {
411+
ET_LOG(
412+
Error,
413+
"dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes",
414+
self_dtype,
415+
src_dtype);
416+
return Error::InvalidArgument;
417+
}
418+
419+
// Check total number of elements compatibility (PyTorch copy_ behavior)
420+
int64_t self_numel = self->numel();
421+
int64_t src_numel = src->numel();
422+
423+
if (self_numel != src_numel) {
424+
ET_LOG(
425+
Error,
426+
"numel mismatch. self.numel()=%ld, src.numel()=%ld",
427+
self_numel,
428+
src_numel);
429+
return Error::InvalidArgument;
430+
}
431+
432+
// Get tensor metadata
433+
int64_t* self_strides;
434+
int64_t* src_strides;
435+
aoti_torch_get_strides(self, &self_strides);
436+
aoti_torch_get_strides(src, &src_strides);
437+
438+
int64_t* self_sizes;
439+
int64_t* src_sizes;
440+
aoti_torch_get_sizes(self, &self_sizes);
441+
aoti_torch_get_sizes(src, &src_sizes);
442+
443+
// Determine device locations
444+
cudaPointerAttributes srcAttributes{};
445+
cudaPointerAttributes dstAttributes{};
446+
447+
ET_CUDA_CHECK_OR_RETURN_ERROR(
448+
cudaPointerGetAttributes(&srcAttributes, src->data_ptr()));
449+
450+
ET_CUDA_CHECK_OR_RETURN_ERROR(
451+
cudaPointerGetAttributes(&dstAttributes, self->data_ptr()));
452+
453+
bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
454+
bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;
455+
456+
// Check if tensors have the same schema (sizes, strides, dtype) for fast path
457+
bool same_schema = true;
458+
for (int i = 0; i < self->dim(); i++) {
459+
if (self_strides[i] != src_strides[i]) {
460+
same_schema = false;
461+
break;
462+
}
463+
}
464+
465+
size_t total_bytes = src->nbytes();
466+
int64_t total_elements = self->numel();
467+
468+
if (same_schema) {
469+
// Fast path: Direct memory copy since layouts match exactly
470+
if (srcIsDevice && dstIsDevice) {
471+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
472+
self->mutable_data_ptr(),
473+
src->data_ptr(),
474+
total_bytes,
475+
cudaMemcpyDeviceToDevice));
476+
} else if (srcIsDevice && !dstIsDevice) {
477+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
478+
self->mutable_data_ptr(),
479+
src->data_ptr(),
480+
total_bytes,
481+
cudaMemcpyDeviceToHost));
482+
} else if (!srcIsDevice && dstIsDevice) {
483+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
484+
self->mutable_data_ptr(),
485+
src->data_ptr(),
486+
total_bytes,
487+
cudaMemcpyHostToDevice));
488+
} else {
489+
std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes);
490+
}
491+
} else {
492+
// Fallback path: Pointwise copy with stride-aware indexing
493+
// This handles arbitrary tensor layouts and strides
494+
495+
size_t element_size = dtype_to_element_size(self_dtype);
496+
if (element_size == 0) {
497+
ET_LOG(Error, "Invalid element size for dtype: %d", self_dtype);
498+
return Error::InvalidArgument;
499+
}
500+
501+
// Allocate temporary host memory for GPU tensors
502+
float* src_host_data = nullptr;
503+
float* dst_host_data = nullptr;
504+
bool need_free_src = false;
505+
bool need_free_dst = false;
506+
507+
if (srcIsDevice) {
508+
src_host_data =
509+
static_cast<float*>(malloc(total_elements * sizeof(float)));
510+
if (src_host_data == nullptr) {
511+
ET_LOG(Error, "Failed to allocate memory for src_host_data");
512+
return Error::MemoryAllocationFailed;
513+
}
514+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
515+
src_host_data, src->data_ptr(), total_bytes, cudaMemcpyDeviceToHost));
516+
need_free_src = true;
517+
} else {
518+
src_host_data = static_cast<float*>(src->data_ptr());
519+
}
520+
521+
if (dstIsDevice) {
522+
dst_host_data =
523+
static_cast<float*>(malloc(total_elements * sizeof(float)));
524+
if (dst_host_data == nullptr) {
525+
ET_LOG(Error, "Failed to allocate memory for dst_host_data");
526+
if (need_free_src) {
527+
free(src_host_data);
528+
}
529+
return Error::MemoryAllocationFailed;
530+
}
531+
need_free_dst = true;
532+
} else {
533+
dst_host_data = static_cast<float*>(self->mutable_data_ptr());
534+
}
535+
536+
// Perform pointwise copy with stride calculation
537+
AOTITorchError copy_err = pointwise_copy_generic(
538+
dst_host_data,
539+
src_host_data,
540+
self_sizes,
541+
self_strides,
542+
src_sizes,
543+
src_strides,
544+
self->dim(),
545+
src->dim(),
546+
total_elements);
547+
548+
if (copy_err != Error::Ok) {
549+
// Clean up temporary buffers before returning
550+
if (need_free_src) {
551+
free(src_host_data);
552+
}
553+
if (need_free_dst) {
554+
free(dst_host_data);
555+
}
556+
return copy_err;
557+
}
558+
559+
// Copy result back to device if needed
560+
if (dstIsDevice) {
561+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
562+
self->mutable_data_ptr(),
563+
dst_host_data,
564+
total_bytes,
565+
cudaMemcpyHostToDevice));
566+
}
567+
568+
// Clean up temporary buffers
569+
if (need_free_src) {
570+
free(src_host_data);
571+
}
572+
if (need_free_dst) {
573+
free(dst_host_data);
574+
}
575+
}
576+
577+
return Error::Ok;
578+
}
579+
315580
AOTITorchError aoti_torch__reinterpret_tensor(
316581
Tensor* self,
317582
int64_t ndim,

backends/cuda/runtime/shims/memory.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,31 @@ AOTITorchError aoti_torch__reinterpret_tensor(
116116
int64_t storage_offset,
117117
Tensor** ret_new_tensor);
118118

119+
/**
120+
* Copies data from source tensor to destination tensor.
121+
*
122+
* This function implements copy function for tensors living in CUDA AOTI
123+
* backend. It supports copying between tensors with different shapes (as long
124+
* as they have the same total number of elements) and different memory
125+
* layouts/strides.
126+
*
127+
* Note that currently this function does not support copying between tensors
128+
* with different dtypes.
129+
*
130+
* @param self Destination tensor (data will be overwritten)
131+
* @param src Source tensor (data will be copied from this tensor)
132+
* @param non_blocking Whether the copy should be non-blocking (currently
133+
* ignored)
134+
*
135+
* @return Error::Ok on success, appropriate error code on failure:
136+
* - Error::InvalidArgument: null pointers, dtype mismatch, numel
137+
* mismatch
138+
* - Error::MemoryAllocationFailed: failed to allocate temporary memory
139+
* - Error::Internal: CUDA operation failures
140+
*/
141+
AOTITorchError
142+
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
143+
119144
// Function to clear all tensors from internal storage
120145
void clear_all_tensors();
121146
} // extern "C"

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ def define_common_targets():
3131
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
3232
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
3333
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
34+
cuda_shim_cpp_unittest("aoti_torch_copy_")

0 commit comments

Comments
 (0)