Skip to content

Commit 34db573

Browse files
committed
[slimtensor] Add aoti_torch_assign_tensors_out for SlimTensor
Add SlimTensor-based `aoti_torch_assign_tensors_out()` - Creates a new tensor handle that moves the ownership of storage from the source tensor to ret tensor via SharedPtr. Differential Revision: [D90126243](https://our.internmc.facebook.com/intern/diff/D90126243/) [ghstack-poisoned]
1 parent 1e2ee90 commit 34db573

File tree

4 files changed

+471
-0
lines changed

4 files changed

+471
-0
lines changed

backends/cuda/runtime/shims/memory_slim.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,25 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
206206
return Error::Ok;
207207
}
208208

209+
AOTITorchError aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst) {
210+
ET_CHECK_OR_RETURN_ERROR(
211+
src != nullptr,
212+
InvalidArgument,
213+
"aoti_torch_assign_tensors_out: src is null");
214+
215+
ET_CHECK_OR_RETURN_ERROR(
216+
ret_dst != nullptr,
217+
InvalidArgument,
218+
"aoti_torch_assign_tensors_out: ret_dst is null");
219+
220+
// Move the source tensor into the destination. After this operation,
221+
// the source tensor will be left in an undefined state (reset).
222+
// This differs from aoti_torch_new_tensor_handle which copies the tensor.
223+
*ret_dst = new Tensor(std::move(*src));
224+
225+
return Error::Ok;
226+
}
227+
209228
} // extern "C"
210229

211230
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/memory_slim.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,20 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
143143
AOTI_SHIM_EXPORT AOTITorchError
144144
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
145145

146+
/**
147+
* Moves a tensor into a new handle and assigns it to the output parameter.
148+
*
149+
* Unlike aoti_torch_new_tensor_handle which copies, this function moves the
150+
* source tensor into the destination. After this operation, the source tensor
151+
* is left in an undefined/reset state and should not be used.
152+
*
153+
* @param src Source tensor to move from (must not be null, will be reset)
154+
* @param ret_dst Output parameter for the new tensor handle
155+
* @return AOTITorchError error code (Error::Ok on success)
156+
*/
157+
AOTI_SHIM_EXPORT AOTITorchError
158+
aoti_torch_assign_tensors_out(Tensor* src, Tensor** ret_dst);
159+
146160
} // extern "C"
147161

148162
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,4 @@ def define_common_targets():
7777
cuda_shim_slim_cpp_unittest("aoti_torch_new_tensor_handle")
7878
cuda_shim_slim_cpp_unittest("aoti_torch__reinterpret_tensor")
7979
cuda_shim_slim_cpp_unittest("aoti_torch_copy_")
80+
cuda_shim_slim_cpp_unittest("aoti_torch_assign_tensors_out")

0 commit comments

Comments
 (0)