Skip to content

Commit f4cc1b3

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
aoti_torch_create_tensor_from_blob_v2 (#14604)
Summary: Pull Request resolved: #14604 This function introduce aoti_torch_create_tensor_from_blob_v2, a function that create tensor from data blob and custom stride and size. Worth to notice that unlike aoti_torch_empty_strided, the tensor created by aoti_torch_create_tensor_from_blob_v2 will not have the control of the memory blob. Therefore when we delete it, the memory will not be freed. Differential Revision: D83094602
1 parent 89034e5 commit f4cc1b3

File tree

6 files changed

+947
-21
lines changed

6 files changed

+947
-21
lines changed

backends/aoti/utils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@ inline AOTITorchError validate_storage_offset(int64_t storage_offset) {
7373
return Error::Ok;
7474
}
7575

76+
// Check if tensor is in contiguous memory format (NCHW for 4D tensors)
77+
// Contiguous format means strides decrease from left to right:
78+
// For NCHW: strides = [C*H*W, H*W, W, 1]
79+
inline bool is_tensor_contiguous(
80+
int64_t ndim,
81+
const int64_t* sizes,
82+
const int64_t* strides) {
83+
int64_t expected_stride = 1;
84+
for (int64_t i = ndim - 1; i >= 0; i--) {
85+
if (strides[i] != expected_stride) {
86+
return false;
87+
}
88+
expected_stride *= sizes[i];
89+
}
90+
return true;
91+
}
92+
7693
} // extern "C"
7794

7895
} // namespace aoti

backends/cuda/runtime/shims/memory.cpp

Lines changed: 118 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,10 @@
1515
#include <cstdint>
1616
#include <cstdlib> // For posix_memalign
1717
#include <memory>
18+
#include <unordered_map>
1819
#include <unordered_set>
1920
#include <vector>
2021

