Skip to content

Commit 049c9fc

Browse files
authored
tensor empty strided
Differential Revision: D83094606 Pull Request resolved: pytorch#14549
1 parent df8d03b commit 049c9fc

File tree

8 files changed

+957
-0
lines changed

8 files changed

+957
-0
lines changed

backends/aoti/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
3636
switch (dtype) {
3737
case 6: // PyTorch's float32 dtype code
3838
return executorch::aten::ScalarType::Float;
39+
case 15: // PyTorch's bfloat16 dtype code
40+
return executorch::aten::ScalarType::BFloat16;
3941
// Future support for additional dtypes can be added here
4042
default:
4143
ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype);

backends/cuda/runtime/TARGETS

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.cxx_library(
6+
name = "runtime_shims",
7+
srcs = [
8+
"shims/memory.cpp",
9+
"shims/tensor_attribute.cpp",
10+
],
11+
headers = [
12+
"shims/memory.h",
13+
"shims/tensor_attribute.h",
14+
"shims/utils.h",
15+
],
16+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
17+
link_whole = True,
18+
supports_python_dlopen = True,
19+
# Constructor needed for backend registration.
20+
compiler_flags = ["-Wno-global-constructors"],
21+
visibility = ["@EXECUTORCH_CLIENTS"],
22+
deps = [
23+
"//executorch/backends/aoti:common_shims",
24+
"//executorch/extension/tensor:tensor",
25+
"//executorch/runtime/core:core",
26+
"//executorch/runtime/core/exec_aten:lib",
27+
"//executorch/runtime/platform:platform",
28+
],
29+
external_deps = [
30+
("cuda", None, "cuda-lazy"),
31+
],
32+
)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/aoti/common_shims.h>
10+
#include <executorch/backends/aoti/utils.h>
11+
#include <executorch/backends/cuda/runtime/shims/memory.h>
12+
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
13+
#include <executorch/backends/cuda/runtime/shims/utils.h>
14+
#include <executorch/runtime/platform/log.h>
15+
#include <cstdint>
16+
#include <cstdlib> // For posix_memalign
17+
#include <memory>
18+
#include <unordered_set>
19+
#include <vector>
20+
21+
// CUDA error checking macro
22+
#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \
23+
do { \
24+
const cudaError_t err = EXPR; \
25+
if (err == cudaSuccess) { \
26+
break; \
27+
} \
28+
ET_LOG( \
29+
Error, \
30+
"%s:%d CUDA error: %s", \
31+
__FILE__, \
32+
__LINE__, \
33+
cudaGetErrorString(err)); \
34+
return Error::Internal; \
35+
} while (0)
36+
37+
// Kernel launch check macro
38+
#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \
39+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError())
40+
41+
namespace executorch {
42+
namespace backends {
43+
namespace cuda {
44+
45+
using executorch::aten::SizesType;
46+
using executorch::aten::StridesType;
47+
using executorch::backends::aoti::dtype_to_element_size;
48+
using executorch::backends::aoti::dtype_to_scalar_type;
49+
50+
// Global storage for tensors and their metadata
51+
std::unordered_set<std::shared_ptr<Tensor>> tensors;
52+
53+
extern "C" {
54+
55+
AOTITorchError aoti_torch_empty_strided(
56+
int64_t ndim,
57+
const int64_t* sizes_ptr,
58+
const int64_t* strides_ptr,
59+
int32_t dtype,
60+
int32_t device_type,
61+
int32_t device_index,
62+
Tensor** ret_new_tensor) {
63+
// Check that device_index is always 0
64+
if (device_index != 0) {
65+
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
66+
return Error::InvalidArgument;
67+
}
68+
69+
// This requires us to reserve CUDA memory and put it into a ETensor
70+
void* ptr;
71+
int64_t numel = 1;
72+
for (int64_t i = 0; i < ndim; i++) {
73+
numel *= sizes_ptr[i];
74+
}
75+
76+
AOTITorchError dtype_error = validate_dtype(dtype);
77+
if (dtype_error != Error::Ok) {
78+
return dtype_error;
79+
}
80+
81+
size_t element_size = dtype_to_element_size(dtype);
82+
if (element_size == 0) {
83+
ET_LOG(Error, "Invalid element size for dtype: %d", dtype);
84+
return Error::InvalidArgument;
85+
}
86+
int64_t nbytes = numel * element_size;
87+
88+
if (device_type == 1) { // cuda
89+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMallocManaged(&ptr, nbytes));
90+
} else if (device_type == 0) { // cpu
91+
// Ensure 16-byte alignment for CPU memory to match CUDA requirements
92+
int result = posix_memalign(&ptr, 16, nbytes);
93+
if (result != 0) {
94+
ET_LOG(Error, "Failed to allocate aligned CPU memory");
95+
return Error::MemoryAllocationFailed;
96+
}
97+
if (ptr == nullptr) {
98+
ET_LOG(Error, "Failed to call posix_memalign");
99+
return Error::MemoryAllocationFailed;
100+
}
101+
} else {
102+
ET_LOG(
103+
Error,
104+
"Need to implement empty_strided for non-CUDA non-CPU device type %d",
105+
device_type);
106+
return Error::NotImplemented;
107+
}
108+
109+
// ETensor sizes
110+
auto sizes = convert_sizes_to_vector(ndim, sizes_ptr);
111+
112+
// ETensor strides
113+
auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
114+
115+
// ETensor creation with dynamic shape support for edge cases
116+
auto tensor = executorch::extension::from_blob(
117+
ptr, sizes, strides, dtype_to_scalar_type(dtype));
118+
119+
// Store the tensor so it doesn't get destroyed
120+
tensors.insert(tensor);
121+
*ret_new_tensor = tensor.get();
122+
123+
return Error::Ok;
124+
}
125+
126+
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
127+
void clear_all_tensors() {
128+
tensors.clear();
129+
}
130+
131+
} // extern "C"
132+
133+
} // namespace cuda
134+
} // namespace backends
135+
} // namespace executorch
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cuda_runtime.h>
12+
#include <executorch/backends/aoti/common_shims.h>
13+
#include <cstdint>
14+
15+
namespace executorch {
16+
namespace backends {
17+
namespace cuda {
18+
19+
using executorch::backends::aoti::AOTITorchError;
20+
using executorch::backends::aoti::Tensor;
21+
22+
extern "C" {
23+
24+
/**
25+
* Creates an uninitialized tensor with specified dimensions, strides, and
26+
* dtyper on either CPU or CUDA device.
27+
*
28+
* @param ndim Number of dimensions in the tensor
29+
* @param sizes_ptr Pointer to array of dimension sizes
30+
* @param strides_ptr Pointer to array of strides for each dimension
31+
* @param dtype Data type identifier (matches PyTorch scalar types)
32+
* @param device_type Device type (0=CPU, 1=CUDA)
33+
* @param device_index Device index (must be 0 for current implementation)
34+
* @param ret_new_tensor Output parameter for the created tensor
35+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
36+
* failure)
37+
*/
38+
AOTITorchError aoti_torch_empty_strided(
39+
int64_t ndim,
40+
const int64_t* sizes_ptr,
41+
const int64_t* strides_ptr,
42+
int32_t dtype,
43+
int32_t device_type,
44+
int32_t device_index,
45+
Tensor** ret_new_tensor);
46+
47+
// Function to clear all tensors from internal storage
48+
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
49+
void clear_all_tensors();
50+
51+
} // extern "C"
52+
53+
} // namespace cuda
54+
} // namespace backends
55+
} // namespace executorch
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
3+
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")
4+
5+
def cuda_shim_cpp_unittest(name):
6+
cpp_unittest(
7+
name = "test_" + name,
8+
srcs = [
9+
"test_" + name + ".cpp",
10+
],
11+
deps = [
12+
"//executorch/backends/aoti:common_shims",
13+
"//executorch/backends/cuda/runtime:runtime_shims",
14+
"//executorch/extension/tensor:tensor",
15+
"//executorch/runtime/core:core",
16+
"//executorch/runtime/platform:platform",
17+
"//executorch/runtime/core/exec_aten:lib",
18+
],
19+
external_deps = [
20+
("cuda", None, "cuda-lazy"),
21+
],
22+
)
23+
24+
def define_common_targets():
25+
"""Defines targets that should be shared between fbcode and xplat.
26+
27+
The directory containing this targets.bzl file should also contain both
28+
TARGETS and BUCK files that call this function.
29+
"""
30+
cuda_shim_cpp_unittest("aoti_torch_empty_strided")

0 commit comments

Comments
 (0)