Skip to content

Commit 609b8db

Browse files
committed
[NOTFORLAND] Tensor granularity data program separation
1 parent 5ed5097 commit 609b8db

File tree

4 files changed

+223
-9
lines changed

4 files changed

+223
-9
lines changed

backends/aoti/aoti_model_container.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ AOTInductorModelContainerGetNumInputsFunc
2424
AOTInductorModelContainerGetNumOutputsFunc
2525
AOTInductorModelContainerGetNumOutputs = nullptr;
2626
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;
27+
AOTInductorModelContainerUpdateUserManagedConstantBufferFunc
28+
AOTInductorModelContainerUpdateUserManagedConstantBuffer = nullptr;
2729

2830
// Additional global function pointers for AOT Inductor model container
2931
// operations needed by Metal backend

backends/aoti/aoti_model_container.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include <executorch/runtime/core/error.h>
1212
#include <executorch/runtime/core/evalue.h>
1313

14+
#include <string>
15+
#include <vector>
16+
1417
namespace executorch {
1518
namespace backends {
1619
namespace aoti {
@@ -30,6 +33,11 @@ using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
3033
using AOTInductorStreamHandle = void*;
3134
using AOTIProxyExecutorHandle = void*;
3235

36+
// Constant map handle (opaque pointer to std::unordered_map<std::string,
37+
// AtenTensorHandle>*)
38+
struct AOTInductorConstantMap;
39+
using AOTInductorConstantMapHandle = AOTInductorConstantMap*;
40+
3341
// Function pointer types for AOT Inductor model container operations
3442
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
3543
AOTInductorModelContainerHandle* container_handle,
@@ -60,6 +68,13 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
6068
AOTInductorStreamHandle stream_handle,
6169
AOTIProxyExecutorHandle proxy_executor_handle);
6270

71+
using AOTInductorModelContainerUpdateUserManagedConstantBufferFunc =
72+
AOTIRuntimeError (*)(
73+
AOTInductorModelContainerHandle container_handle,
74+
AOTInductorConstantMapHandle constant_map_handle,
75+
bool use_inactive,
76+
bool validate_full_update);
77+
6378
// Global function pointers (will be loaded dynamically)
6479
extern AOTInductorModelContainerCreateWithDeviceFunc
6580
AOTInductorModelContainerCreateWithDevice;
@@ -69,6 +84,8 @@ extern AOTInductorModelContainerGetNumInputsFunc
6984
extern AOTInductorModelContainerGetNumOutputsFunc
7085
AOTInductorModelContainerGetNumOutputs;
7186
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
87+
extern AOTInductorModelContainerUpdateUserManagedConstantBufferFunc
88+
AOTInductorModelContainerUpdateUserManagedConstantBuffer;
7289

7390
// Retrieves the name of an input tensor by index from the AOTI model container.
7491
// Needed by Metal backend
@@ -99,6 +116,11 @@ struct AOTIDelegateHandle {
99116
AOTInductorModelContainerHandle container_handle;
100117
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
101118
// dependency
119+
std::vector<std::string> weight_fqns; // Fully qualified names of weights
120+
std::vector<std::unique_ptr<etensor::Tensor>>
121+
weight_tensors; // Storage for weight tensors
122+
std::vector<executorch::runtime::FreeableBuffer>
123+
weight_buffers; // Storage for weight data - owns the actual data
102124
};
103125

104126
} // namespace aoti

backends/cuda/cuda_backend.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import contextlib
88
import os
9+
import struct
910
import typing
1011
from enum import Enum
1112

12-
from typing import Any, Dict, final, List, Optional, Set
13+
from typing import Any, Dict, final, List, Optional, Set, Tuple, Union
1314

1415
import torch
1516
from executorch.backends.cuda.replace_slice_copy_with_slice import (
@@ -25,6 +26,8 @@
2526
from executorch.exir.backend.compile_spec_schema import CompileSpec
2627
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
2728
from torch.export.passes import move_to_device_pass
29+
30+
from torch.export.pt2_archive._package_weights import TensorProperties
2831
from torch.nn.attention import SDPBackend
2932

3033
# exist fallback operators in et namespace;
@@ -38,6 +41,34 @@ class COMPILE_SPEC_KEYS(Enum):
3841
METHOD_NAME = "method_name"
3942

4043

44+
def _extract_so_path_and_weight_dict(
45+
file_paths_and_weights: List[
46+
Union[str, Dict[str, Tuple[torch.nn.Parameter, TensorProperties]]]
47+
]
48+
):
49+
so_path = None
50+
weight_dict = {}
51+
for item in file_paths_and_weights:
52+
if isinstance(item, str) and item.endswith("wrapper.so"):
53+
so_path = item
54+
elif isinstance(item, dict):
55+
weight_dict.update(item)
56+
assert (
57+
so_path is not None
58+
), f"so_path is None, all the strings are: {[x for x in file_paths_and_weights if isinstance(x, str)]}"
59+
assert len(weight_dict) > 0, f"No weight dict found in {file_paths_and_weights}"
60+
return so_path, weight_dict
61+
62+
63+
def _weight_fqn_list_to_bytes(weight_fqns: List[str]) -> bytes:
64+
processed_bytes = bytearray()
65+
processed_bytes.extend(struct.pack("<I", len(weight_fqns)))
66+
for fqn in weight_fqns:
67+
encoded_fqn = fqn.encode("utf-8")
68+
processed_bytes.extend(struct.pack("<I", len(encoded_fqn)))
69+
processed_bytes.extend(encoded_fqn)
70+
71+
4172
# context manager for non-fallback guarantee
4273
# it will raise exception when generating fallback kernels during aoti compile
4374
@contextlib.contextmanager
@@ -136,7 +167,10 @@ def preprocess(
136167
# Do not link against the full PyTorch/libtorch library
137168
"aot_inductor.link_libtorch": False,
138169
# Package model constants and other generated files directly in the shared object (.so) file
139-
"aot_inductor.package_constants_in_so": True,
170+
# Package model constants and other generated files directly in the shared object (.so) file
171+
"aot_inductor.package": True,
172+
"aot_inductor.package_constants_in_so": False,
173+
"aot_inductor.package_constants_on_disk": True,
140174
# Enable maximum automatic tuning for optimal performance
141175
"max_autotune": True,
142176
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
@@ -151,13 +185,17 @@ def preprocess(
151185
]
152186
), torch.no_grad():
153187
# torch._logging.set_logs(post_grad_graphs=True)
154-
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
188+
file_paths_and_weights = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
155189
if len(missing_fallback_kernels) > 0:
156190
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
157191
raise RuntimeError(
158192
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
159193
"Please add them to the AOTI backend."
160194
)
195+
assert isinstance(
196+
file_paths_and_weights, list
197+
), f"Expected a list of file paths and weights, got type: {type(file_paths_and_weights)}"
198+
so_path, weight_dict = _extract_so_path_and_weight_dict(file_paths_and_weights)
161199

162200
# pyre-ignorep[6]: Incompatible parameter type
163201
with open(so_path, "rb") as f:
@@ -169,12 +207,24 @@ def preprocess(
169207
method_name + "_so_blob", so_data, 1, "aoti_cuda_blob"
170208
)
171209

210+
# Add weights to named data store
211+
for name, weight_tuple in weight_dict.items():
212+
named_data_store.add_named_data(
213+
name,
214+
weight_tuple[0].cpu().numpy().tobytes(),
215+
1,
216+
None, # Do not store it in .ptd
217+
)
218+
219+
weight_fqns = sorted(weight_dict.keys())
220+
processed_bytes = _weight_fqn_list_to_bytes(weight_fqns)
221+
172222
# Clean up the generated so file; it has been packaged into the NamdeDataStore
173223
# pyre-ignorep[6]: Incompatible parameter type
174224
os.remove(so_path)
175225

176226
return PreprocessResult(
177-
processed_bytes=b"",
227+
processed_bytes=bytes(processed_bytes),
178228
debug_handle_map={},
179229
data_store_output=named_data_store.get_named_data_store_output(),
180230
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 145 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,19 @@
1212
#include <executorch/runtime/core/error.h>
1313
#include <executorch/runtime/core/evalue.h>
1414
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
15+
#include <executorch/runtime/core/tensor_layout.h>
1516
#include <unistd.h>
1617
#include <cstdio>
18+
#include <memory>
1719

20+
#include <cstdint>
21+
#include <cstring>
1822
#include <filesystem>
1923
#include <fstream>
24+
#include <iostream>
2025
#include <string>
26+
#include <system_error>
27+
#include <unordered_map>
2128
#include <vector>
2229

2330
// Include our shim layer headers
@@ -54,6 +61,62 @@ using executorch::runtime::Result;
5461
using executorch::runtime::Span;
5562
using executorch::runtime::etensor::Tensor;
5663

64+
namespace {
65+
66+
Error parse_weight_fqns_from_processed(
67+
const FreeableBuffer* processed,
68+
std::vector<std::string>& weight_fqns) {
69+
if (processed == nullptr || processed->data() == nullptr ||
70+
processed->size() == 0) {
71+
return Error::Ok;
72+
}
73+
74+
const auto* cursor = static_cast<const uint8_t*>(processed->data());
75+
size_t remaining = processed->size();
76+
77+
auto read_uint32 = [&](uint32_t& value) -> bool {
78+
if (remaining < sizeof(uint32_t)) {
79+
return false;
80+
}
81+
std::memcpy(&value, cursor, sizeof(uint32_t));
82+
cursor += sizeof(uint32_t);
83+
remaining -= sizeof(uint32_t);
84+
return true;
85+
};
86+
87+
uint32_t num_entries = 0;
88+
ET_CHECK_OR_RETURN_ERROR(
89+
read_uint32(num_entries),
90+
InvalidArgument,
91+
"Failed to read FQN count from processed bytes");
92+
93+
weight_fqns.reserve(num_entries);
94+
for (uint32_t i = 0; i < num_entries; ++i) {
95+
uint32_t length = 0;
96+
ET_CHECK_OR_RETURN_ERROR(
97+
read_uint32(length),
98+
InvalidArgument,
99+
"Failed to read FQN length from processed bytes")
100+
101+
ET_CHECK_OR_RETURN_ERROR(
102+
remaining >= length,
103+
InvalidArgument,
104+
"Processed bytes exhausted while reading FQN %u (remaining=%zu, length=%u)",
105+
i,
106+
remaining,
107+
length);
108+
109+
const char* str_begin = reinterpret_cast<const char*>(cursor);
110+
weight_fqns.emplace_back(str_begin, length);
111+
cursor += length;
112+
remaining -= length;
113+
}
114+
115+
return Error::Ok;
116+
}
117+
118+
} // namespace
119+
57120
class ET_EXPERIMENTAL CudaBackend final
58121
: public ::executorch::runtime::BackendInterface {
59122
private:
@@ -63,6 +126,8 @@ class ET_EXPERIMENTAL CudaBackend final
63126
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
64127
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
65128
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);
129+
LOAD_SYMBOL(
130+
AOTInductorModelContainerUpdateUserManagedConstantBuffer, so_handle);
66131

67132
return Error::Ok;
68133
}
@@ -88,6 +153,15 @@ class ET_EXPERIMENTAL CudaBackend final
88153
}
89154
}
90155

