Skip to content

Commit 092b946

Browse files
committed
[slimtensor] Add aoti_torch__reinterpret_tensor for SlimTensor
Add SlimTensor-based `aoti_torch__reinterpret_tensor()` - Creates a reinterpreted view of a tensor with new sizes, strides, and storage offset using SlimTensor's `as_strided()` method. The view shares the same underlying storage. Differential Revision: [D90126249](https://our.internmc.facebook.com/intern/diff/D90126249/) [ghstack-poisoned]
1 parent 9e10873 commit 092b946

File tree

4 files changed

+758
-3
lines changed

4 files changed

+758
-3
lines changed

backends/cuda/runtime/shims/memory_slim.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,47 @@ AOTITorchError aoti_torch_new_tensor_handle(
145145
return Error::Ok;
146146
}
147147

148+
AOTITorchError aoti_torch__reinterpret_tensor(
149+
Tensor* self,
150+
int64_t ndim,
151+
const int64_t* sizes_ptr,
152+
const int64_t* strides_ptr,
153+
int64_t storage_offset,
154+
Tensor** ret_new_tensor) {
155+
ET_CHECK_OR_RETURN_ERROR(
156+
self != nullptr,
157+
InvalidArgument,
158+
"aoti_torch__reinterpret_tensor: self is null");
159+
160+
ET_CHECK_OR_RETURN_ERROR(
161+
ret_new_tensor != nullptr,
162+
InvalidArgument,
163+
"aoti_torch__reinterpret_tensor: ret_new_tensor is null");
164+
165+
ET_CHECK_OR_RETURN_ERROR(
166+
ndim >= 0,
167+
InvalidArgument,
168+
"aoti_torch__reinterpret_tensor: ndim must be non-negative, got %lld",
169+
static_cast<long long>(ndim));
170+
171+
ET_CHECK_OR_RETURN_ERROR(
172+
!(sizes_ptr == nullptr && ndim > 0),
173+
InvalidArgument,
174+
"aoti_torch__reinterpret_tensor: sizes_ptr is null but ndim > 0");
175+
176+
IntArrayRef sizes(sizes_ptr, static_cast<size_t>(ndim));
177+
IntArrayRef strides(strides_ptr, static_cast<size_t>(ndim));
178+
179+
// Create a new tensor view using as_strided. This creates a tensor that
180+
// shares the same underlying storage but with different sizes, strides,
181+
// and storage offset. SlimTensor::as_strided() handles this via copy
182+
// constructor which shares the SharedPtr<Storage>.
183+
*ret_new_tensor =
184+
new Tensor(self->as_strided(sizes, strides, storage_offset));
185+
186+
return Error::Ok;
187+
}
188+
148189
} // extern "C"
149190

150191
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/memory_slim.h

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,30 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);
103103
* @param new_handle Output parameter for the new tensor handle
104104
* @return AOTITorchError error code (Error::Ok on success)
105105
*/
106-
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_new_tensor_handle(
107-
Tensor* orig_handle,
108-
Tensor** new_handle);
106+
AOTI_SHIM_EXPORT AOTITorchError
107+
aoti_torch_new_tensor_handle(Tensor* orig_handle, Tensor** new_handle);
108+
109+
/**
110+
* Creates a reinterpreted view of a tensor with new sizes, strides, and offset.
111+
*
112+
* This is equivalent to torch.as_strided() - it creates a new tensor that
113+
* shares the same underlying storage but with different view parameters.
114+
*
115+
* @param self Original tensor to reinterpret (must not be null)
116+
* @param ndim Number of dimensions for the new view
117+
* @param sizes_ptr Pointer to array of dimension sizes
118+
* @param strides_ptr Pointer to array of strides for each dimension
119+
* @param storage_offset Storage offset in number of elements
120+
* @param ret_new_tensor Output parameter for the reinterpreted tensor view
121+
* @return AOTITorchError error code (Error::Ok on success)
122+
*/
123+
AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
124+
Tensor* self,
125+
int64_t ndim,
126+
const int64_t* sizes_ptr,
127+
const int64_t* strides_ptr,
128+
int64_t storage_offset,
129+
Tensor** ret_new_tensor);
109130

110131
} // extern "C"
111132

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,4 @@ def define_common_targets():
7575
cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
7676
cuda_shim_slim_cpp_unittest("aoti_torch_delete_tensor_object")
7777
cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle")
78+
cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor")

0 commit comments

Comments
 (0)