diff --git a/backends/aoti/aoti_model_container.cpp b/backends/aoti/aoti_model_container.cpp index 46a246faeb8..08ba114ab11 100644 --- a/backends/aoti/aoti_model_container.cpp +++ b/backends/aoti/aoti_model_container.cpp @@ -24,6 +24,8 @@ AOTInductorModelContainerGetNumInputsFunc AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs = nullptr; AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr; +AOTInductorModelContainerUpdateUserManagedConstantBufferFunc + AOTInductorModelContainerUpdateUserManagedConstantBuffer = nullptr; // Additional global function pointers for AOT Inductor model container // operations needed by Metal backend diff --git a/backends/aoti/aoti_model_container.h b/backends/aoti/aoti_model_container.h index 877f019c457..1876b037694 100644 --- a/backends/aoti/aoti_model_container.h +++ b/backends/aoti/aoti_model_container.h @@ -11,6 +11,9 @@ #include #include +#include +#include + namespace executorch { namespace backends { namespace aoti { @@ -30,6 +33,11 @@ using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*; using AOTInductorStreamHandle = void*; using AOTIProxyExecutorHandle = void*; +// Constant map handle (opaque pointer to std::unordered_map*) +struct AOTInductorConstantMap; +using AOTInductorConstantMapHandle = AOTInductorConstantMap*; + // Function pointer types for AOT Inductor model container operations using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)( AOTInductorModelContainerHandle* container_handle, @@ -60,6 +68,13 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)( AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle); +using AOTInductorModelContainerUpdateUserManagedConstantBufferFunc = + AOTIRuntimeError (*)( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update); + // Global function pointers (will be loaded dynamically) extern AOTInductorModelContainerCreateWithDeviceFunc AOTInductorModelContainerCreateWithDevice; @@ -69,6 +84,8 @@ extern AOTInductorModelContainerGetNumInputsFunc extern AOTInductorModelContainerGetNumOutputsFunc AOTInductorModelContainerGetNumOutputs; extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun; +extern AOTInductorModelContainerUpdateUserManagedConstantBufferFunc + AOTInductorModelContainerUpdateUserManagedConstantBuffer; // Retrieves the name of an input tensor by index from the AOTI model container. // Needed by Metal backend @@ -99,6 +116,11 @@ struct AOTIDelegateHandle { AOTInductorModelContainerHandle container_handle; void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header // dependency + std::vector weight_fqns; // Fully qualified names of weights + std::vector> + weight_tensors; // Storage for weight tensors + std::vector + weight_buffers; // Storage for weight data - owns the actual data }; } // namespace aoti diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index ef98de29f23..04ce9d6d762 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -6,10 +6,11 @@ import contextlib import os +import struct import typing from enum import Enum -from typing import Any, Dict, final, List, Optional, Set +from typing import Any, Dict, final, List, Optional, Set, Tuple, Union import torch from executorch.backends.cuda.replace_slice_copy_with_slice import ( @@ -25,6 +26,8 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch.export.passes import move_to_device_pass + +from torch.export.pt2_archive._package_weights import TensorProperties from torch.nn.attention import SDPBackend # exist fallback operators in et namespace; @@ -38,6 +41,34 @@ class COMPILE_SPEC_KEYS(Enum): METHOD_NAME = "method_name" +def _extract_so_path_and_weight_dict( + file_paths_and_weights: List[ + Union[str, Dict[str, Tuple[torch.nn.Parameter, TensorProperties]]] + ] +): + so_path = None + weight_dict = {} + for item in file_paths_and_weights: + if isinstance(item, str) and item.endswith("wrapper.so"): + so_path = item + elif isinstance(item, dict): + weight_dict.update(item) + assert ( + so_path is not None + ), f"so_path is None, all the strings are: {[x for x in file_paths_and_weights if isinstance(x, str)]}" + assert len(weight_dict) > 0, f"No weight dict found in {file_paths_and_weights}" + return so_path, weight_dict + + +def _weight_fqn_list_to_bytes(weight_fqns: List[str]) -> bytes: + processed_bytes = bytearray() + processed_bytes.extend(struct.pack(" 0: formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) raise RuntimeError( f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" "Please add them to the AOTI backend." ) + assert isinstance( + file_paths_and_weights, list + ), f"Expected a list of file paths and weights, got type: {type(file_paths_and_weights)}" + so_path, weight_dict = _extract_so_path_and_weight_dict(file_paths_and_weights) # pyre-ignorep[6]: Incompatible parameter type with open(so_path, "rb") as f: @@ -169,12 +207,24 @@ def preprocess( method_name + "_so_blob", so_data, 1, "aoti_cuda_blob" ) + # Add weights to named data store + for name, weight_tuple in weight_dict.items(): + named_data_store.add_named_data( + name, + weight_tuple[0].cpu().numpy().tobytes(), + 1, + None, # Do not store it in .ptd + ) + + weight_fqns = sorted(weight_dict.keys()) + processed_bytes = _weight_fqn_list_to_bytes(weight_fqns) + # Clean up the generated so file; it has been packaged into the NamdeDataStore # pyre-ignorep[6]: Incompatible parameter type os.remove(so_path) return PreprocessResult( - processed_bytes=b"", + processed_bytes=bytes(processed_bytes), debug_handle_map={}, data_store_output=named_data_store.get_named_data_store_output(), ) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 805c54ff55c..968610c834f 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -12,12 +12,19 @@ #include #include #include +#include #include #include +#include +#include +#include #include #include +#include #include +#include +#include #include // Include our shim layer headers @@ -54,6 +61,62 @@ using executorch::runtime::Result; using executorch::runtime::Span; using executorch::runtime::etensor::Tensor; +namespace { + +Error parse_weight_fqns_from_processed( + const FreeableBuffer* processed, + std::vector& weight_fqns) { + if (processed == nullptr || processed->data() == nullptr || + processed->size() == 0) { + return Error::Ok; + } + + const auto* cursor = static_cast(processed->data()); + size_t remaining = processed->size(); + + auto read_uint32 = [&](uint32_t& value) -> bool { + if (remaining < sizeof(uint32_t)) { + return false; + } + std::memcpy(&value, cursor, sizeof(uint32_t)); + cursor += sizeof(uint32_t); + remaining -= sizeof(uint32_t); + return true; + }; + + uint32_t num_entries = 0; + ET_CHECK_OR_RETURN_ERROR( + read_uint32(num_entries), + InvalidArgument, + "Failed to read FQN count from processed bytes"); + + weight_fqns.reserve(num_entries); + for (uint32_t i = 0; i < num_entries; ++i) { + uint32_t length = 0; + ET_CHECK_OR_RETURN_ERROR( + read_uint32(length), + InvalidArgument, + "Failed to read FQN length from processed bytes") + + ET_CHECK_OR_RETURN_ERROR( + remaining >= length, + InvalidArgument, + "Processed bytes exhausted while reading FQN %u (remaining=%zu, length=%u)", + i, + remaining, + length); + + const char* str_begin = reinterpret_cast(cursor); + weight_fqns.emplace_back(str_begin, length); + cursor += length; + remaining -= length; + } + + return Error::Ok; +} + +} // namespace + class ET_EXPERIMENTAL CudaBackend final : public ::executorch::runtime::BackendInterface { private: @@ -63,6 +126,8 @@ class ET_EXPERIMENTAL CudaBackend final LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle); LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle); LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle); + LOAD_SYMBOL( + AOTInductorModelContainerUpdateUserManagedConstantBuffer, so_handle); return Error::Ok; } @@ -88,6 +153,15 @@ class ET_EXPERIMENTAL CudaBackend final } } + std::vector weight_fqns; + Error parse_err = parse_weight_fqns_from_processed(processed, weight_fqns); + if (parse_err != Error::Ok) { + if (processed != nullptr) { + processed->Free(); + } + return parse_err; + } + std::string so_blob_key = method_name.empty() ? "so_blob" : method_name + "_so_blob"; @@ -99,7 +173,6 @@ class ET_EXPERIMENTAL CudaBackend final "Failed to get data for key %s: 0x%x", so_blob_key.c_str(), static_cast(aoti_cuda_buffer.error())); - // Generate dynamic temporary file path filesystem::path temp_dir = filesystem::temp_directory_path(); filesystem::path so_path = @@ -149,11 +222,78 @@ class ET_EXPERIMENTAL CudaBackend final handle->so_handle = so_handle; handle->so_path = so_path.string(); handle->container_handle = container_handle; + handle->weight_fqns = weight_fqns; // Store weight FQNs in the handle + + // Create a constant map and populate it with weights from NamedDataMap + // Store the Tensor objects in the handle so they persist for the lifetime + // of the container + std::unordered_map constant_map; - // Create a CUDA stream for asynchronous execution - cudaStream_t cuda_stream; - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); - handle->cuda_stream = static_cast(cuda_stream); + for (const auto& fqn : weight_fqns) { + // Get tensor layout (metadata) for this weight + auto tensor_layout_result = + named_data_map->get_tensor_layout(fqn.c_str()); + ET_CHECK_OR_RETURN_ERROR( + tensor_layout_result.ok(), + Internal, + "Failed to get tensor layout for key %s: 0x%x", + fqn.c_str(), + static_cast(tensor_layout_result.error())); + + auto weight_result = named_data_map->get_data(fqn.c_str()); + ET_CHECK_OR_RETURN_ERROR( + weight_result.ok(), + Internal, + "Failed to get data for key %s: 0x%x", + fqn.c_str(), + static_cast(weight_result.error())); + + // Store the FreeableBuffer to keep the weight data alive + // This is critical: the FreeableBuffer owns or references the actual + // weight data + FreeableBuffer weight_buffer = weight_result.get(); + void* weight_data = weight_buffer.data(); + + // Get tensor layout information + const TensorLayout& layout = tensor_layout_result.get(); + + // Create a Tensor from the weight data using the layout information + // The Tensor is created as a view over the data owned by the + // FreeableBuffer + auto weight_tensor = std::make_unique( + layout.scalar_type(), + layout.sizes().size(), + const_cast(layout.sizes().data()), + weight_data, + const_cast(layout.dim_order().data()), + const_cast(layout.strides().data())); + + constant_map[fqn] = weight_tensor.get(); + handle->weight_tensors.push_back(std::move(weight_tensor)); + handle->weight_buffers.push_back( + std::move(weight_buffer)); // Store buffer to keep data alive + } + + // Update the container with user-managed constant buffer + if (!constant_map.empty()) { + AOTIRuntimeError update_err = + AOTInductorModelContainerUpdateUserManagedConstantBuffer( + container_handle, + reinterpret_cast(&constant_map), + /*use_inactive=*/false, + /*validate_full_update=*/true); + + ET_CHECK_OR_RETURN_ERROR( + update_err == Error::Ok, + Internal, + "Failed to update constant buffer with error code %d", + update_err); + + ET_LOG( + Info, + "Successfully populated %zu weights into container", + constant_map.size()); + } return (DelegateHandle*)handle; // Return the handle post-processing }