156+
std::vector<std::string> weight_fqns;
157+
Error parse_err = parse_weight_fqns_from_processed(processed, weight_fqns);
158+
if (parse_err != Error::Ok) {
159+
if (processed != nullptr) {
160+
processed->Free();
161+
}
162+
return parse_err;
163+
}
164+
91165
std::string so_blob_key =
92166
method_name.empty() ? "so_blob" : method_name + "_so_blob";
93167

@@ -99,7 +173,6 @@ class ET_EXPERIMENTAL CudaBackend final
99173
"Failed to get data for key %s: 0x%x",
100174
so_blob_key.c_str(),
101175
static_cast<uint32_t>(aoti_cuda_buffer.error()));
102-
103176
// Generate dynamic temporary file path
104177
filesystem::path temp_dir = filesystem::temp_directory_path();
105178
filesystem::path so_path =
@@ -149,11 +222,78 @@ class ET_EXPERIMENTAL CudaBackend final
149222
handle->so_handle = so_handle;
150223
handle->so_path = so_path.string();
151224
handle->container_handle = container_handle;
225+
handle->weight_fqns = weight_fqns; // Store weight FQNs in the handle
226+
227+
// Create a constant map and populate it with weights from NamedDataMap
228+
// Store the Tensor objects in the handle so they persist for the lifetime
229+
// of the container
230+
std::unordered_map<std::string, Tensor*> constant_map;
152231

