File tree Expand file tree Collapse file tree 6 files changed +35
-32
lines changed Expand file tree Collapse file tree 6 files changed +35
-32
lines changed Original file line number Diff line number Diff line change @@ -204,16 +204,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr(
204204 void ** ret_data_ptr // returns borrowed reference
205205);
206206
207- AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_mutable_data_ptr (
208- AtenTensorHandle tensor,
209- void ** ret_data_ptr // returns borrowed reference
210- );
211-
212- AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_const_data_ptr (
213- AtenTensorHandle tensor,
214- const void ** ret_data_ptr // returns borrowed reference
215- );
216-
217207// Get the nbytes of the underlying storage
218208AOTI_TORCH_EXPORT AOTITorchError
219209aoti_torch_get_storage_size (AtenTensorHandle tensor, int64_t * ret_size);
Original file line number Diff line number Diff line change @@ -281,24 +281,6 @@ AOTITorchError aoti_torch_get_data_ptr(
281281 });
282282}
283283
284- AOTITorchError aoti_torch_get_const_data_ptr (
285- AtenTensorHandle tensor,
286- const void ** ret_data_ptr) {
287- AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE ({
288- at::Tensor* t = tensor_handle_to_tensor_pointer (tensor);
289- *ret_data_ptr = t->const_data_ptr ();
290- });
291- }
292-
293- AOTITorchError aoti_torch_get_mutable_data_ptr (
294- AtenTensorHandle tensor,
295- void ** ret_data_ptr) {
296- AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE ({
297- at::Tensor* t = tensor_handle_to_tensor_pointer (tensor);
298- *ret_data_ptr = t->mutable_data_ptr ();
299- });
300- }
301-
302284AOTITorchError aoti_torch_get_storage_size (
303285 AtenTensorHandle tensor,
304286 int64_t * ret_size) {
Original file line number Diff line number Diff line change @@ -560,3 +560,21 @@ torch_get_num_threads(uint32_t* out_num_threads) {
560560 AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE (
561561 { *out_num_threads = static_cast <uint32_t >(at::get_num_threads ()); });
562562}
563+
564+ AOTI_TORCH_EXPORT AOTITorchError
565+ torch_get_const_data_ptr (AtenTensorHandle tensor, const void ** ret_data_ptr) {
566+ AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE ({
567+ at::Tensor* t =
568+ torch::aot_inductor::tensor_handle_to_tensor_pointer (tensor);
569+ *ret_data_ptr = t->const_data_ptr ();
570+ });
571+ }
572+
573+ AOTI_TORCH_EXPORT AOTITorchError
574+ torch_get_mutable_data_ptr (AtenTensorHandle tensor, void ** ret_data_ptr) {
575+ AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE ({
576+ at::Tensor* t =
577+ torch::aot_inductor::tensor_handle_to_tensor_pointer (tensor);
578+ *ret_data_ptr = t->mutable_data_ptr ();
579+ });
580+ }
Original file line number Diff line number Diff line change @@ -92,6 +92,17 @@ AOTI_TORCH_EXPORT AOTITorchError torch_get_thread_idx(uint32_t* out_thread_idx);
9292AOTI_TORCH_EXPORT AOTITorchError
9393torch_get_num_threads (uint32_t * out_num_threads );
9494
95+ // Get a pointer to the underlying storage data
96+ AOTI_TORCH_EXPORT AOTITorchError torch_get_mutable_data_ptr (
97+ AtenTensorHandle tensor ,
98+ void * * ret_data_ptr // returns borrowed reference
99+ );
100+
101+ AOTI_TORCH_EXPORT AOTITorchError torch_get_const_data_ptr (
102+ AtenTensorHandle tensor ,
103+ const void * * ret_data_ptr // returns borrowed reference
104+ );
105+
95106#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
96107
97108#ifdef __cplusplus
Original file line number Diff line number Diff line change @@ -33,6 +33,7 @@ inline Device Tensor::device() const {
3333 return Device (extension_device_type, static_cast <DeviceIndex>(device_index));
3434}
3535
36+ #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
3637// The following data ptr cast methods mirror the methods defined in
3738// aten/src/ATen/templates/TensorMethods.cpp
3839#define DEFINE_DATA_PTR_CAST (T, name, PRED ) \
@@ -63,5 +64,6 @@ DEFINE_CAST(uint32_t, UInt32)
6364DEFINE_CAST(uint64_t , UInt64)
6465#undef DEFINE_CAST
6566#undef _PRED
67+ #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
6668
6769HIDDEN_NAMESPACE_END (torch, stable)
Original file line number Diff line number Diff line change @@ -86,17 +86,16 @@ class Tensor {
8686 return data_ptr;
8787 }
8888
89+ #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
8990 void * mutable_data_ptr () const {
9091 void * data_ptr{};
91- TORCH_ERROR_CODE_CHECK (
92- aoti_torch_get_mutable_data_ptr (ath_.get (), &data_ptr));
92+ TORCH_ERROR_CODE_CHECK (torch_get_mutable_data_ptr (ath_.get (), &data_ptr));
9393 return data_ptr;
9494 }
9595
9696 const void * const_data_ptr () const {
9797 const void * data_ptr{};
98- TORCH_ERROR_CODE_CHECK (
99- aoti_torch_get_const_data_ptr (ath_.get (), &data_ptr));
98+ TORCH_ERROR_CODE_CHECK (torch_get_const_data_ptr (ath_.get (), &data_ptr));
10099 return data_ptr;
101100 }
102101
@@ -105,6 +104,7 @@ class Tensor {
105104
106105 template <typename T, std::enable_if_t <!std::is_const_v<T>, int > = 0 >
107106 const T* const_data_ptr () const ;
107+ #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
108108
109109 int64_t dim () const {
110110 int64_t dim;
You can’t perform that action at this time.
0 commit comments