@@ -25,6 +25,8 @@ namespace cuda {
2525
2626using  executorch::aten::SizesType;
2727using  executorch::aten::StridesType;
28+ using  executorch::backends::aoti::aoti_torch_get_device_index;
29+ using  executorch::backends::aoti::aoti_torch_get_dtype;
2830using  executorch::backends::aoti::dtype_to_element_size;
2931using  executorch::backends::aoti::dtype_to_scalar_type;
3032using  executorch::backends::aoti::validate_storage_offset;
@@ -310,6 +312,121 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
310312  return  Error::Internal;
311313}
312314
315+ AOTITorchError aoti_torch__reinterpret_tensor (
316+     Tensor* self,
317+     int64_t  ndim,
318+     const  int64_t * sizes_ptr,
319+     const  int64_t * strides_ptr,
320+     int64_t  storage_offset,
321+     Tensor** ret_new_tensor) {
322+   //  Validate input parameters first
323+   if  (self == nullptr ) {
324+     ET_LOG (Error, " aoti_torch__reinterpret_tensor failed: self tensor is null"  );
325+     return  Error::InvalidArgument;
326+   }
327+ 
328+   if  (sizes_ptr == nullptr  && ndim > 0 ) {
329+     ET_LOG (Error, " aoti_torch__reinterpret_tensor failed: sizes_ptr is null"  );
330+     return  Error::InvalidArgument;
331+   }
332+ 
333+   if  (ret_new_tensor == nullptr ) {
334+     ET_LOG (
335+         Error, " aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"  );
336+     return  Error::InvalidArgument;
337+   }
338+ 
339+   //  Check if storage_offset is not 0 - return error if not
340+   AOTITorchError storage_offset_error = validate_storage_offset (storage_offset);
341+   if  (storage_offset_error != Error::Ok) {
342+     return  storage_offset_error;
343+   }
344+ 
345+   //  Get the device info from the source tensor to perform device_index
346+   //  validation
347+   int32_t  device_type = 0 ;
348+   int32_t  device_index = 0 ;
349+   AOTITorchError device_error = aoti_torch_get_device_type (self, &device_type);
350+   if  (device_error != Error::Ok) {
351+     return  device_error;
352+   }
353+ 
354+   device_error = aoti_torch_get_device_index (self, &device_index);
355+   if  (device_error != Error::Ok) {
356+     return  device_error;
357+   }
358+ 
359+   //  Ensure device_index is always 0
360+   if  (device_index != 0 ) {
361+     ET_LOG (Error, " device_index must be 0, got: %d"  , device_index);
362+     return  Error::InvalidArgument;
363+   }
364+ 
365+   //  Get the dtype from the source tensor
366+   int32_t  dtype = 0 ;
367+   AOTITorchError dtype_error = aoti_torch_get_dtype (self, &dtype);
368+   if  (dtype_error != Error::Ok) {
369+     return  dtype_error;
370+   }
371+ 
372+   //  Validate dtype using SupportedDTypes
373+   dtype_error = validate_dtype (dtype);
374+   if  (dtype_error != Error::Ok) {
375+     return  dtype_error;
376+   }
377+ 
378+   //  Get the original data pointer from the source tensor
379+   void * data_ptr = self->mutable_data_ptr ();
380+   if  (data_ptr == nullptr ) {
381+     ET_LOG (Error, " Source tensor has null data pointer"  );
382+     return  Error::InvalidArgument;
383+   }
384+ 
385+   //  Check if the given memory is in the map, if not return error
386+   auto  memory_it = memory_to_n_tensor.find (data_ptr);
387+   if  (memory_it == memory_to_n_tensor.end ()) {
388+     ET_LOG (
389+         Error,
390+         " Memory address %p is not being tracked by reference counting system"  ,
391+         data_ptr);
392+     return  Error::InvalidArgument;
393+   }
394+ 
395+   //  Convert sizes using utility function from utils.h
396+   std::vector<SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
397+ 
398+   //  Convert strides using utility function from utils.h
399+   std::vector<StridesType> strides =
400+       convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
401+ 
402+   //  Create new tensor view that reinterprets the same memory with different
403+   //  shape/strides This creates a view, not a copy - the data pointer is shared
404+   std::shared_ptr<Tensor> tensor = executorch::extension::from_blob (
405+       data_ptr, //  Reuse the same memory from source tensor
406+       sizes, //  New sizes with explicit SizesType
407+       strides, //  New strides with explicit StridesType
408+       dtype_to_scalar_type (dtype) //  Convert dtype with explicit type casting
409+   );
410+ 
411+   if  (!tensor) {
412+     ET_LOG (Error, " Failed to create reinterpreted tensor view"  );
413+     return  Error::InvalidArgument;
414+   }
415+ 
416+   //  Store the tensor so it doesn't get destroyed
417+   tensors.insert (tensor);
418+ 
419+   *ret_new_tensor = tensor.get ();
420+ 
421+   //  Increment the reference count for this memory address only if it is owned
422+   //  by tensor
423+   memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
424+       ? NOT_OWN
425+       : memory_to_n_tensor[data_ptr] + 1 ;
426+ 
427+   return  Error::Ok;
428+ }
429+ 
313430} //  extern "C"
314431
315432} //  namespace cuda
0 commit comments