Skip to content

Commit 91d322e

Browse files
committed
Update on "gemma3 e2e runner on cuda"
This diff introduces e2e runner for gemma3 model on cuda delegating using AOTI library, which is guarded by CI. Also other necessary infrastructure updates for building and running the `gemma3 e2e runner` on CUDA devices. Differential Revision: [D85087532](https://our.internmc.facebook.com/intern/diff/D85087532/) [ghstack-poisoned]
2 parents 73f19aa + a9ac599 commit 91d322e

File tree

8 files changed

+244
-45
lines changed

8 files changed

+244
-45
lines changed

backends/cuda/CMakeLists.txt

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,40 @@ find_package(CUDAToolkit REQUIRED)
3434
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
3535
find_package_torch()
3636

37+
# CUDA tensor maker for backends that support incontiguous tensors
38+
set(_tensor_maker_sources runtime/tensor/tensor_maker.cpp)
39+
add_library(cuda_tensor_maker STATIC ${_tensor_maker_sources})
40+
target_include_directories(
41+
cuda_tensor_maker
42+
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
43+
$<INSTALL_INTERFACE:include>
44+
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
45+
)
46+
target_compile_options(
47+
cuda_tensor_maker
48+
PUBLIC $<$<CXX_COMPILER_ID:MSVC>:/EHsc /GR>
49+
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-fexceptions -frtti -fPIC>
50+
)
51+
# Ensure symbols are exported properly
52+
if(APPLE)
53+
target_link_options(cuda_tensor_maker PUBLIC -Wl,-export_dynamic)
54+
else()
55+
target_link_options(
56+
cuda_tensor_maker PUBLIC
57+
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
58+
)
59+
endif()
60+
61+
# Link against ExecuTorch core libraries
62+
target_link_libraries(cuda_tensor_maker PUBLIC executorch ${CMAKE_DL_LIBS})
63+
executorch_target_link_options_shared_lib(cuda_tensor_maker)
64+
65+
install(
66+
TARGETS cuda_tensor_maker
67+
EXPORT ExecuTorchTargets
68+
DESTINATION lib
69+
)
70+
3771
# CUDA-specific AOTI functionality
3872
set(_aoti_cuda_sources
3973
runtime/cuda_backend.cpp
@@ -62,9 +96,9 @@ target_link_options(
6296
aoti_cuda PUBLIC $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
6397
)
6498

65-
# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
99+
# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and PyTorch CUDA libraries
66100
target_link_libraries(
67-
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
101+
aoti_cuda PUBLIC aoti_common cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS}
68102
)
69103
# If you need other CUDA libraries, link them similarly:
70104
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)

backends/cuda/runtime/TARGETS

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,25 @@ runtime.cxx_library(
2727
],
2828
)
2929

