@@ -582,6 +582,96 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
582582 return Error::Ok;
583583}
584584
585+ AOTITorchError aoti_torch_new_tensor_handle (
586+ Tensor* orig_handle,
587+ Tensor** new_handle) {
588+ // Validate input parameters
589+ ET_CHECK_OR_RETURN_ERROR (
590+ orig_handle != nullptr ,
591+ InvalidArgument,
592+ " aoti_torch_new_tensor_handle failed: orig_handle is null" );
593+
594+ ET_CHECK_OR_RETURN_ERROR (
595+ new_handle != nullptr ,
596+ InvalidArgument,
597+ " aoti_torch_new_tensor_handle failed: new_handle is null" );
598+
599+ // Get metadata from the original tensor
600+ int64_t * sizes_ptr;
601+ int64_t * strides_ptr;
602+ int32_t dtype;
603+ int32_t device_type;
604+ int32_t device_index;
605+
606+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_sizes (orig_handle, &sizes_ptr));
607+ ET_CHECK_OK_OR_RETURN_ERROR (
608+ aoti_torch_get_strides (orig_handle, &strides_ptr));
609+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_dtype (orig_handle, &dtype));
610+ ET_CHECK_OK_OR_RETURN_ERROR (
611+ aoti_torch_get_device_type (orig_handle, &device_type));
612+ ET_CHECK_OK_OR_RETURN_ERROR (
613+ aoti_torch_get_device_index (orig_handle, &device_index));
614+
615+ int64_t ndim = orig_handle->dim ();
616+
617+ // Validate dtype
618+ ET_CHECK_OK_OR_RETURN_ERROR (validate_dtype (dtype));
619+
620+ // Ensure device_index is always 0
621+ ET_CHECK_OR_RETURN_ERROR (
622+ device_index == 0 ,
623+ InvalidArgument,
624+ " device_index must be 0, got: %d" ,
625+ device_index);
626+
627+ // Get the original data pointer from the source tensor
628+ void * data_ptr = orig_handle->mutable_data_ptr ();
629+ ET_CHECK_OR_RETURN_ERROR (
630+ data_ptr != nullptr ,
631+ InvalidArgument,
632+ " Source tensor has null data pointer" );
633+
634+ // Check if the given memory is in the map
635+ auto memory_it = memory_to_n_tensor.find (data_ptr);
636+ ET_CHECK_OR_RETURN_ERROR (
637+ memory_it != memory_to_n_tensor.end (),
638+ InvalidArgument,
639+ " Memory address %p is not being tracked by reference counting system" ,
640+ data_ptr);
641+
642+ // Convert sizes and strides to vectors
643+ std::vector<SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
644+ std::vector<StridesType> strides =
645+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
646+
647+ // Create new tensor that shares the same memory as the original
648+ // This is similar to PyTorch's Tensor copy constructor - creates a new
649+ // tensor object that shares the same underlying storage
650+ std::shared_ptr<Tensor> tensor = make_tensor (
651+ sizes, // Same sizes as original
652+ data_ptr, // Share the same memory from source tensor
653+ {}, // dim_order (empty, will be auto-generated)
654+ strides, // Same strides as original
655+ dtype_to_scalar_type (dtype) // Same dtype as original
656+ );
657+
658+ ET_CHECK_OR_RETURN_ERROR (
659+ tensor != nullptr , InvalidArgument, " Failed to create new tensor handle" );
660+
661+ // Store the tensor so it doesn't get destroyed
662+ tensors.insert (tensor);
663+
664+ *new_handle = tensor.get ();
665+
666+ // Increment the reference count for this memory address only if it is owned
667+ // by tensor
668+ memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
669+ ? NOT_OWN
670+ : memory_to_n_tensor[data_ptr] + 1 ;
671+
672+ return Error::Ok;
673+ }
674+
585675AOTITorchError aoti_torch__reinterpret_tensor (
586676 Tensor* self,
587677 int64_t ndim,
0 commit comments