Skip to content

Commit d2ccb5b

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Follow up on pytorch#161891 move additions to stable shim and use version guards (pytorch#168025)
Address pytorch#161891 (comment) Pull Request resolved: pytorch#168025 Approved by: https://github.com/janeyx99
1 parent 8cb8b6c commit d2ccb5b

File tree

6 files changed

+35
-32
lines changed

6 files changed

+35
-32
lines changed

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr(
204204
void** ret_data_ptr // returns borrowed reference
205205
);
206206

207-
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_mutable_data_ptr(
208-
AtenTensorHandle tensor,
209-
void** ret_data_ptr // returns borrowed reference
210-
);
211-
212-
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_const_data_ptr(
213-
AtenTensorHandle tensor,
214-
const void** ret_data_ptr // returns borrowed reference
215-
);
216-
217207
// Get the nbytes of the underlying storage
218208
AOTI_TORCH_EXPORT AOTITorchError
219209
aoti_torch_get_storage_size(AtenTensorHandle tensor, int64_t* ret_size);

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -281,24 +281,6 @@ AOTITorchError aoti_torch_get_data_ptr(
281281
});
282282
}
283283

284-
AOTITorchError aoti_torch_get_const_data_ptr(
285-
AtenTensorHandle tensor,
286-
const void** ret_data_ptr) {
287-
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
288-
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
289-
*ret_data_ptr = t->const_data_ptr();
290-
});
291-
}
292-
293-
AOTITorchError aoti_torch_get_mutable_data_ptr(
294-
AtenTensorHandle tensor,
295-
void** ret_data_ptr) {
296-
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
297-
at::Tensor* t = tensor_handle_to_tensor_pointer(tensor);
298-
*ret_data_ptr = t->mutable_data_ptr();
299-
});
300-
}
301-
302284
AOTITorchError aoti_torch_get_storage_size(
303285
AtenTensorHandle tensor,
304286
int64_t* ret_size) {

torch/csrc/shim_common.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,3 +560,21 @@ torch_get_num_threads(uint32_t* out_num_threads) {
560560
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
561561
{ *out_num_threads = static_cast<uint32_t>(at::get_num_threads()); });
562562
}
563+
564+
AOTI_TORCH_EXPORT AOTITorchError
565+
torch_get_const_data_ptr(AtenTensorHandle tensor, const void** ret_data_ptr) {
566+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
567+
at::Tensor* t =
568+
torch::aot_inductor::tensor_handle_to_tensor_pointer(tensor);
569+
*ret_data_ptr = t->const_data_ptr();
570+
});
571+
}
572+
573+
AOTI_TORCH_EXPORT AOTITorchError
574+
torch_get_mutable_data_ptr(AtenTensorHandle tensor, void** ret_data_ptr) {
575+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
576+
at::Tensor* t =
577+
torch::aot_inductor::tensor_handle_to_tensor_pointer(tensor);
578+
*ret_data_ptr = t->mutable_data_ptr();
579+
});
580+
}

torch/csrc/stable/c/shim.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,17 @@ AOTI_TORCH_EXPORT AOTITorchError torch_get_thread_idx(uint32_t* out_thread_idx);
9292
AOTI_TORCH_EXPORT AOTITorchError
9393
torch_get_num_threads(uint32_t* out_num_threads);
9494

95+
// Get a pointer to the underlying storage data
96+
AOTI_TORCH_EXPORT AOTITorchError torch_get_mutable_data_ptr(
97+
AtenTensorHandle tensor,
98+
void** ret_data_ptr // returns borrowed reference
99+
);
100+
101+
AOTI_TORCH_EXPORT AOTITorchError torch_get_const_data_ptr(
102+
AtenTensorHandle tensor,
103+
const void** ret_data_ptr // returns borrowed reference
104+
);
105+
95106
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
96107

97108
#ifdef __cplusplus

torch/csrc/stable/tensor_inl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ inline Device Tensor::device() const {
3333
return Device(extension_device_type, static_cast<DeviceIndex>(device_index));
3434
}
3535

36+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
3637
// The following data ptr cast methods mirror the methods defined in
3738
// aten/src/ATen/templates/TensorMethods.cpp
3839
#define DEFINE_DATA_PTR_CAST(T, name, PRED) \
@@ -63,5 +64,6 @@ DEFINE_CAST(uint32_t, UInt32)
6364
DEFINE_CAST(uint64_t, UInt64)
6465
#undef DEFINE_CAST
6566
#undef _PRED
67+
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
6668

6769
HIDDEN_NAMESPACE_END(torch, stable)

torch/csrc/stable/tensor_struct.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,16 @@ class Tensor {
8686
return data_ptr;
8787
}
8888

89+
#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
8990
void* mutable_data_ptr() const {
9091
void* data_ptr{};
91-
TORCH_ERROR_CODE_CHECK(
92-
aoti_torch_get_mutable_data_ptr(ath_.get(), &data_ptr));
92+
TORCH_ERROR_CODE_CHECK(torch_get_mutable_data_ptr(ath_.get(), &data_ptr));
9393
return data_ptr;
9494
}
9595

9696
const void* const_data_ptr() const {
9797
const void* data_ptr{};
98-
TORCH_ERROR_CODE_CHECK(
99-
aoti_torch_get_const_data_ptr(ath_.get(), &data_ptr));
98+
TORCH_ERROR_CODE_CHECK(torch_get_const_data_ptr(ath_.get(), &data_ptr));
10099
return data_ptr;
101100
}
102101

@@ -105,6 +104,7 @@ class Tensor {
105104

106105
template <typename T, std::enable_if_t<!std::is_const_v<T>, int> = 0>
107106
const T* const_data_ptr() const;
107+
#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0
108108

109109
int64_t dim() const {
110110
int64_t dim;

0 commit comments

Comments
 (0)