Skip to content

Commit da85148

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
tensor empty strided (#14549)
Summary: this diff introduce aoti_tensor_empty_strided to et cuda backend, which will be one of the main functions to create empty tensor using the given stride. Differential Revision: D83094606
1 parent 83daccf commit da85148

File tree

7 files changed

+709
-0
lines changed

7 files changed

+709
-0
lines changed

backends/cuda/runtime/TARGETS

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.cxx_library(
6+
name = "runtime_shims",
7+
srcs = [
8+
"shims/memory.cpp",
9+
"shims/tensor_attribute.cpp",
10+
],
11+
headers = [
12+
"shims/memory.h",
13+
"shims/tensor_attribute.h",
14+
"shims/utils.h",
15+
],
16+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
17+
link_whole = True,
18+
supports_python_dlopen = True,
19+
# Constructor needed for backend registration.
20+
compiler_flags = ["-Wno-global-constructors"],
21+
visibility = ["@EXECUTORCH_CLIENTS"],
22+
deps = [
23+
"//executorch/backends/aoti:common_shims",
24+
"//executorch/extension/tensor:tensor",
25+
"//executorch/runtime/core:core",
26+
"//executorch/runtime/core/exec_aten:lib",
27+
"//executorch/runtime/platform:platform",
28+
],
29+
external_deps = [
30+
("cuda", None, "cuda-lazy"),
31+
],
32+
)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/common_shims.h>
10+
#include <executorch/backends/aoti/utils.h>
11+
#include <executorch/backends/cuda/runtime/shims/memory.h>
12+
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
13+
#include <executorch/backends/cuda/runtime/shims/utils.h>
14+
#include <executorch/runtime/platform/log.h>
15+
#include <cstdint>
16+
#include <cstdio>
17+
#include <cstdlib> // For posix_memalign
18+
#include <cstring>
19+
#include <memory>
20+
#include <unordered_set>
21+
#include <vector>
22+
23+
namespace executorch {
24+
namespace backends {
25+
namespace cuda {
26+
27+
// Global storage for tensors and their metadata
28+
std::unordered_set<std::shared_ptr<Tensor>> tensors;
29+
30+
extern "C" {
31+
32+
AOTITorchError aoti_torch_empty_strided(
33+
int64_t ndim,
34+
const int64_t* sizes_ptr,
35+
const int64_t* strides_ptr,
36+
int32_t dtype,
37+
int32_t device_type,
38+
int32_t device_index,
39+
Tensor** ret_new_tensor) {
40+
// This requires us to reserve CUDA memory and put it into a ETensor
41+
void* ptr;
42+
int64_t numel = 1;
43+
for (int i = 0; i < ndim; i++) {
44+
numel *= sizes_ptr[i];
45+
}
46+
47+
AOTITorchError dtype_error = validate_dtype(dtype);
48+
if (dtype_error != Error::Ok) {
49+
return dtype_error;
50+
}
51+
52+
size_t element_size = dtype_to_element_size(dtype);
53+
if (element_size == 0) {
54+
ET_LOG(Error, "Invalid element size for dtype: %d", dtype);
55+
return Error::InvalidArgument;
56+
}
57+
int64_t nbytes = numel * element_size;
58+
59+
if (device_type == 1) { // cuda
60+
cudaError_t err = cudaMalloc(&ptr, nbytes);
61+
if (err != cudaSuccess) {
62+
ET_LOG(
63+
Error,
64+
"failed to allocate %ld bytes: %s",
65+
nbytes,
66+
cudaGetErrorString(err));
67+
return Error::MemoryAllocationFailed;
68+
}
69+
} else if (device_type == 0) { // cpu
70+
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
71+
// do we need to do this in cuda backend?
72+
int result = posix_memalign(&ptr, 16, nbytes);
73+
if (result != 0) {
74+
ET_LOG(Error, "Failed to allocate aligned CPU memory");
75+
return Error::MemoryAllocationFailed;
76+
}
77+
if (ptr == nullptr) {
78+
ET_LOG(Error, "Failed to call posix_memalign");
79+
return Error::MemoryAllocationFailed;
80+
}
81+
} else {
82+
ET_LOG(
83+
Error,
84+
"Need to implement empty_strided for non-CUDA non-CPU device type %d",
85+
device_type);
86+
return Error::NotImplemented;
87+
}
88+
89+
// ETensor sizes
90+
std::vector<int32_t> sizes(ndim);
91+
for (int i = 0; i < ndim; i++) {
92+
sizes[i] = sizes_ptr[i];
93+
}
94+
95+
// ETensor strides
96+
std::vector<int32_t> strides(ndim);
97+
if (strides_ptr != nullptr) {
98+
// Use provided strides. it is ok if provided strides here is not contiguous
99+
// strides since it will be used internally in CUDA delegate.
100+
for (int i = 0; i < ndim; i++) {
101+
strides[i] = strides_ptr[i];
102+
}
103+
} else {
104+
// Calculate strides from sizes using ExecutorTorch's algorithm
105+
if (ndim > 0) {
106+
strides[ndim - 1] = 1; // Last dimension has stride 1
107+
for (int i = ndim - 2; i >= 0; i--) {
108+
if (sizes_ptr[i + 1] == 0) {
109+
strides[i] = strides[i + 1]; // Copy stride when size is 0
110+
} else {
111+
strides[i] = strides[i + 1] * sizes_ptr[i + 1];
112+
}
113+
}
114+
}
115+
}
116+
117+
// ETensor creation with dynamic shape support for edge cases
118+
auto tensor = executorch::extension::from_blob(
119+
ptr, sizes, strides, dtype_to_scalar_type(dtype));
120+
121+
// Store the tensor so it doesn't get destroyed
122+
tensors.insert(tensor);
123+
*ret_new_tensor = tensor.get();
124+
125+
return Error::Ok;
126+
}
127+
128+
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
129+
void clear_all_tensors() {
130+
tensors.clear();
131+
}
132+
133+
} // extern "C"
134+
135+
} // namespace cuda
136+
} // namespace backends
137+
} // namespace executorch
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cuda_runtime.h>
12+
#include <executorch/backends/aoti/common_shims.h>
13+
#include <cstdint>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace cuda {
18+
19+
using namespace executorch::backends::aoti;
20+
21+
extern "C" {
22+
23+
AOTITorchError aoti_torch_empty_strided(
24+
int64_t ndim,
25+
const int64_t* sizes_ptr,
26+
const int64_t* strides_ptr,
27+
int32_t dtype,
28+
int32_t device_type,
29+
int32_t device_index,
30+
Tensor** ret_new_tensor);
31+
32+
// Function to clear all tensors from internal storage
33+
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
34+
void clear_all_tensors();
35+
36+
} // extern "C"
37+
38+
} // namespace cuda
39+
} // namespace backends
40+
} // namespace executorch
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
3+
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")
4+
5+
def cuda_shim_cpp_unittest(name):
6+
cpp_unittest(
7+
name = "test_" + name,
8+
srcs = [
9+
"test_" + name + ".cpp",
10+
],
11+
deps = [
12+
"//executorch/backends/aoti:common_shims",
13+
"//executorch/backends/cuda/runtime:runtime_shims",
14+
"//executorch/extension/tensor:tensor",
15+
"//executorch/runtime/core:core",
16+
"//executorch/runtime/platform:platform",
17+
"//executorch/runtime/core/exec_aten:lib",
18+
],
19+
external_deps = [
20+
("cuda", None, "cuda-lazy"),
21+
],
22+
)
23+
24+
def define_common_targets():
25+
"""Defines targets that should be shared between fbcode and xplat.
26+
27+
The directory containing this targets.bzl file should also contain both
28+
TARGETS and BUCK files that call this function.
29+
"""
30+
cuda_shim_cpp_unittest("aoti_torch_empty_strided")

0 commit comments

Comments
 (0)