Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 268 additions & 3 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ using executorch::aten::SizesType;
using executorch::aten::StridesType;
using executorch::backends::aoti::aoti_torch_get_device_index;
using executorch::backends::aoti::aoti_torch_get_dtype;
using executorch::backends::aoti::aoti_torch_get_sizes;
using executorch::backends::aoti::aoti_torch_get_strides;
using executorch::backends::aoti::dtype_to_element_size;
using executorch::backends::aoti::dtype_to_scalar_type;
using executorch::backends::aoti::validate_storage_offset;
Expand All @@ -40,6 +42,67 @@ std::unordered_set<std::shared_ptr<Tensor>> tensors;
constexpr int32_t NOT_OWN = -1;
std::unordered_map<void*, int32_t> memory_to_n_tensor;

namespace {

// Calculate linear offset from strides and indices
int64_t calculate_linear_offset(
const int64_t* indices,
const int64_t* strides,
int64_t ndim) {
int64_t offset = 0;
for (int64_t i = 0; i < ndim; ++i) {
offset += indices[i] * strides[i];
}
return offset;
}

// Convert linear index to multi-dimensional indices based on sizes
void linear_to_indices(
int64_t linear_idx,
const int64_t* sizes,
int64_t ndim,
int64_t* indices) {
for (int64_t i = ndim - 1; i >= 0; --i) {
indices[i] = linear_idx % sizes[i];
linear_idx /= sizes[i];
}
}

// Generic pointwise copy function that handles arbitrary strides
template <typename T>
AOTITorchError pointwise_copy_generic(
T* dst_data,
const T* src_data,
const int64_t* dst_sizes,
const int64_t* dst_strides,
const int64_t* src_sizes,
const int64_t* src_strides,
int64_t dst_ndim,
int64_t src_ndim,
int64_t total_elements) {
std::vector<int64_t> dst_indices(dst_ndim);
std::vector<int64_t> src_indices(src_ndim);

for (int64_t linear_idx = 0; linear_idx < total_elements; ++linear_idx) {
// Convert linear index to multi-dimensional indices for both tensors
linear_to_indices(linear_idx, dst_sizes, dst_ndim, dst_indices.data());
linear_to_indices(linear_idx, src_sizes, src_ndim, src_indices.data());

// Calculate offsets for both source and destination
int64_t src_offset =
calculate_linear_offset(src_indices.data(), src_strides, src_ndim);
int64_t dst_offset =
calculate_linear_offset(dst_indices.data(), dst_strides, dst_ndim);

// Copy element
dst_data[dst_offset] = src_data[src_offset];
}

return Error::Ok;
}

} // anonymous namespace

extern "C" {

AOTITorchError aoti_torch_create_tensor_from_blob_v2(
Expand Down Expand Up @@ -178,9 +241,10 @@ AOTITorchError aoti_torch_empty_strided(
}
int64_t nbytes = numel * element_size;

if (device_type == 1) { // cuda
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes));
} else if (device_type == 0) { // cpu
if (device_type == static_cast<int32_t>(SupportedDevices::CUDA)) {
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaMallocManaged(&ptr, static_cast<size_t>(nbytes)));
} else if (device_type == static_cast<int32_t>(SupportedDevices::CPU)) {
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
int result = posix_memalign(&ptr, 16, nbytes);
if (result != 0) {
Expand Down Expand Up @@ -312,6 +376,207 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
return Error::Internal;
}

AOTITorchError
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
(void)non_blocking;

// Check for null pointers first
if (self == nullptr) {
ET_LOG(Error, "aoti_torch_copy_ failed: self tensor is null");
return Error::InvalidArgument;
}

if (src == nullptr) {
ET_LOG(Error, "aoti_torch_copy_ failed: src tensor is null");
return Error::InvalidArgument;
}

// Get dtype information and validate compatibility
int32_t self_dtype, src_dtype;
aoti_torch_get_dtype(self, &self_dtype);
aoti_torch_get_dtype(src, &src_dtype);

AOTITorchError self_dtype_error = validate_dtype(self_dtype);
if (self_dtype_error != Error::Ok) {
return self_dtype_error;
}

AOTITorchError src_dtype_error = validate_dtype(src_dtype);
if (src_dtype_error != Error::Ok) {
return src_dtype_error;
}

// Check dtype compatibility - both tensors must have the same dtype
if (self_dtype != src_dtype) {
ET_LOG(
Error,
"dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_ requires same dtypes",
self_dtype,
src_dtype);
return Error::InvalidArgument;
}

// Check total number of elements compatibility (PyTorch copy_ behavior)
int64_t self_numel = self->numel();
int64_t src_numel = src->numel();

if (self_numel != src_numel) {
ET_LOG(
Error,
"numel mismatch. self.numel()=%ld, src.numel()=%ld",
self_numel,
src_numel);
return Error::InvalidArgument;
}

// Get tensor metadata
int64_t* self_strides;
int64_t* src_strides;
aoti_torch_get_strides(self, &self_strides);
aoti_torch_get_strides(src, &src_strides);

int64_t* self_sizes;
int64_t* src_sizes;
aoti_torch_get_sizes(self, &self_sizes);
aoti_torch_get_sizes(src, &src_sizes);

// Determine device locations
cudaPointerAttributes srcAttributes{};
cudaPointerAttributes dstAttributes{};

ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&srcAttributes, src->data_ptr()));

ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&dstAttributes, self->data_ptr()));

bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;

// Check if tensors have the same schema (sizes, strides, dtype) for fast path
bool same_schema = true;
for (int i = 0; i < self->dim(); i++) {
if (self_strides[i] != src_strides[i]) {
same_schema = false;
break;
}
}

size_t total_bytes = src->nbytes();
int64_t total_elements = self->numel();

if (same_schema) {
// Fast path: Direct memory copy since layouts match exactly
if (srcIsDevice && dstIsDevice) {
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
self->mutable_data_ptr(),
src->data_ptr(),
total_bytes,
cudaMemcpyDeviceToDevice));
} else if (srcIsDevice && !dstIsDevice) {
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
self->mutable_data_ptr(),
src->data_ptr(),
total_bytes,
cudaMemcpyDeviceToHost));
} else if (!srcIsDevice && dstIsDevice) {
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
self->mutable_data_ptr(),
src->data_ptr(),
total_bytes,
cudaMemcpyHostToDevice));
} else {
std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes);
}
} else {
// Fallback path: Pointwise copy with stride-aware indexing
// This handles arbitrary tensor layouts and strides

size_t element_size = dtype_to_element_size(self_dtype);
if (element_size == 0) {
ET_LOG(Error, "Invalid element size for dtype: %d", self_dtype);
return Error::InvalidArgument;
}

// Allocate temporary host memory for GPU tensors
float* src_host_data = nullptr;
float* dst_host_data = nullptr;
bool need_free_src = false;
bool need_free_dst = false;

if (srcIsDevice) {
src_host_data =
static_cast<float*>(malloc(total_elements * sizeof(float)));
if (src_host_data == nullptr) {
ET_LOG(Error, "Failed to allocate memory for src_host_data");
return Error::MemoryAllocationFailed;
}
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
src_host_data, src->data_ptr(), total_bytes, cudaMemcpyDeviceToHost));
need_free_src = true;
} else {
src_host_data = static_cast<float*>(src->data_ptr());
}

if (dstIsDevice) {
dst_host_data =
static_cast<float*>(malloc(total_elements * sizeof(float)));
if (dst_host_data == nullptr) {
ET_LOG(Error, "Failed to allocate memory for dst_host_data");
if (need_free_src) {
free(src_host_data);
}
return Error::MemoryAllocationFailed;
}
need_free_dst = true;
} else {
dst_host_data = static_cast<float*>(self->mutable_data_ptr());
}

// Perform pointwise copy with stride calculation
AOTITorchError copy_err = pointwise_copy_generic(
dst_host_data,
src_host_data,
self_sizes,
self_strides,
src_sizes,
src_strides,
self->dim(),
src->dim(),
total_elements);

if (copy_err != Error::Ok) {
// Clean up temporary buffers before returning
if (need_free_src) {
free(src_host_data);
}
if (need_free_dst) {
free(dst_host_data);
}
return copy_err;
}

// Copy result back to device if needed
if (dstIsDevice) {
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpy(
self->mutable_data_ptr(),
dst_host_data,
total_bytes,
cudaMemcpyHostToDevice));
}

// Clean up temporary buffers
if (need_free_src) {
free(src_host_data);
}
if (need_free_dst) {
free(dst_host_data);
}
}

return Error::Ok;
}

AOTITorchError aoti_torch__reinterpret_tensor(
Tensor* self,
int64_t ndim,
Expand Down
25 changes: 25 additions & 0 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,31 @@ AOTITorchError aoti_torch__reinterpret_tensor(
int64_t storage_offset,
Tensor** ret_new_tensor);

/**
* Copies data from source tensor to destination tensor.
*
* This function implements copy function for tensors living in CUDA AOTI
* backend. It supports copying between tensors with different shapes (as long
* as they have the same total number of elements) and different memory
* layouts/strides.
*
* Note that currently this function does not support copying between tensors
* with different dtypes.
*
* @param self Destination tensor (data will be overwritten)
* @param src Source tensor (data will be copied from this tensor)
* @param non_blocking Whether the copy should be non-blocking (currently
* ignored)
*
* @return Error::Ok on success, appropriate error code on failure:
* - Error::InvalidArgument: null pointers, dtype mismatch, numel
* mismatch
* - Error::MemoryAllocationFailed: failed to allocate temporary memory
* - Error::Internal: CUDA operation failures
*/
AOTITorchError
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);

// Function to clear all tensors from internal storage
void clear_all_tensors();
} // extern "C"
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
cuda_shim_cpp_unittest("aoti_torch_copy_")
Loading
Loading