Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/aoti/aoti_model_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ AOTInductorModelContainerGetNumInputsFunc
AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs = nullptr;
AOTInductorModelContainerRunFunc AOTInductorModelContainerRun = nullptr;
AOTInductorModelContainerUpdateUserManagedConstantBufferFunc
AOTInductorModelContainerUpdateUserManagedConstantBuffer = nullptr;

// Additional global function pointers for AOT Inductor model container
// operations needed by Metal backend
Expand Down
22 changes: 22 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>

#include <string>
#include <vector>

namespace executorch {
namespace backends {
namespace aoti {
Expand All @@ -30,6 +33,11 @@ using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
using AOTInductorStreamHandle = void*;
using AOTIProxyExecutorHandle = void*;

// Constant map handle (opaque pointer to std::unordered_map<std::string,
// AtenTensorHandle>*)
struct AOTInductorConstantMap;
using AOTInductorConstantMapHandle = AOTInductorConstantMap*;

// Function pointer types for AOT Inductor model container operations
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle* container_handle,
Expand Down Expand Up @@ -60,6 +68,13 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle);

using AOTInductorModelContainerUpdateUserManagedConstantBufferFunc =
AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
AOTInductorConstantMapHandle constant_map_handle,
bool use_inactive,
bool validate_full_update);

// Global function pointers (will be loaded dynamically)
extern AOTInductorModelContainerCreateWithDeviceFunc
AOTInductorModelContainerCreateWithDevice;
Expand All @@ -69,6 +84,8 @@ extern AOTInductorModelContainerGetNumInputsFunc
extern AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs;
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
extern AOTInductorModelContainerUpdateUserManagedConstantBufferFunc
AOTInductorModelContainerUpdateUserManagedConstantBuffer;

// Retrieves the name of an input tensor by index from the AOTI model container.
// Needed by Metal backend
Expand Down Expand Up @@ -99,6 +116,11 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
std::vector<std::string> weight_fqns; // Fully qualified names of weights
std::vector<std::unique_ptr<etensor::Tensor>>
weight_tensors; // Storage for weight tensors
std::vector<executorch::runtime::FreeableBuffer>
weight_buffers; // Storage for weight data - owns the actual data
};

} // namespace aoti
Expand Down
58 changes: 54 additions & 4 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import contextlib
import os
import struct
import typing
from enum import Enum

from typing import Any, Dict, final, List, Optional, Set
from typing import Any, Dict, final, List, Optional, Set, Tuple, Union

