Skip to content

Commit b0b94fe

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
aoti_torch__reinterpret_tensor (#14614)
Summary: Pull Request resolved: #14614 Introduced aoti_torch__reinterpret_tensor, which creates a new tensor view that reinterprets the same underlying memory with custom shape and strides. Differential Revision: D83094603
1 parent 6a00b66 commit b0b94fe

File tree

4 files changed

+945
-0
lines changed

4 files changed

+945
-0
lines changed

backends/cuda/runtime/shims/memory.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ namespace cuda {
2525

2626
using executorch::aten::SizesType;
2727
using executorch::aten::StridesType;
28+
using executorch::backends::aoti::aoti_torch_get_device_index;
29+
using executorch::backends::aoti::aoti_torch_get_dtype;
2830
using executorch::backends::aoti::dtype_to_element_size;
2931
using executorch::backends::aoti::dtype_to_scalar_type;
3032
using executorch::backends::aoti::validate_storage_offset;
@@ -310,6 +312,115 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
310312
return Error::Internal;
311313
}
312314

315+
AOTITorchError aoti_torch__reinterpret_tensor(
316+
Tensor* self,
317+
int64_t ndim,
318+
const int64_t* sizes_ptr,
319+
const int64_t* strides_ptr,
320+
int64_t storage_offset,
321+
Tensor** ret_new_tensor) {
322+
// Validate input parameters first
323+
if (self == nullptr) {
324+
ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: self tensor is null");
325+
return Error::InvalidArgument;
326+
}
327+
328+
if (sizes_ptr == nullptr && ndim > 0) {
329+
ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: sizes_ptr is null");
330+
return Error::InvalidArgument;
331+
}
332+
333+
if (ret_new_tensor == nullptr) {
334+
ET_LOG(
335+
Error, "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null");
336+
return Error::InvalidArgument;
337+
}
338+
339+
// Check if storage_offset is not 0 - return error if not
340+
AOTITorchError storage_offset_error = validate_storage_offset(storage_offset);
341+
if (storage_offset_error != Error::Ok) {
342+
return storage_offset_error;
343+
}
344+
345+
// Get the device info from the source tensor to perform device_index
346+
// validation
347+
int32_t device_type = 0;
348+
int32_t device_index = 0;
349+
AOTITorchError device_error = aoti_torch_get_device_type(self, &device_type);
350+
if (device_error != Error::Ok) {
351+
return device_error;
352+
}
353+
354+
device_error = aoti_torch_get_device_index(self, &device_index);
355+
if (device_error != Error::Ok) {
356+
return device_error;
357+
}
358+
359+
// Ensure device_index is always 0
360+
if (device_index != 0) {
361+
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
362+
return Error::InvalidArgument;
363+
}
364+
365+
// Get the dtype from the source tensor
366+
int32_t dtype = 0;
367+
AOTITorchError dtype_error = aoti_torch_get_dtype(self, &dtype);
368+
if (dtype_error != Error::Ok) {
369+
return dtype_error;
370+
}
371+
372+
// Validate dtype using SupportedDTypes
373+
dtype_error = validate_dtype(dtype);
374+
if (dtype_error != Error::Ok) {
375+
return dtype_error;
376+
}
377+
378+
// Get the original data pointer from the source tensor
379+
void* data_ptr = self->mutable_data_ptr();
380+
if (data_ptr == nullptr) {
381+
ET_LOG(Error, "Source tensor has null data pointer");
382+
return Error::InvalidArgument;
383+
}
384+
385+
// Check if the given memory is in the map, if not return error
386+
auto memory_it = memory_to_n_tensor.find(data_ptr);
387+
if (memory_it == memory_to_n_tensor.end()) {
388+
ET_LOG(Error, "Memory address %p is not being tracked by reference counting system", data_ptr);
389+
return Error::InvalidArgument;
390+
}
391+
392+
// Convert sizes using utility function from utils.h
393+
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
394+
395+
// Convert strides using utility function from utils.h
396+
std::vector<StridesType> strides =
397+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
398+
399+
// Create new tensor view that reinterprets the same memory with different
400+
// shape/strides This creates a view, not a copy - the data pointer is shared
401+
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
402+
data_ptr, // Reuse the same memory from source tensor
403+
sizes, // New sizes with explicit SizesType
404+
strides, // New strides with explicit StridesType
405+
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
406+
);
407+
408+
if (!tensor) {
409+
ET_LOG(Error, "Failed to create reinterpreted tensor view");
410+
return Error::InvalidArgument;
411+
}
412+
413+
// Store the tensor so it doesn't get destroyed
414+
tensors.insert(tensor);
415+
416+
*ret_new_tensor = tensor.get();
417+
418+
// Increment the reference count for this memory address only if it is owned by tensor
419+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN ? NOT_OWN : memory_to_n_tensor[data_ptr] + 1;
420+
421+
return Error::Ok;
422+
}
423+
313424
} // extern "C"
314425

315426
} // namespace cuda

backends/cuda/runtime/shims/memory.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,30 @@ AOTITorchError aoti_torch_empty_strided(
9191
*/
9292
AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);
9393

94+
/**
95+
* Creates a tensor view that reinterprets the same underlying memory with
96+
* different shape and strides without copying data.
97+
*
98+
* Note that the new tensor will not have the ownership of the underlying memory.
99+
*
100+
* @param self Input tensor whose memory will be reinterpreted
101+
* @param ndim Number of dimensions for the new tensor view
102+
* @param sizes_ptr Array of sizes for each dimension
103+
* @param strides_ptr Array of strides for each dimension (or nullptr for
104+
* contiguous)
105+
* @param storage_offset Storage offset (must be 0)
106+
* @param ret_new_tensor Output pointer to store the new tensor view
107+
*
108+
* @return Error::Ok on success, appropriate error code on failure
109+
*/
110+
AOTITorchError aoti_torch__reinterpret_tensor(
111+
Tensor* self,
112+
int64_t ndim,
113+
const int64_t* sizes_ptr,
114+
const int64_t* strides_ptr,
115+
int64_t storage_offset,
116+
Tensor** ret_new_tensor);
117+
94118
// Function to clear all tensors from internal storage
95119
void clear_all_tensors();
96120
} // extern "C"

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ def define_common_targets():
3030
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
3131
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
3232
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
33+
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")

0 commit comments

Comments
 (0)