Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace cuda {

using executorch::aten::SizesType;
using executorch::aten::StridesType;
using executorch::backends::aoti::aoti_torch_get_device_index;
using executorch::backends::aoti::aoti_torch_get_dtype;
using executorch::backends::aoti::dtype_to_element_size;
using executorch::backends::aoti::dtype_to_scalar_type;
using executorch::backends::aoti::validate_storage_offset;
Expand Down Expand Up @@ -310,6 +312,121 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
return Error::Internal;
}

AOTITorchError aoti_torch__reinterpret_tensor(
Tensor* self,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
Tensor** ret_new_tensor) {
// Validate input parameters first
if (self == nullptr) {
ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: self tensor is null");
return Error::InvalidArgument;
}

if (sizes_ptr == nullptr && ndim > 0) {
ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: sizes_ptr is null");
return Error::InvalidArgument;
}

if (ret_new_tensor == nullptr) {
ET_LOG(
Error, "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null");
return Error::InvalidArgument;
}

// Check if storage_offset is not 0 - return error if not
AOTITorchError storage_offset_error = validate_storage_offset(storage_offset);
if (storage_offset_error != Error::Ok) {
return storage_offset_error;
}

// Get the device info from the source tensor to perform device_index
// validation
int32_t device_type = 0;
int32_t device_index = 0;
AOTITorchError device_error = aoti_torch_get_device_type(self, &device_type);
if (device_error != Error::Ok) {
return device_error;
}

device_error = aoti_torch_get_device_index(self, &device_index);
if (device_error != Error::Ok) {
return device_error;
}

// Ensure device_index is always 0
if (device_index != 0) {
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
return Error::InvalidArgument;
}

// Get the dtype from the source tensor
int32_t dtype = 0;
AOTITorchError dtype_error = aoti_torch_get_dtype(self, &dtype);
if (dtype_error != Error::Ok) {
return dtype_error;
}

// Validate dtype using SupportedDTypes
dtype_error = validate_dtype(dtype);
if (dtype_error != Error::Ok) {
return dtype_error;
}

// Get the original data pointer from the source tensor
void* data_ptr = self->mutable_data_ptr();
if (data_ptr == nullptr) {
ET_LOG(Error, "Source tensor has null data pointer");
return Error::InvalidArgument;
}

// Check if the given memory is in the map, if not return error
auto memory_it = memory_to_n_tensor.find(data_ptr);
if (memory_it == memory_to_n_tensor.end()) {
ET_LOG(
Error,
"Memory address %p is not being tracked by reference counting system",
data_ptr);
return Error::InvalidArgument;
}

// Convert sizes using utility function from utils.h
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);

// Convert strides using utility function from utils.h
std::vector<StridesType> strides =
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);

// Create new tensor view that reinterprets the same memory with different
// shape/strides This creates a view, not a copy - the data pointer is shared
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
data_ptr, // Reuse the same memory from source tensor
sizes, // New sizes with explicit SizesType
strides, // New strides with explicit StridesType
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
);

if (!tensor) {
ET_LOG(Error, "Failed to create reinterpreted tensor view");
return Error::InvalidArgument;
}

// Store the tensor so it doesn't get destroyed
tensors.insert(tensor);

*ret_new_tensor = tensor.get();

// Increment the reference count for this memory address only if it is owned
// by tensor
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
? NOT_OWN
: memory_to_n_tensor[data_ptr] + 1;

return Error::Ok;
}

} // extern "C"

} // namespace cuda
Expand Down
25 changes: 25 additions & 0 deletions backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ AOTITorchError aoti_torch_empty_strided(
*/
AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);

/**
* Creates a tensor view that reinterprets the same underlying memory with
* different shape and strides without copying data.
*
* Note that the new tensor will not have the ownership of the underlying
* memory.
*
* @param self Input tensor whose memory will be reinterpreted
* @param ndim Number of dimensions for the new tensor view
* @param sizes_ptr Array of sizes for each dimension
* @param strides_ptr Array of strides for each dimension (or nullptr for
* contiguous)
* @param storage_offset Storage offset (must be 0)
* @param ret_new_tensor Output pointer to store the new tensor view
*
* @return Error::Ok on success, appropriate error code on failure
*/
AOTITorchError aoti_torch__reinterpret_tensor(
Tensor* self,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
Tensor** ret_new_tensor);

// Function to clear all tensors from internal storage
void clear_all_tensors();
} // extern "C"
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")
Loading
Loading