1212#include < executorch/runtime/core/error.h>
1313#include < executorch/runtime/core/evalue.h>
1414#include < executorch/runtime/core/exec_aten/util/tensor_util.h>
15+ #include < executorch/runtime/core/tensor_layout.h>
1516#include < unistd.h>
1617#include < cstdio>
18+ #include < memory>
1719
20+ #include < cstdint>
21+ #include < cstring>
1822#include < filesystem>
1923#include < fstream>
24+ #include < iostream>
2025#include < string>
26+ #include < system_error>
27+ #include < unordered_map>
2128#include < vector>
2229
2330// Include our shim layer headers
@@ -54,6 +61,62 @@ using executorch::runtime::Result;
5461using executorch::runtime::Span;
5562using executorch::runtime::etensor::Tensor;
5663
64+ namespace {
65+
66+ Error parse_weight_fqns_from_processed (
67+ const FreeableBuffer* processed,
68+ std::vector<std::string>& weight_fqns) {
69+ if (processed == nullptr || processed->data () == nullptr ||
70+ processed->size () == 0 ) {
71+ return Error::Ok;
72+ }
73+
74+ const auto * cursor = static_cast <const uint8_t *>(processed->data ());
75+ size_t remaining = processed->size ();
76+
77+ auto read_uint32 = [&](uint32_t & value) -> bool {
78+ if (remaining < sizeof (uint32_t )) {
79+ return false ;
80+ }
81+ std::memcpy (&value, cursor, sizeof (uint32_t ));
82+ cursor += sizeof (uint32_t );
83+ remaining -= sizeof (uint32_t );
84+ return true ;
85+ };
86+
87+ uint32_t num_entries = 0 ;
88+ ET_CHECK_OR_RETURN_ERROR (
89+ read_uint32 (num_entries),
90+ InvalidArgument,
91+ " Failed to read FQN count from processed bytes" );
92+
93+ weight_fqns.reserve (num_entries);
94+ for (uint32_t i = 0 ; i < num_entries; ++i) {
95+ uint32_t length = 0 ;
96+ ET_CHECK_OR_RETURN_ERROR (
97+ read_uint32 (length),
98+ InvalidArgument,
99+ " Failed to read FQN length from processed bytes" )
100+
101+ ET_CHECK_OR_RETURN_ERROR (
102+ remaining >= length,
103+ InvalidArgument,
104+ " Processed bytes exhausted while reading FQN %u (remaining=%zu, length=%u)" ,
105+ i,
106+ remaining,
107+ length);
108+
109+ const char * str_begin = reinterpret_cast <const char *>(cursor);
110+ weight_fqns.emplace_back (str_begin, length);
111+ cursor += length;
112+ remaining -= length;
113+ }
114+
115+ return Error::Ok;
116+ }
117+
118+ } // namespace
119+
57120class ET_EXPERIMENTAL CudaBackend final
58121 : public ::executorch::runtime::BackendInterface {
59122 private:
@@ -63,6 +126,8 @@ class ET_EXPERIMENTAL CudaBackend final
63126 LOAD_SYMBOL (AOTInductorModelContainerGetNumInputs, so_handle);
64127 LOAD_SYMBOL (AOTInductorModelContainerGetNumOutputs, so_handle);
65128 LOAD_SYMBOL (AOTInductorModelContainerRun, so_handle);
129+ LOAD_SYMBOL (
130+ AOTInductorModelContainerUpdateUserManagedConstantBuffer, so_handle);
66131
67132 return Error::Ok;
68133 }
@@ -88,6 +153,15 @@ class ET_EXPERIMENTAL CudaBackend final
88153 }
89154 }
90155
156+ std::vector<std::string> weight_fqns;
157+ Error parse_err = parse_weight_fqns_from_processed (processed, weight_fqns);
158+ if (parse_err != Error::Ok) {
159+ if (processed != nullptr ) {
160+ processed->Free ();
161+ }
162+ return parse_err;
163+ }
164+
91165 std::string so_blob_key =
92166 method_name.empty () ? " so_blob" : method_name + " _so_blob" ;
93167
@@ -99,7 +173,6 @@ class ET_EXPERIMENTAL CudaBackend final
99173 " Failed to get data for key %s: 0x%x" ,
100174 so_blob_key.c_str (),
101175 static_cast <uint32_t >(aoti_cuda_buffer.error ()));
102-
103176 // Generate dynamic temporary file path
104177 filesystem::path temp_dir = filesystem::temp_directory_path ();
105178 filesystem::path so_path =
@@ -149,11 +222,78 @@ class ET_EXPERIMENTAL CudaBackend final
149222 handle->so_handle = so_handle;
150223 handle->so_path = so_path.string ();
151224 handle->container_handle = container_handle;
225+ handle->weight_fqns = weight_fqns; // Store weight FQNs in the handle
226+
227+ // Create a constant map and populate it with weights from NamedDataMap
228+ // Store the Tensor objects in the handle so they persist for the lifetime
229+ // of the container
230+ std::unordered_map<std::string, Tensor*> constant_map;
152231
153- // Create a CUDA stream for asynchronous execution
154- cudaStream_t cuda_stream;
155- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaStreamCreate (&cuda_stream));
156- handle->cuda_stream = static_cast <void *>(cuda_stream);
232+ for (const auto & fqn : weight_fqns) {
233+ // Get tensor layout (metadata) for this weight
234+ auto tensor_layout_result =
235+ named_data_map->get_tensor_layout (fqn.c_str ());
236+ ET_CHECK_OR_RETURN_ERROR (
237+ tensor_layout_result.ok (),
238+ Internal,
239+ " Failed to get tensor layout for key %s: 0x%x" ,
240+ fqn.c_str (),
241+ static_cast <uint32_t >(tensor_layout_result.error ()));
242+
243+ auto weight_result = named_data_map->get_data (fqn.c_str ());
244+ ET_CHECK_OR_RETURN_ERROR (
245+ weight_result.ok (),
246+ Internal,
247+ " Failed to get data for key %s: 0x%x" ,
248+ fqn.c_str (),
249+ static_cast <uint32_t >(weight_result.error ()));
250+
251+ // Store the FreeableBuffer to keep the weight data alive
252+ // This is critical: the FreeableBuffer owns or references the actual
253+ // weight data
254+ FreeableBuffer weight_buffer = weight_result.get ();
255+ void * weight_data = weight_buffer.data ();
256+
257+ // Get tensor layout information
258+ const TensorLayout& layout = tensor_layout_result.get ();
259+
260+ // Create a Tensor from the weight data using the layout information
261+ // The Tensor is created as a view over the data owned by the
262+ // FreeableBuffer
263+ auto weight_tensor = std::make_unique<Tensor>(
264+ layout.scalar_type (),
265+ layout.sizes ().size (),
266+ const_cast <Tensor::SizesType*>(layout.sizes ().data ()),
267+ weight_data,
268+ const_cast <Tensor::DimOrderType*>(layout.dim_order ().data ()),
269+ const_cast <Tensor::StridesType*>(layout.strides ().data ()));
270+
271+ constant_map[fqn] = weight_tensor.get ();
272+ handle->weight_tensors .push_back (std::move (weight_tensor));
273+ handle->weight_buffers .push_back (
274+ std::move (weight_buffer)); // Store buffer to keep data alive
275+ }
276+
277+ // Update the container with user-managed constant buffer
278+ if (!constant_map.empty ()) {
279+ AOTIRuntimeError update_err =
280+ AOTInductorModelContainerUpdateUserManagedConstantBuffer (
281+ container_handle,
282+ reinterpret_cast <AOTInductorConstantMapHandle>(&constant_map),
283+ /* use_inactive=*/ false ,
284+ /* validate_full_update=*/ true );
285+
286+ ET_CHECK_OR_RETURN_ERROR (
287+ update_err == Error::Ok,
288+ Internal,
289+ " Failed to update constant buffer with error code %d" ,
290+ update_err);
291+
292+ ET_LOG (
293+ Info,
294+ " Successfully populated %zu weights into container" ,
295+ constant_map.size ());
296+ }
157297
158298 return (DelegateHandle*)handle; // Return the handle post-processing
159299 }
0 commit comments