30+
runtime.cxx_library(
31+
name = "tensor_maker",
32+
srcs = [
33+
"tensor/tensor_maker.cpp",
34+
],
35+
headers = [
36+
"tensor/tensor_maker.h",
37+
],
38+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
39+
link_whole = True,
40+
supports_python_dlopen = True,
41+
visibility = ["@EXECUTORCH_CLIENTS"],
42+
deps = [
43+
"//executorch/runtime/core:core",
44+
"//executorch/runtime/core/exec_aten:lib",
45+
"//executorch/runtime/core/exec_aten/util:tensor_util",
46+
],
47+
)
48+
3049
runtime.cxx_library(
3150
name = "runtime_shims",
3251
srcs = [
@@ -52,8 +71,8 @@ runtime.cxx_library(
5271
compiler_flags = ["-Wno-global-constructors"],
5372
visibility = ["@EXECUTORCH_CLIENTS"],
5473
deps = [
74+
":tensor_maker",
5575
"//executorch/backends/aoti:common_shims",
56-
"//executorch/extension/tensor:tensor",
5776
"//executorch/runtime/core:core",
5877
"//executorch/runtime/core/exec_aten:lib",
5978
"//executorch/runtime/platform:platform",

backends/cuda/runtime/shims/memory.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/backends/cuda/runtime/platform/platform.h>
1212
#include <executorch/backends/cuda/runtime/shims/memory.h>
1313
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
14+
#include <executorch/backends/cuda/runtime/tensor/tensor_maker.h>
1415
#include <executorch/backends/cuda/runtime/utils.h>
1516
#include <executorch/runtime/platform/log.h>
1617
#include <cstdint>
@@ -163,9 +164,11 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
163164

164165
// Create ExecutorTorch tensor that wraps the existing memory
165166
// Note: We're NOT copying the data, just wrapping it
166-
auto tensor = executorch::extension::from_blob(
167-
data, // existing memory (don't copy!)
167+
// Using CUDA-specific tensor maker that supports incontiguous tensors
168+
auto tensor = executorch::backends::cuda::make_tensor(
168169
sizes, // tensor dimensions
170+
data, // existing memory (don't copy!)
171+
{}, // dim_order (empty, will be auto-generated)
169172
strides, // tensor strides (allows different strides)
170173
dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType
171174
);
@@ -268,8 +271,13 @@ AOTITorchError aoti_torch_empty_strided(
268271
auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
269272

270273
// ETensor creation with dynamic shape support for edge cases
271-
auto tensor = executorch::extension::from_blob(
272-
ptr, sizes, strides, dtype_to_scalar_type(dtype));
274+
// Using CUDA-specific tensor maker that supports incontiguous tensors
275+
auto tensor = executorch::backends::cuda::make_tensor(
276+
sizes,
277+
ptr,
278+
{}, // dim_order (empty, will be auto-generated)
279+
strides,
280+
dtype_to_scalar_type(dtype));
273281

274282
// Store the tensor so it doesn't get destroyed
275283
tensors.insert(tensor);
@@ -647,9 +655,11 @@ AOTITorchError aoti_torch__reinterpret_tensor(
647655

648656
// Create new tensor view that reinterprets the same memory with different
649657
// shape/strides This creates a view, not a copy - the data pointer is shared
650-
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
651-
data_ptr, // Reuse the same memory from source tensor
658+
// Using CUDA-specific tensor maker that supports incontiguous tensors
659+
std::shared_ptr<Tensor> tensor = executorch::backends::cuda::make_tensor(
652660
sizes, // New sizes with explicit SizesType
661+
data_ptr, // Reuse the same memory from source tensor
662+
{}, // dim_order (empty, will be auto-generated)
653663
strides, // New strides with explicit StridesType
654664
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
655665
);

backends/cuda/runtime/shims/tensor_attribute.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
#pragma once
1010

11-
#include <executorch/extension/tensor/tensor.h>
1211
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1313
#include <cstdint>
1414

1515
namespace executorch::backends::cuda {

backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -278,30 +278,6 @@ TEST_F(AOTITorchEmptyStridedTest, LargeTensor) {
278278
EXPECT_EQ(tensor->size(2), 50);
279279
}
280280

281-
// Test error handling with memory allocation failures
282-
TEST_F(AOTITorchEmptyStridedTest, MemoryAllocationStress) {
283-
// Try to create a very large tensor that might cause allocation failure
284-
// (This test may pass or fail depending on available memory)
285-
std::vector<int64_t> huge_sizes = {10000, 10000, 100}; // ~38GB for float32
286-
Tensor* tensor;
287-
288-
AOTITorchError error = aoti_torch_empty_strided(
289-
huge_sizes.size(),
290-
huge_sizes.data(),
291-
nullptr,
292-
6, // float32
293-
1, // CUDA device
294-
0, // device index
295-
&tensor);
296-
297-
// Either succeed or fail with memory allocation error
298-
if (error == Error::Ok) {
299-
EXPECT_NE(tensor, nullptr);
300-
} else {
301-
EXPECT_EQ(error, Error::MemoryAllocationFailed);
302-
}
303-
}
304-
305281
// Test aoti_torch_empty_strided with bfloat16 dtype
306282
TEST_F(AOTITorchEmptyStridedTest, BFloat16Tensor) {
307283
// Test creating bfloat16 tensor on CUDA
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include <executorch/backends/cuda/runtime/tensor/tensor_maker.h>
4+
5+
#include <numeric>
6+
7+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
8+
9+
namespace executorch::backends::cuda {
10+
11+
namespace {
12+
#ifndef USE_ATEN_LIB
13+
/**
14+
* A structure that consolidates the metadata (sizes, dim_order, strides) and
15+
* the data buffer associated with a Tensor. Since Tensor does not own
16+
* the memory for these metadata arrays or the data itself, this structure
17+
* ensures that they are managed together and have the same lifetime as the
18+
* Tensor. When the Tensor is destroyed, the Storage structure ensures
19+
* proper cleanup of the associated metadata and data if needed.
20+
*/
21+
struct Storage final {
22+
executorch::aten::TensorImpl tensor_impl;
23+
executorch::aten::Tensor tensor;
24+
std::vector<executorch::aten::SizesType> sizes;
25+
std::vector<executorch::aten::DimOrderType> dim_order;
26+
std::vector<executorch::aten::StridesType> strides;
27+
std::function<void(void*)> deleter;
28+
29+
Storage(
30+
executorch::aten::TensorImpl&& tensor_impl,
31+
std::vector<executorch::aten::SizesType>&& sizes,
32+
std::vector<executorch::aten::DimOrderType>&& dim_order,
33+
std::vector<executorch::aten::StridesType>&& strides,
34+
std::function<void(void*)>&& deleter)
35+
: tensor_impl(std::move(tensor_impl)),
36+
tensor(&this->tensor_impl),
37+
sizes(std::move(sizes)),
38+
dim_order(std::move(dim_order)),
39+
strides(std::move(strides)),
40+
deleter(std::move(deleter)) {}
41+
42+
~Storage() {
43+
if (deleter) {
44+
deleter(tensor_impl.mutable_data());
45+
}
46+
}
47+
};
48+
#endif // USE_ATEN_LIB
49+
} // namespace
50+
51+
TensorPtr make_tensor(
52+
std::vector<executorch::aten::SizesType> sizes,
53+
void* data,
54+
std::vector<executorch::aten::DimOrderType> dim_order,
55+
std::vector<executorch::aten::StridesType> strides,
56+
executorch::aten::ScalarType type,
57+
executorch::aten::TensorShapeDynamism dynamism,
58+
std::function<void(void*)> deleter) {
59+
const auto dim = sizes.size();
60+
ET_CHECK_MSG(
61+
dim_order.empty() || dim_order.size() == dim,
62+
"dim_order size must match sizes or be empty.");
63+
ET_CHECK_MSG(
64+
strides.empty() || strides.size() == dim,
65+
"strides size must match sizes or be empty.");
66+
67+
if (dim_order.empty()) {
68+
dim_order.resize(dim);
69+
std::iota(dim_order.begin(), dim_order.end(), 0);
70+
if (!strides.empty()) {
71+
std::sort(dim_order.begin(), dim_order.end(), [&](size_t a, size_t b) {
72+
return strides[a] > strides[b];
73+
});
74+
}
75+
}
76+
77+
// AOTI backends (like AOTI-CUDA) handle both contiguous and incontiguous
78+
// tensors, so we skip stride calculation and incontiguous tensor checks.
79+
// Strides are passed through as-is without validation.
80+
81+
#ifndef USE_ATEN_LIB
82+
executorch::aten::TensorImpl tensor_impl(
83+
type,
84+
dim,
85+
sizes.data(),
86+
data,
87+
dim_order.data(),
88+
strides.data(),
89+
dim > 0 ? dynamism : executorch::aten::TensorShapeDynamism::STATIC);
90+
auto storage = std::make_shared<Storage>(
91+
std::move(tensor_impl),
92+
std::move(sizes),
93+
std::move(dim_order),
94+
std::move(strides),
95+
std::move(deleter));
96+
const auto tensor_ptr = &storage->tensor;
97+
return std::shared_ptr<executorch::aten::Tensor>(
98+
std::move(storage), tensor_ptr);
99+
#else
100+
auto options = c10::TensorOptions()
101+
.dtype(c10::scalarTypeToTypeMeta(type))
102+
.device(c10::kCPU);
103+
auto storage = c10::Storage(
104+
c10::Storage::use_byte_size_t(),
105+
at::detail::computeStorageNbytes(
106+
sizes, strides, options.dtype().itemsize()),
107+
c10::InefficientStdFunctionContext::makeDataPtr(
108+
data, std::move(deleter), options.device()),
109+
nullptr,
110+
false);
111+
auto tensor_impl = c10::make_intrusive<executorch::aten::TensorImpl>(
112+
std::move(storage),
113+
c10::DispatchKeySet(c10::DispatchKey::CPU),
114+
options.dtype());
115+
tensor_impl->set_sizes_and_strides(sizes, strides);
116+
return std::make_shared<executorch::aten::Tensor>(std::move(tensor_impl));
117+
#endif // USE_ATEN_LIB
118+
}
119+
120+
} // namespace executorch::backends::cuda
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include <functional>
6+
#include <memory>
7+
#include <vector>
8+
9+
#include <executorch/runtime/core/error.h>
10+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
12+
namespace executorch::backends::cuda {
13+
14+
/**
15+
* A smart pointer type for managing the lifecycle of a Tensor.
16+
* This is compatible with executorch::extension::TensorPtr.
17+
*/
18+
using TensorPtr = std::shared_ptr<executorch::aten::Tensor>;
19+
20+
/**
21+
* Creates a TensorPtr for AOTI backends that skips stride calculation and
22+
* incontiguous tensor checks. This is specifically designed for AOTI-CUDA
23+
* which handles both contiguous and incontiguous tensors.
24+
*
25+
* This function is similar to executorch::extension::make_tensor_ptr but
26+
* bypasses the stride validation that assumes contiguous tensors, making it
27+
* suitable for AOTI backends that support arbitrary strides.
28+
*
29+
* @param sizes A vector specifying the size of each dimension.
30+
* @param data A pointer to the data buffer.
31+
* @param dim_order A vector specifying the order of dimensions.
32+
* @param strides A vector specifying the strides of the tensor.
33+
* @param type The scalar type of the tensor elements.
34+
* @param dynamism Specifies the mutability of the tensor's shape.
35+
* @param deleter A custom deleter function for managing the lifetime of the
36+
* data buffer. If provided, this deleter will be called when the managed Tensor
37+
* object is destroyed.
38+
* @return A TensorPtr that manages the newly created Tensor.
39+
*/
40+
TensorPtr make_tensor(
41+
std::vector<executorch::aten::SizesType> sizes,
42+
void* data,
43+
std::vector<executorch::aten::DimOrderType> dim_order,
44+
std::vector<executorch::aten::StridesType> strides,
45+
executorch::aten::ScalarType type = executorch::aten::ScalarType::Float,
46+
executorch::aten::TensorShapeDynamism dynamism =
47+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND,
48+
std::function<void(void*)> deleter = nullptr);
49+
50+
} // namespace executorch::backends::cuda

extension/tensor/tensor_ptr.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,7 @@ TensorPtr make_tensor_ptr(
8383
// Skip stride calculation and incontiguous tensor check for CUDA backend since
8484
// AOTI-CUDA handles both contiguous and incontiguous tensors. This will be
8585
// removed after SlimTensor migration.
86-
#ifdef USE_CUDA_BACKEND
87-
if (strides.empty()) {
88-
std::vector<executorch::aten::StridesType> computed_strides(dim);
89-
90-
auto error = runtime::dim_order_to_stride(
91-
sizes.data(), dim_order.data(), dim, computed_strides.data());
92-
ET_CHECK_MSG(error == runtime::Error::Ok, "Failed to compute strides.");
93-
94-
strides = std::move(computed_strides);
95-
}
96-
#else
86+
#ifndef USE_CUDA_BACKEND
9787
std::vector<executorch::aten::StridesType> computed_strides(dim);
9888

9989
auto error = runtime::dim_order_to_stride(

0 commit comments

Comments
 (0)