Skip to content

Commit e684bb0

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
aoti_torch_copy_
Summary: This diff introduce `aoti_torch_copy_`, the function for copying tensor inside cuda backend. Right now it only support copy between tensors with same dtype. Differential Revision: D83094604
1 parent e8a05e4 commit e684bb0

File tree

4 files changed

+690
-3
lines changed

4 files changed

+690
-3
lines changed

backends/cuda/runtime/shims/memory.cpp

Lines changed: 266 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ using executorch::aten::SizesType;
2727
using executorch::aten::StridesType;
2828
using executorch::backends::aoti::aoti_torch_get_device_index;
2929
using executorch::backends::aoti::aoti_torch_get_dtype;
30+
using executorch::backends::aoti::aoti_torch_get_sizes;
31+
using executorch::backends::aoti::aoti_torch_get_strides;
3032
using executorch::backends::aoti::dtype_to_element_size;
3133
using executorch::backends::aoti::dtype_to_scalar_type;
3234
using executorch::backends::aoti::validate_storage_offset;
@@ -40,6 +42,67 @@ std::unordered_set<std::shared_ptr<Tensor>> tensors;
4042
constexpr int32_t NOT_OWN = -1;
4143
std::unordered_map<void*, int32_t> memory_to_n_tensor;
4244