21-
// CUDA error checking macro
22-
#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \
23-
do { \
24-
const cudaError_t err = EXPR; \
25-
if (err == cudaSuccess) { \
26-
break; \
27-
} \
28-
ET_LOG( \
29-
Error, \
30-
"%s:%d CUDA error: %s", \
31-
__FILE__, \
32-
__LINE__, \
33-
cudaGetErrorString(err)); \
34-
return Error::Internal; \
35-
} while (0)
36-
37-
// Kernel launch check macro
38-
#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \
39-
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError())
40-
4122
namespace executorch {
4223
namespace backends {
4324
namespace cuda {
@@ -46,12 +27,105 @@ using executorch::aten::SizesType;
4627
using executorch::aten::StridesType;
4728
using executorch::backends::aoti::dtype_to_element_size;
4829
using executorch::backends::aoti::dtype_to_scalar_type;
30+
using executorch::backends::aoti::validate_storage_offset;
4931

5032
// Global storage for tensors and their metadata
5133
std::unordered_set<std::shared_ptr<Tensor>> tensors;
34+
// Global storage for tensor ownership information
35+
std::unordered_map<Tensor*, bool> is_tensor_own_memory;
5236

5337
extern "C" {
5438

39+
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
40+
void* data,
41+
int64_t ndim,
42+
const int64_t* sizes_ptr,
43+
const int64_t* strides_ptr,
44+
int64_t storage_offset,
45+
int32_t dtype,
46+
int32_t device_type,
47+
int32_t device_index,
48+
Tensor** ret_new_tensor,
49+
int32_t layout,
50+
const uint8_t* opaque_metadata,
51+
int64_t opaque_metadata_size) {
52+
// TODO(gasoonjia): verify given data is on the target device
53+
(void)device_type;
54+
(void)opaque_metadata;
55+
(void)layout;
56+
(void)opaque_metadata_size;
57+
58+
// Validate input parameters first
59+
if (data == nullptr) {
60+
ET_LOG(
61+
Error,
62+
"aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null");
63+
return Error::InvalidArgument;
64+
}
65+
66+
if (sizes_ptr == nullptr && ndim > 0) {
67+
ET_LOG(
68+
Error,
69+
"aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null");
70+
return Error::InvalidArgument;
71+
}
72+
73+
if (ret_new_tensor == nullptr) {
74+
ET_LOG(
75+
Error,
76+
"aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null");
77+
return Error::InvalidArgument;
78+
}
79+
80+
// Check that device_index is always 0
81+
if (device_index != 0) {
82+
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
83+
return Error::InvalidArgument;
84+
}
85+
86+
// Validate dtype using SupportedDTypes from utils.h
87+
AOTITorchError dtype_error = validate_dtype(dtype);
88+
if (dtype_error != Error::Ok) {
89+
return dtype_error;
90+
}
91+
92+
// Storage offset must be 0 since from_blob cannot handle different offsets
93+
AOTITorchError storage_offset_error = validate_storage_offset(storage_offset);
94+
if (storage_offset_error != Error::Ok) {
95+
return storage_offset_error;
96+
}
97+
98+
// Convert sizes to the format expected by ExecutorTorch using SizesType
99+
std::vector<executorch::aten::SizesType> sizes =
100+
convert_sizes_to_vector(ndim, sizes_ptr);
101+
102+
// Convert strides using the common helper function with StridesType
103+
std::vector<executorch::aten::StridesType> strides =
104+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
105+
106+
// Create ExecutorTorch tensor that wraps the existing memory
107+
// Note: We're NOT copying the data, just wrapping it
108+
auto tensor = executorch::extension::from_blob(
109+
data, // existing memory (don't copy!)
110+
sizes, // tensor dimensions
111+
strides, // tensor strides (allows different strides)
112+
dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType
113+
);
114+
115+
if (!tensor) {
116+
ET_LOG(Error, "Failed to create tensor from blob");
117+
return Error::InvalidArgument;
118+
}
119+
120+
// Store the tensor so it doesn't get destroyed
121+
tensors.insert(tensor);
122+
123+
*ret_new_tensor = tensor.get();
124+
is_tensor_own_memory[tensor.get()] = false;
125+
126+
return Error::Ok;
127+
}
128+
55129
AOTITorchError aoti_torch_empty_strided(
56130
int64_t ndim,
57131
const int64_t* sizes_ptr,
@@ -119,6 +193,7 @@ AOTITorchError aoti_torch_empty_strided(
119193
// Store the tensor so it doesn't get destroyed
120194
tensors.insert(tensor);
121195
*ret_new_tensor = tensor.get();
196+
is_tensor_own_memory[tensor.get()] = true;
122197

123198
return Error::Ok;
124199
}
@@ -156,9 +231,32 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
156231
// If tensor not found in our tracking, it's invalid
157232
if (!found_in_tensors) {
158233
ET_LOG(Error, "Didn't find tensor %p", tensor);
234+
// Clean up any stale ownership info
235+
is_tensor_own_memory.erase(tensor);
159236
return Error::InvalidArgument;
160237
}
161238

239+
// Check ownership before cleaning up metadata
240+
auto ownership_it = is_tensor_own_memory.find(tensor);
241+
bool owns_memory = (ownership_it != is_tensor_own_memory.end())
242+
? ownership_it->second
243+
: false;
244+
245+
// Clean up local metadata maps immediately to prevent use-after-free
246+
is_tensor_own_memory.erase(tensor);
247+
248+
if (!owns_memory) {
249+
// Don't free memory since the tensor doesn't own it, but still remove from
250+
// tracking
251+
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
252+
if (it->get() == tensor) {
253+
tensors.erase(it);
254+
break;
255+
}
256+
}
257+
return Error::Ok;
258+
}
259+
162260
// Find and delete the tensor
163261
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
164262
if (it->get() == tensor) {

backends/cuda/runtime/shims/memory.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,44 @@ using executorch::backends::aoti::Tensor;
2121

2222
extern "C" {
2323

24+
/**
25+
* Creates a tensor object from an existing memory blob without copying the
26+
* data. The tensor will wrap the provided memory and will not take ownership of
27+
* it. When the tensor is deleted, the original memory will remain valid and
28+
* must be freed by the caller.
29+
*
30+
* @param data Pointer to the memory blob to wrap (must not be null)
31+
* @param ndim Number of dimensions in the tensor
32+
* @param sizes_ptr Pointer to array of dimension sizes (using SizesType)
33+
* @param strides_ptr Pointer to array of strides for each dimension (using
34+
* StridesType, can be null for contiguous)
35+
* @param storage_offset Storage offset (must be 0 for current implementation)
36+
* @param dtype Data type identifier (supports FLOAT32 and BFLOAT16 from
37+
* SupportedDTypes)
38+
* @param device_type Device type (CPU=0, CUDA=1 from SupportedDevices)
39+
* @param device_index Device index (must be 0 for current implementation)
40+
* @param ret_new_tensor Output parameter for the created tensor (must not be
41+
* null)
42+
* @param layout Tensor layout identifier (0=strided)
43+
* @param opaque_metadata Optional metadata pointer (can be null)
44+
* @param opaque_metadata_size Size of opaque metadata in bytes
45+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
46+
* failure)
47+
*/
48+
AOTITorchError aoti_torch_create_tensor_from_blob_v2(
49+
void* data,
50+
int64_t ndim,
51+
const int64_t* sizes_ptr,
52+
const int64_t* strides_ptr,
53+
int64_t storage_offset,
54+
int32_t dtype,
55+
int32_t device_type,
56+
int32_t device_index,
57+
Tensor** ret_new_tensor,
58+
int32_t layout,
59+
const uint8_t* opaque_metadata,
60+
int64_t opaque_metadata_size);
61+
2462
/**
2563
* Creates an uninitialized tensor with specified dimensions, strides, and
2664
* dtyper on either CPU or CUDA device.
@@ -55,7 +93,6 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);
5593

5694
// Function to clear all tensors from internal storage
5795
void clear_all_tensors();
58-
5996
} // extern "C"
6097

6198
} // namespace cuda

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ def define_common_targets():
2929
"""
3030
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
3131
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
32+
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")

0 commit comments

Comments
 (0)