@@ -25,6 +25,8 @@ namespace cuda {
2525
2626using executorch::aten::SizesType;
2727using executorch::aten::StridesType;
28+ using executorch::backends::aoti::aoti_torch_get_device_index;
29+ using executorch::backends::aoti::aoti_torch_get_dtype;
2830using executorch::backends::aoti::dtype_to_element_size;
2931using executorch::backends::aoti::dtype_to_scalar_type;
3032using executorch::backends::aoti::validate_storage_offset;
@@ -310,6 +312,121 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
310312 return Error::Internal;
311313}
312314
315+ AOTITorchError aoti_torch__reinterpret_tensor (
316+ Tensor* self,
317+ int64_t ndim,
318+ const int64_t * sizes_ptr,
319+ const int64_t * strides_ptr,
320+ int64_t storage_offset,
321+ Tensor** ret_new_tensor) {
322+ // Validate input parameters first
323+ if (self == nullptr ) {
324+ ET_LOG (Error, " aoti_torch__reinterpret_tensor failed: self tensor is null" );
325+ return Error::InvalidArgument;
326+ }
327+
328+ if (sizes_ptr == nullptr && ndim > 0 ) {
329+ ET_LOG (Error, " aoti_torch__reinterpret_tensor failed: sizes_ptr is null" );
330+ return Error::InvalidArgument;
331+ }
332+
333+ if (ret_new_tensor == nullptr ) {
334+ ET_LOG (
335+ Error, " aoti_torch__reinterpret_tensor failed: ret_new_tensor is null" );
336+ return Error::InvalidArgument;
337+ }
338+
339+ // Check if storage_offset is not 0 - return error if not
340+ AOTITorchError storage_offset_error = validate_storage_offset (storage_offset);
341+ if (storage_offset_error != Error::Ok) {
342+ return storage_offset_error;
343+ }
344+
345+ // Get the device info from the source tensor to perform device_index
346+ // validation
347+ int32_t device_type = 0 ;
348+ int32_t device_index = 0 ;
349+ AOTITorchError device_error = aoti_torch_get_device_type (self, &device_type);
350+ if (device_error != Error::Ok) {
351+ return device_error;
352+ }
353+
354+ device_error = aoti_torch_get_device_index (self, &device_index);
355+ if (device_error != Error::Ok) {
356+ return device_error;
357+ }
358+
359+ // Ensure device_index is always 0
360+ if (device_index != 0 ) {
361+ ET_LOG (Error, " device_index must be 0, got: %d" , device_index);
362+ return Error::InvalidArgument;
363+ }
364+
365+ // Get the dtype from the source tensor
366+ int32_t dtype = 0 ;
367+ AOTITorchError dtype_error = aoti_torch_get_dtype (self, &dtype);
368+ if (dtype_error != Error::Ok) {
369+ return dtype_error;
370+ }
371+
372+ // Validate dtype using SupportedDTypes
373+ dtype_error = validate_dtype (dtype);
374+ if (dtype_error != Error::Ok) {
375+ return dtype_error;
376+ }
377+
378+ // Get the original data pointer from the source tensor
379+ void * data_ptr = self->mutable_data_ptr ();
380+ if (data_ptr == nullptr ) {
381+ ET_LOG (Error, " Source tensor has null data pointer" );
382+ return Error::InvalidArgument;
383+ }
384+
385+ // Check if the given memory is in the map, if not return error
386+ auto memory_it = memory_to_n_tensor.find (data_ptr);
387+ if (memory_it == memory_to_n_tensor.end ()) {
388+ ET_LOG (
389+ Error,
390+ " Memory address %p is not being tracked by reference counting system" ,
391+ data_ptr);
392+ return Error::InvalidArgument;
393+ }
394+
395+ // Convert sizes using utility function from utils.h
396+ std::vector<SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
397+
398+ // Convert strides using utility function from utils.h
399+ std::vector<StridesType> strides =
400+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
401+
402+ // Create new tensor view that reinterprets the same memory with different
403+ // shape/strides This creates a view, not a copy - the data pointer is shared
404+ std::shared_ptr<Tensor> tensor = executorch::extension::from_blob (
405+ data_ptr, // Reuse the same memory from source tensor
406+ sizes, // New sizes with explicit SizesType
407+ strides, // New strides with explicit StridesType
408+ dtype_to_scalar_type (dtype) // Convert dtype with explicit type casting
409+ );
410+
411+ if (!tensor) {
412+ ET_LOG (Error, " Failed to create reinterpreted tensor view" );
413+ return Error::InvalidArgument;
414+ }
415+
416+ // Store the tensor so it doesn't get destroyed
417+ tensors.insert (tensor);
418+
419+ *ret_new_tensor = tensor.get ();
420+
421+ // Increment the reference count for this memory address only if it is owned
422+ // by tensor
423+ memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
424+ ? NOT_OWN
425+ : memory_to_n_tensor[data_ptr] + 1 ;
426+
427+ return Error::Ok;
428+ }
429+
313430} // extern "C"
314431
315432} // namespace cuda
0 commit comments