45+
namespace {
46+
47+
// Calculate linear offset from strides and indices
48+
int64_t calculate_linear_offset(
49+
const int64_t* indices,
50+
const int64_t* strides,
51+
int64_t ndim) {
52+
int64_t offset = 0;
53+
for (int64_t i = 0; i < ndim; ++i) {
54+
offset += indices[i] * strides[i];
55+
}
56+
return offset;
57+
}
58+
59+
// Convert linear index to multi-dimensional indices based on sizes
60+
void linear_to_indices(
61+
int64_t linear_idx,
62+
const int64_t* sizes,
63+
int64_t ndim,
64+
int64_t* indices) {
65+
for (int64_t i = ndim - 1; i >= 0; --i) {
66+
indices[i] = linear_idx % sizes[i];
67+
linear_idx /= sizes[i];
68+
}
69+
}
70+
71+
// Generic pointwise copy function that handles arbitrary strides
72+
template <typename T>
73+
AOTITorchError pointwise_copy_generic(
74+
T* dst_data,
75+
const T* src_data,
76+
const int64_t* dst_sizes,
77+
const int64_t* dst_strides,
78+
const int64_t* src_sizes,
79+
const int64_t* src_strides,
80+
int64_t dst_ndim,
81+
int64_t src_ndim,
82+
int64_t total_elements) {
83+
std::vector<int64_t> dst_indices(dst_ndim);
84+
std::vector<int64_t> src_indices(src_ndim);
85+
86+
for (int64_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) {
87+
// Convert linear index to multi-dimensional indices for both tensors
88+
linear_to_indices(linear_idx, dst_sizes, dst_ndim, dst_indices.data());
89+
linear_to_indices(linear_idx, src_sizes, src_ndim, src_indices.data());
90+
91+
// Calculate offsets for both source and destination
92+
int64_t src_offset =
93+
calculate_linear_offset(src_indices.data(), src_strides, src_ndim);
94+
int64_t dst_offset =
95+
calculate_linear_offset(dst_indices.data(), dst_strides, dst_ndim);
96+
97+
// Copy element
98+
dst_data[dst_offset] = src_data[src_offset];
99+
}
100+
101+
return Error::Ok;
102+
}
103+
104+
} // anonymous namespace
105+
43106
extern "C" {
44107

45108
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
@@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided(
178241
}
179242
int64_t nbytes = numel * element_size;
180243

181-
if (device_type == 1) { // cuda
182-
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes));
183-
} else if (device_type == 0) { // cpu
244+
if (device_type == static_cast<int32_t>(SupportedDevices::CUDA)) {
245+
ET_CUDA_CHECK_OR_RETURN_ERROR(
246+
cudaMallocManaged(&ptr, static_cast<size_t>(nbytes)));
247+
} else if (device_type == static_cast<int32_t>(SupportedDevices::CPU)) {
184248
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
185249
int result = posix_memalign(&ptr, 16, nbytes);
186250
if (result != 0) {
@@ -312,6 +376,205 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
312376
return Error::Internal;
313377
}
314378

379+
AOTITorchError
380+
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
381+
// Check for null pointers first
382+
if (self == nullptr) {
383+
ET_LOG(Error, "aoti_torch_copy_ failed: self tensor is null");
384+
return Error::InvalidArgument;
385+
}
386+
387+
if (src == nullptr) {
388+
ET_LOG(Error, "aoti_torch_copy_ failed: src tensor is null");
389+
return Error::InvalidArgument;
390+
}
391+
392+
// Get dtype information and validate compatibility
393+
int32_t self_dtype, src_dtype;
394+
aoti_torch_get_dtype(self, &self_dtype);
395+
aoti_torch_get_dtype(src, &src_dtype);
396+
397+
AOTITorchError self_dtype_error = validate_dtype(self_dtype);
398+
if (self_dtype_error != Error::Ok) {
399+
return self_dtype_error;
400+
}
401+
402+
AOTITorchError src_dtype_error = validate_dtype(src_dtype);
403+
if (src_dtype_error != Error::Ok) {
404+
return src_dtype_error;
405+
}
406+
407+
// Check dtype compatibility - both tensors must have the same dtype
408+
if (self_dtype != src_dtype) {
409+
ET_LOG(
410+
Error,
411+
"dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes",
412+
self_dtype,
413+
src_dtype);
414+
return Error::InvalidArgument;
415+
}
416+
417+
// Check total number of elements compatibility (PyTorch copy_ behavior)
418+
int64_t self_numel = self->numel();
419+
int64_t src_numel = src->numel();
420+
421+
if (self_numel != src_numel) {
422+
ET_LOG(
423+
Error,
424+
"numel mismatch. self.numel()=%ld, src.numel()=%ld",
425+
self_numel,
426+
src_numel);
427+
return Error::InvalidArgument;
428+
}
429+
430+
// Get tensor metadata
431+
int64_t* self_strides;
432+
int64_t* src_strides;
433+
aoti_torch_get_strides(self, &self_strides);
434+
aoti_torch_get_strides(src, &src_strides);
435+
436+
int64_t* self_sizes;
437+
int64_t* src_sizes;
438+
aoti_torch_get_sizes(self, &self_sizes);
439+
aoti_torch_get_sizes(src, &src_sizes);
440+
441+
// Determine device locations
442+
cudaPointerAttributes srcAttributes{};
443+
cudaPointerAttributes dstAttributes{};
444+
445+
ET_CUDA_CHECK_OR_RETURN_ERROR(
446+
cudaPointerGetAttributes(&srcAttributes, src->data_ptr()));
447+
448+
ET_CUDA_CHECK_OR_RETURN_ERROR(
449+
cudaPointerGetAttributes(&dstAttributes, self->data_ptr()));
450+
451+
bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
452+
bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;
453+
454+
// Check if tensors have the same schema (sizes, strides, dtype) for fast path
455+
bool same_schema = true;
456+
for (int i = 0; i < self->dim(); i++) {
457+
if (self_strides[i] != src_strides[i]) {
458+
same_schema = false;
459+
break;
460+
}
461+
}
462+
463+
size_t total_bytes = src->nbytes();
464+
int64_t total_elements = self->numel();
465+
466+
if (same_schema) {
467+
// Fast path: Direct memory copy since layouts match exactly
468+
if (srcIsDevice && dstIsDevice) {
469+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
470+
self->mutable_data_ptr(),
471+
src->data_ptr(),
472+
total_bytes,
473+
cudaMemcpyDeviceToDevice));
474+
} else if (srcIsDevice && !dstIsDevice) {
475+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
476+
self->mutable_data_ptr(),
477+
src->data_ptr(),
478+
total_bytes,
479+
cudaMemcpyDeviceToHost));
480+
} else if (!srcIsDevice && dstIsDevice) {
481+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
482+
self->mutable_data_ptr(),
483+
src->data_ptr(),
484+
total_bytes,
485+
cudaMemcpyHostToDevice));
486+
} else {
487+
std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes);
488+
}
489+
} else {
490+
// Fallback path: Pointwise copy with stride-aware indexing
491+
// This handles arbitrary tensor layouts and strides
492+
493+
size_t element_size = dtype_to_element_size(self_dtype);
494+
if (element_size == 0) {
495+
ET_LOG(Error, "Invalid element size for dtype: %d", self_dtype);
496+
return Error::InvalidArgument;
497+
}
498+
499+
// Allocate temporary host memory for GPU tensors
500+
float* src_host_data = nullptr;
501+
float* dst_host_data = nullptr;
502+
bool need_free_src = false;
503+
bool need_free_dst = false;
504+
505+
if (srcIsDevice) {
506+
src_host_data =
507+
static_cast<float*>(malloc(total_elements * sizeof(float)));
508+
if (src_host_data == nullptr) {
509+
ET_LOG(Error, "Failed to allocate memory for src_host_data");
510+
return Error::MemoryAllocationFailed;
511+
}
512+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
513+
src_host_data, src->data_ptr(), total_bytes, cudaMemcpyDeviceToHost));
514+
need_free_src = true;
515+
} else {
516+
src_host_data = static_cast<float*>(src->data_ptr());
517+
}
518+
519+
if (dstIsDevice) {
520+
dst_host_data =
521+
static_cast<float*>(malloc(total_elements * sizeof(float)));
522+
if (dst_host_data == nullptr) {
523+
ET_LOG(Error, "Failed to allocate memory for dst_host_data");
524+
if (need_free_src) {
525+
free(src_host_data);
526+
}
527+
return Error::MemoryAllocationFailed;
528+
}
529+
need_free_dst = true;
530+
} else {
531+
dst_host_data = static_cast<float*>(self->mutable_data_ptr());
532+
}
533+
534+
// Perform pointwise copy with stride calculation
535+
AOTITorchError copy_err = pointwise_copy_generic(
536+
dst_host_data,
537+
src_host_data,
538+
self_sizes,
539+
self_strides,
540+
src_sizes,
541+
src_strides,
542+
self->dim(),
543+
src->dim(),
544+
total_elements);
545+
546+
if (copy_err != Error::Ok) {
547+
// Clean up temporary buffers before returning
548+
if (need_free_src) {
549+
free(src_host_data);
550+
}
551+
if (need_free_dst) {
552+
free(dst_host_data);
553+
}
554+
return copy_err;
555+
}
556+
557+
// Copy result back to device if needed
558+
if (dstIsDevice) {
559+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
560+
self->mutable_data_ptr(),
561+
dst_host_data,
562+
total_bytes,
563+
cudaMemcpyHostToDevice));
564+
}
565+
566+
// Clean up temporary buffers
567+
if (need_free_src) {
568+
free(src_host_data);
569+
}
570+
if (need_free_dst) {
571+
free(dst_host_data);
572+
}
573+
}
574+
575+
return Error::Ok;
576+
}
577+
315578
AOTITorchError aoti_torch__reinterpret_tensor(
316579
Tensor* self,
317580
int64_t ndim,

backends/cuda/runtime/shims/memory.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,31 @@ AOTITorchError aoti_torch__reinterpret_tensor(
116116
int64_t storage_offset,
117117
Tensor** ret_new_tensor);
118118

119+
/**
120+
* Copies data from source tensor to destination tensor.
121+
*
122+
* This function implements copy function for tensors living in CUDA AOTI
123+
* backend. It supports copying between tensors with different shapes (as long
124+
* as they have the same total number of elements) and different memory
125+
* layouts/strides.
126+
*
127+
* Note that currently this function does not support copying between tensors
128+
* with different dtypes.
129+
*
130+
* @param self Destination tensor (data will be overwritten)
131+
* @param src Source tensor (data will be copied from this tensor)
132+
* @param non_blocking Whether the copy should be non-blocking (currently
133+
* ignored)
134+
*
135+
* @return Error::Ok on success, appropriate error code on failure:
136+
* - Error::InvalidArgument: null pointers, dtype mismatch, numel
137+
* mismatch
138+
* - Error::MemoryAllocationFailed: failed to allocate temporary memory
139+
* - Error::Internal: CUDA operation failures
140+
*/
141+
AOTITorchError
142+
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
143+
119144
// Function to clear all tensors from internal storage
120145
void clear_all_tensors();
121146
} // extern "C"

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ def define_common_targets():
3131
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
3232
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
3333
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
34+
cuda_shim_cpp_unittest("aoti_torch_copy_")

0 commit comments

Comments
 (0)