153-
// Create a CUDA stream for asynchronous execution
154-
cudaStream_t cuda_stream;
155-
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
156-
handle->cuda_stream = static_cast<void*>(cuda_stream);
232+
for (const auto& fqn : weight_fqns) {
233+
// Get tensor layout (metadata) for this weight
234+
auto tensor_layout_result =
235+
named_data_map->get_tensor_layout(fqn.c_str());
236+
ET_CHECK_OR_RETURN_ERROR(
237+
tensor_layout_result.ok(),
238+
Internal,
239+
"Failed to get tensor layout for key %s: 0x%x",
240+
fqn.c_str(),
241+
static_cast<uint32_t>(tensor_layout_result.error()));
242+
243+
auto weight_result = named_data_map->get_data(fqn.c_str());
244+
ET_CHECK_OR_RETURN_ERROR(
245+
weight_result.ok(),
246+
Internal,
247+
"Failed to get data for key %s: 0x%x",
248+
fqn.c_str(),
249+
static_cast<uint32_t>(weight_result.error()));
250+
251+
// Store the FreeableBuffer to keep the weight data alive
252+
// This is critical: the FreeableBuffer owns or references the actual
253+
// weight data
254+
FreeableBuffer weight_buffer = weight_result.get();
255+
void* weight_data = weight_buffer.data();
256+
257+
// Get tensor layout information
258+
const TensorLayout& layout = tensor_layout_result.get();
259+
260+
// Create a Tensor from the weight data using the layout information
261+
// The Tensor is created as a view over the data owned by the
262+
// FreeableBuffer
263+
auto weight_tensor = std::make_unique<Tensor>(
264+
layout.scalar_type(),
265+
layout.sizes().size(),
266+
const_cast<Tensor::SizesType*>(layout.sizes().data()),
267+
weight_data,
268+
const_cast<Tensor::DimOrderType*>(layout.dim_order().data()),
269+
const_cast<Tensor::StridesType*>(layout.strides().data()));
270+
271+
constant_map[fqn] = weight_tensor.get();
272+
handle->weight_tensors.push_back(std::move(weight_tensor));
273+
handle->weight_buffers.push_back(
274+
std::move(weight_buffer)); // Store buffer to keep data alive
275+
}
276+
277+
// Update the container with user-managed constant buffer
278+
if (!constant_map.empty()) {
279+
AOTIRuntimeError update_err =
280+
AOTInductorModelContainerUpdateUserManagedConstantBuffer(
281+
container_handle,
282+
reinterpret_cast<AOTInductorConstantMapHandle>(&constant_map),
283+
/*use_inactive=*/false,
284+
/*validate_full_update=*/true);
285+
286+
ET_CHECK_OR_RETURN_ERROR(
287+
update_err == Error::Ok,
288+
Internal,
289+
"Failed to update constant buffer with error code %d",
290+
update_err);
291+
292+
ET_LOG(
293+
Info,
294+
"Successfully populated %zu weights into container",
295+
constant_map.size());
296+
}
157297

158298
return (DelegateHandle*)handle; // Return the handle post-processing
159299
}

0 commit comments

Comments
 (0)