import torch
from executorch.backends.cuda.replace_slice_copy_with_slice import (
Expand All @@ -25,6 +26,8 @@
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch.export.passes import move_to_device_pass

from torch.export.pt2_archive._package_weights import TensorProperties
from torch.nn.attention import SDPBackend

# exist fallback operators in et namespace;
Expand All @@ -38,6 +41,34 @@ class COMPILE_SPEC_KEYS(Enum):
METHOD_NAME = "method_name"


def _extract_so_path_and_weight_dict(
file_paths_and_weights: List[
Union[str, Dict[str, Tuple[torch.nn.Parameter, TensorProperties]]]
]
):
so_path = None
weight_dict = {}
for item in file_paths_and_weights:
if isinstance(item, str) and item.endswith("wrapper.so"):
so_path = item
elif isinstance(item, dict):
weight_dict.update(item)
assert (
so_path is not None
), f"so_path is None, all the strings are: {[x for x in file_paths_and_weights if isinstance(x, str)]}"
assert len(weight_dict) > 0, f"No weight dict found in {file_paths_and_weights}"
return so_path, weight_dict


def _weight_fqn_list_to_bytes(weight_fqns: List[str]) -> bytes:
processed_bytes = bytearray()
processed_bytes.extend(struct.pack("<I", len(weight_fqns)))
for fqn in weight_fqns:
encoded_fqn = fqn.encode("utf-8")
processed_bytes.extend(struct.pack("<I", len(encoded_fqn)))
processed_bytes.extend(encoded_fqn)


# context manager for non-fallback guarantee
# it will raise exception when generating fallback kernels during aoti compile
@contextlib.contextmanager
Expand Down Expand Up @@ -136,7 +167,10 @@ def preprocess(
# Do not link against the full PyTorch/libtorch library
"aot_inductor.link_libtorch": False,
# Package model constants and other generated files directly in the shared object (.so) file
"aot_inductor.package_constants_in_so": True,
# Package model constants and other generated files directly in the shared object (.so) file
"aot_inductor.package": True,
"aot_inductor.package_constants_in_so": False,
"aot_inductor.package_constants_on_disk": True,
# Enable maximum automatic tuning for optimal performance
"max_autotune": True,
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
Expand All @@ -151,13 +185,17 @@ def preprocess(
]
), torch.no_grad():
# torch._logging.set_logs(post_grad_graphs=True)
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
file_paths_and_weights = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
if len(missing_fallback_kernels) > 0:
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
raise RuntimeError(
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
"Please add them to the AOTI backend."
)
assert isinstance(
file_paths_and_weights, list
), f"Expected a list of file paths and weights, got type: {type(file_paths_and_weights)}"
so_path, weight_dict = _extract_so_path_and_weight_dict(file_paths_and_weights)

# pyre-ignorep[6]: Incompatible parameter type
with open(so_path, "rb") as f:
Expand All @@ -169,12 +207,24 @@ def preprocess(
method_name + "_so_blob", so_data, 1, "aoti_cuda_blob"
)

# Add weights to named data store
for name, weight_tuple in weight_dict.items():
named_data_store.add_named_data(
name,
weight_tuple[0].cpu().numpy().tobytes(),
1,
None, # Do not store it in .ptd
)

weight_fqns = sorted(weight_dict.keys())
processed_bytes = _weight_fqn_list_to_bytes(weight_fqns)

# Clean up the generated so file; it has been packaged into the NamdeDataStore
# pyre-ignorep[6]: Incompatible parameter type
os.remove(so_path)

return PreprocessResult(
processed_bytes=b"",
processed_bytes=bytes(processed_bytes),
debug_handle_map={},
data_store_output=named_data_store.get_named_data_store_output(),
)
Expand Down
150 changes: 145 additions & 5 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/core/tensor_layout.h>
#include <unistd.h>
#include <cstdio>
#include <memory>

#include <cstdint>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <system_error>
#include <unordered_map>
#include <vector>

// Include our shim layer headers
Expand Down Expand Up @@ -54,6 +61,62 @@ using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::etensor::Tensor;

namespace {

Error parse_weight_fqns_from_processed(
const FreeableBuffer* processed,
std::vector<std::string>& weight_fqns) {
if (processed == nullptr || processed->data() == nullptr ||
processed->size() == 0) {
return Error::Ok;
}

const auto* cursor = static_cast<const uint8_t*>(processed->data());
size_t remaining = processed->size();

auto read_uint32 = [&](uint32_t& value) -> bool {
if (remaining < sizeof(uint32_t)) {
return false;
}
std::memcpy(&value, cursor, sizeof(uint32_t));
cursor += sizeof(uint32_t);
remaining -= sizeof(uint32_t);
return true;
};

uint32_t num_entries = 0;
ET_CHECK_OR_RETURN_ERROR(
read_uint32(num_entries),
InvalidArgument,
"Failed to read FQN count from processed bytes");

weight_fqns.reserve(num_entries);
for (uint32_t i = 0; i < num_entries; ++i) {
uint32_t length = 0;
ET_CHECK_OR_RETURN_ERROR(
read_uint32(length),
InvalidArgument,
"Failed to read FQN length from processed bytes")

ET_CHECK_OR_RETURN_ERROR(
remaining >= length,
InvalidArgument,
"Processed bytes exhausted while reading FQN %u (remaining=%zu, length=%u)",
i,
remaining,
length);

const char* str_begin = reinterpret_cast<const char*>(cursor);
weight_fqns.emplace_back(str_begin, length);
cursor += length;
remaining -= length;
}

return Error::Ok;
}

} // namespace

class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
Expand All @@ -63,6 +126,8 @@ class ET_EXPERIMENTAL CudaBackend final
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);
LOAD_SYMBOL(
AOTInductorModelContainerUpdateUserManagedConstantBuffer, so_handle);

return Error::Ok;
}
Expand All @@ -88,6 +153,15 @@ class ET_EXPERIMENTAL CudaBackend final
}
}

std::vector<std::string> weight_fqns;
Error parse_err = parse_weight_fqns_from_processed(processed, weight_fqns);
if (parse_err != Error::Ok) {
if (processed != nullptr) {
processed->Free();
}
return parse_err;
}

std::string so_blob_key =
method_name.empty() ? "so_blob" : method_name + "_so_blob";

Expand All @@ -99,7 +173,6 @@ class ET_EXPERIMENTAL CudaBackend final
"Failed to get data for key %s: 0x%x",
so_blob_key.c_str(),
static_cast<uint32_t>(aoti_cuda_buffer.error()));

// Generate dynamic temporary file path
filesystem::path temp_dir = filesystem::temp_directory_path();
filesystem::path so_path =
Expand Down Expand Up @@ -149,11 +222,78 @@ class ET_EXPERIMENTAL CudaBackend final
handle->so_handle = so_handle;
handle->so_path = so_path.string();
handle->container_handle = container_handle;
handle->weight_fqns = weight_fqns; // Store weight FQNs in the handle

// Create a constant map and populate it with weights from NamedDataMap
// Store the Tensor objects in the handle so they persist for the lifetime
// of the container
std::unordered_map<std::string, Tensor*> constant_map;

// Create a CUDA stream for asynchronous execution
cudaStream_t cuda_stream;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
handle->cuda_stream = static_cast<void*>(cuda_stream);
for (const auto& fqn : weight_fqns) {
// Get tensor layout (metadata) for this weight
auto tensor_layout_result =
named_data_map->get_tensor_layout(fqn.c_str());
ET_CHECK_OR_RETURN_ERROR(
tensor_layout_result.ok(),
Internal,
"Failed to get tensor layout for key %s: 0x%x",
fqn.c_str(),
static_cast<uint32_t>(tensor_layout_result.error()));

auto weight_result = named_data_map->get_data(fqn.c_str());
ET_CHECK_OR_RETURN_ERROR(
weight_result.ok(),
Internal,
"Failed to get data for key %s: 0x%x",
fqn.c_str(),
static_cast<uint32_t>(weight_result.error()));

// Store the FreeableBuffer to keep the weight data alive
// This is critical: the FreeableBuffer owns or references the actual
// weight data
FreeableBuffer weight_buffer = weight_result.get();
void* weight_data = weight_buffer.data();

// Get tensor layout information
const TensorLayout& layout = tensor_layout_result.get();

// Create a Tensor from the weight data using the layout information
// The Tensor is created as a view over the data owned by the
// FreeableBuffer
auto weight_tensor = std::make_unique<Tensor>(
layout.scalar_type(),
layout.sizes().size(),
const_cast<Tensor::SizesType*>(layout.sizes().data()),
weight_data,
const_cast<Tensor::DimOrderType*>(layout.dim_order().data()),
const_cast<Tensor::StridesType*>(layout.strides().data()));

constant_map[fqn] = weight_tensor.get();
handle->weight_tensors.push_back(std::move(weight_tensor));
handle->weight_buffers.push_back(
std::move(weight_buffer)); // Store buffer to keep data alive
}

// Update the container with user-managed constant buffer
if (!constant_map.empty()) {
AOTIRuntimeError update_err =
AOTInductorModelContainerUpdateUserManagedConstantBuffer(
container_handle,
reinterpret_cast<AOTInductorConstantMapHandle>(&constant_map),
/*use_inactive=*/false,
/*validate_full_update=*/true);

ET_CHECK_OR_RETURN_ERROR(
update_err == Error::Ok,
Internal,
"Failed to update constant buffer with error code %d",
update_err);

ET_LOG(
Info,
"Successfully populated %zu weights into container",
constant_map.size());
}

return (DelegateHandle*)handle; // Return the handle post-processing
}
Expand Down
Loading