Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
53a2908a10f414a2f85caa06703a26a40e873869
e6f766c7d750d40603eee3f66c5915bac606b3ea
39 changes: 39 additions & 0 deletions .ci/scripts/utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,44 @@ install_pip_dependencies() {
popd || return
}

dedupe_macos_loader_path_rpaths() {
if [[ "$(uname)" != "Darwin" ]]; then
return
fi

local torch_lib_dir
pushd ..
torch_lib_dir=$(python -c "import importlib.util; print(importlib.util.find_spec('torch').submodule_search_locations[0])")/lib
popd

if [[ -z "${torch_lib_dir}" || ! -d "${torch_lib_dir}" ]]; then
return
fi

local torch_libs=(
"libtorch_cpu.dylib"
"libtorch.dylib"
"libc10.dylib"
)

for lib_name in "${torch_libs[@]}"; do
local lib_path="${torch_lib_dir}/${lib_name}"
if [[ ! -f "${lib_path}" ]]; then
continue
fi

local removed=0
# Repeatedly remove the @loader_path rpath entries until none remain.
while install_name_tool -delete_rpath @loader_path "${lib_path}" 2>/dev/null; do
removed=1
done

if [[ "${removed}" == "1" ]]; then
install_name_tool -add_rpath @loader_path "${lib_path}" || true
fi
done
}

install_domains() {
echo "Install torchvision and torchaudio"
pip install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${TORCHAUDIO_VERSION}"
Expand Down Expand Up @@ -101,6 +139,7 @@ install_pytorch_and_domains() {
echo "Use cached wheel at ${cached_torch_wheel}"
fi

dedupe_macos_loader_path_rpaths
# Grab the pinned audio and vision commits from PyTorch
TORCHAUDIO_VERSION=$(cat .github/ci_commit_pins/audio.txt)
export TORCHAUDIO_VERSION
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ jobs:
# reinstall executorch
bash ./install_executorch.sh --minimal
pip list
# run python unittest
python -m unittest examples.models.moshi.mimi.test_mimi
Expand Down
33 changes: 6 additions & 27 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,6 @@ announce_configured_options(CCACHE_PROGRAM)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# Setup RPATH. See
# https://gitlab.kitware.com/cmake/community/-/wikis/doc/cmake/RPATH-handling
# Use separate rpaths during build and install phases
set(CMAKE_SKIP_BUILD_RPATH OFF)
# Don't use the install-rpath during the build phase
set(CMAKE_BUILD_WITH_INSTALL_RPATH ON)
# Automatically add all linked folders that are NOT in the build directory to
# the rpath (per library?)
#
# TODO: Doesn't work for us right now because we are not installing .so's into
# the correct locations. For example we have libcustom_ops_aot_lib.so depending
# on _portable_lib.so, which was eventually put under
# <site-packages>/executorch/extension/pybindings/ but this rpath is not
# automatically added because at build time it seems `portable_lib` is being
# built under the same directory, so no extra rpath is being added. To properly
# fix this we need to install `portable_lib` into the correct path.
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH ON)
# ------------------------------ OPTIONS -------------------------------------
# WARNING: Please don't add example specific options in this CMakeLists.txt.
# Instead please use `find_package(executorch REQUIRED)` in the example
# directory and add a new executable in the example `CMakeLists.txt`.

if(NOT EXECUTORCH_ENABLE_LOGGING)
# Avoid pulling in the logging strings, which can be large. Note that this
# will set the compiler flag for all targets in this directory, and for all
Expand Down Expand Up @@ -909,12 +887,13 @@ if(EXECUTORCH_BUILD_PYBIND)

# Set RPATH to find PyTorch libraries relative to the installation location
# This goes from executorch/extension/pybindings up to site-packages, then to
# torch/lib
# torch/lib. Don't do this to APPLE, as it will error out on the following
# error:
#
if(APPLE)
set_target_properties(
portable_lib PROPERTIES BUILD_RPATH "@loader_path/../../../torch/lib"
INSTALL_RPATH "@loader_path/../../../torch/lib"
)
# Skip setting @loader_path for APPLE, since it causes error like ld:
# duplicate LC_RPATH '@loader_path' in '<site-packages>/torch/lib/
# libtorch_cpu.dylib'
else()
set_target_properties(
portable_lib PROPERTIES BUILD_RPATH "$ORIGIN/../../../torch/lib"
Expand Down
6 changes: 6 additions & 0 deletions backends/aoti/aoti_delegate_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t* num_constants);

// Update the model container with the constant tensors
using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
const uint8_t* weight_blob_ptr);

} // extern "C"

// AOTI Delegate Handle structure
Expand All @@ -87,6 +92,7 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerGetNumInputsFunc get_num_inputs;
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
AOTInductorModelContainerRunFunc run;
AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob;
};

} // namespace aoti
Expand Down
37 changes: 32 additions & 5 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,11 @@ def preprocess(
"aot_inductor.embed_kernel_binary": True,
# Do not link against the full PyTorch/libtorch library
"aot_inductor.link_libtorch": False,
# Package model constants and other generated files directly in the shared object (.so) file
"aot_inductor.package_constants_in_so": True,
# Separate weight constants from the .so file
"aot_inductor.package": True,
"aot_inductor.package_constants_in_so": False,
# Store weight constants on disk in a binary blob
"aot_inductor.package_constants_on_disk_format": "binary_blob",
# Enable maximum automatic tuning for optimal performance
"max_autotune": True,
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
Expand All @@ -162,25 +165,49 @@ def preprocess(
]
), torch.no_grad():
# torch._logging.set_logs(post_grad_graphs=True)
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
# Here we should expect 1 so file and 1 weight blob in the same directory.
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
if len(missing_fallback_kernels) > 0:
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
raise RuntimeError(
f"Method {CudaBackend.method_name_from_compile_specs(compile_specs)} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
"Please add them to the AOTI backend."
)

# Extract the .so and .blob paths from the returned list
so_path = None
blob_path = None
for path in paths:
if path.endswith(".wrapper.so"):
so_path = path
elif path.endswith(".wrapper_weights.blob"):
blob_path = path

if so_path is None or blob_path is None:
raise RuntimeError(
f"Could not find required files in compiled paths, got {paths}"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if existence/non-existence of blob_path matches the options

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only check for blob_path is None if package_constant_on_disk_format option is set to "binary_blob"

# pyre-ignorep[6]: Incompatible parameter type
with open(so_path, "rb") as f:
so_data = f.read()

named_data_store = NamedDataStore()
method_name = CudaBackend.method_name_from_compile_specs(compile_specs)

# Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file.
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)

# Add weights blob to named data store
with open(blob_path, "rb") as f:
blob_data = f.read()
named_data_store.add_named_data(
method_name + "_so_blob", so_data, 1, "aoti_cuda_blob"
method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob"
)
# Clean up the weights blob file
os.remove(blob_path)

# Clean up the generated so file; it has been packaged into the NamdeDataStore
# Clean up the generated so file; it has been packaged into the NamedDataStore
# pyre-ignorep[6]: Incompatible parameter type
os.remove(so_path)

Expand Down
80 changes: 47 additions & 33 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@

namespace executorch::backends::cuda {

#define LOAD_SYMBOL(handle, member, name, so_handle) \
do { \
auto symbol_res = get_function(so_handle, #name); \
if (!symbol_res.ok()) { \
return symbol_res.error(); \
} \
handle->member = reinterpret_cast<name##Func>(symbol_res.get()); \
} while (0)

using namespace std;
using namespace aoti;

Expand All @@ -61,29 +52,37 @@ class ET_EXPERIMENTAL CudaBackend final
Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
LOAD_SYMBOL(
handle,
create_with_device,
AOTInductorModelContainerCreateWithDevice,
so_handle);
#define LOAD_SYMBOL(member, name) \
do { \
auto symbol_res = get_function(so_handle, #name); \
if (!symbol_res.ok()) { \
return symbol_res.error(); \
} \
handle->member = reinterpret_cast<name##Func>(symbol_res.get()); \
} while (0)

LOAD_SYMBOL(create_with_device, AOTInductorModelContainerCreateWithDevice);

LOAD_SYMBOL(
handle, delete_container, AOTInductorModelContainerDelete, so_handle);
LOAD_SYMBOL(delete_container, AOTInductorModelContainerDelete);

LOAD_SYMBOL(
handle,
get_num_inputs,
AOTInductorModelContainerGetNumInputs,
so_handle);
LOAD_SYMBOL(get_num_inputs, AOTInductorModelContainerGetNumInputs);

LOAD_SYMBOL(
handle,
get_num_outputs,
AOTInductorModelContainerGetNumOutputs,
so_handle);
LOAD_SYMBOL(get_num_outputs, AOTInductorModelContainerGetNumOutputs);

LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle);
LOAD_SYMBOL(run, AOTInductorModelContainerRun);
#undef LOAD_SYMBOL

auto symbol_res =
get_function(so_handle, "AOTInductorModelUpdateConstantsFromBlob");
if (symbol_res.ok()) {
handle->update_constants_from_blob =
reinterpret_cast<AOTInductorModelUpdateConstantsFromBlobFunc>(
symbol_res.get());
} else {
ET_LOG(
Info,
"Failed to load AOTInductorModelUpdateConstantsFromBlob. This .so is probably compiled on an old version of torch (<2.9.0)");
}
return Error::Ok;
}

Expand Down Expand Up @@ -112,13 +111,13 @@ class ET_EXPERIMENTAL CudaBackend final
method_name.empty() ? "so_blob" : method_name + "_so_blob";

const NamedDataMap* named_data_map = context.get_named_data_map();
auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str());
auto aoti_dso_buffer = named_data_map->get_data(so_blob_key.c_str());
ET_CHECK_OR_RETURN_ERROR(
aoti_cuda_buffer.ok(),
aoti_dso_buffer.ok(),
Internal,
"Failed to get data for key %s: 0x%x",
so_blob_key.c_str(),
static_cast<uint32_t>(aoti_cuda_buffer.error()));
static_cast<uint32_t>(aoti_dso_buffer.error()));

// Generate dynamic temporary file path
filesystem::path temp_dir = filesystem::temp_directory_path();
Expand All @@ -132,19 +131,21 @@ class ET_EXPERIMENTAL CudaBackend final
ET_LOG(
Info,
"Writing %zu bytes to %s",
aoti_cuda_buffer->size(),
aoti_dso_buffer->size(),
so_path.c_str());

outfile.write(
static_cast<const char*>(aoti_cuda_buffer->data()),
aoti_cuda_buffer->size());
static_cast<const char*>(aoti_dso_buffer->data()),
aoti_dso_buffer->size());

ET_CHECK_OR_RETURN_ERROR(
outfile, AccessFailed, "Failed to write to file %s", so_path.c_str());

// Finish writing the file to disk
outfile.close();

// Free the buffer immediately after writing to disk
aoti_dso_buffer->Free();
// Load the lib
Result<void*> lib_handle_res = load_library(so_path);
if (!lib_handle_res.ok()) {
Expand Down Expand Up @@ -172,6 +173,19 @@ class ET_EXPERIMENTAL CudaBackend final

handle->container_handle = container_handle;

// Look into named data map for constant data
std::string weights_blob_key =
method_name.empty() ? "weights_blob" : method_name + "_weights_blob";
auto buffer_res = named_data_map->get_data(weights_blob_key.c_str());
if (buffer_res.ok() && handle->update_constants_from_blob != nullptr) {
ET_LOG(Info, "Found %s in named data map", weights_blob_key.c_str());
const void* weights_blob = buffer_res->data();
// Feed the weights blob into the container. Under the hood it's copying
// weights, so we should free the buffer immediately.
Comment on lines +183 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the weights are mmapped, so this isn't halving the maximum amount of weights we can handle, right? even so, seems unfortunate that we have to copy and therefore can't keep them simply mmapped though; peak CPU memory now needs to hold them, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't halving the maximum amount of weights we can handle, right?

Good point, let me test this on my RTX 5080.

seems unfortunate that we have to copy and therefore can't keep them simply mmapped though

Yeah would be good if aoti can just take it without copying.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah would be good if aoti can just take it without copying.

I might be missing something here, but I assume you're mmaping into the CPU memory right? AOTI copies it into the CUDA memory, and since we're running this on CUDA, we have to copy it to CUDA some time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a mmap equivalent on CUDA? If so, on ET side we can create a dataloader to directly load into CUDA memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you can use GPUDirect Storage.

ET_CHECK_OK_OR_RETURN_ERROR(handle->update_constants_from_blob(
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
buffer_res->Free();
}
// Create a CUDA stream for asynchronous execution
cudaStream_t cuda_stream;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
Expand Down
6 changes: 3 additions & 3 deletions examples/models/moshi/mimi/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
set -x

conda install -c conda-forge "ffmpeg<8" -y
pip install torchcodec==0.7.0.dev20250929 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install moshi==0.2.4
pip install bitsandbytes soundfile
pip install torchcodec==0.7.0.dev20251012 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install moshi==0.2.11
pip install bitsandbytes soundfile einops
# Run llama2/install requirements for torchao deps
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
bash "$SCRIPT_DIR"/../../llama/install_requirements.sh
3 changes: 1 addition & 2 deletions examples/models/moshi/mimi/test_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def forward(self, x):
x = self.mimi_model.upsample(x)
(emb,) = self.mimi_model.decoder_transformer(x)
emb.transpose(1, 2)
with self.mimi_model._context_for_encoder_decoder:
out = self.mimi_model.decoder(emb)
out = self.mimi_model.decoder(emb)
return out

emb_input = torch.rand(1, 1, 512, device="cpu")
Expand Down
2 changes: 1 addition & 1 deletion examples/models/voxtral/multimodal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ int32_t main(int32_t argc, char** argv) {
// Create multimodal runner
std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner =
::executorch::extension::llm::create_multimodal_runner(
model_path, std::move(tokenizer), data_path);
model_path, std::move(tokenizer), data_path, Module::LoadMode::Mmap);
if (runner == nullptr) {
ET_LOG(Error, "Failed to create multimodal runner");
return 1;
Expand Down
8 changes: 4 additions & 4 deletions extension/llm/runner/llm_runner_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(
std::unique_ptr<MultimodalRunner> create_multimodal_runner(
const std::string& model_path,
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
std::optional<const std::string> data_path) {
std::optional<const std::string> data_path,
Module::LoadMode load_mode) {
// Sanity check tokenizer
if (!tokenizer || !tokenizer->is_loaded()) {
ET_LOG(Error, "Tokenizer is null or not loaded");
Expand All @@ -278,10 +279,9 @@ std::unique_ptr<MultimodalRunner> create_multimodal_runner(
// Create the Module
std::unique_ptr<Module> module;
if (data_path.has_value()) {
module = std::make_unique<Module>(
model_path, data_path.value(), Module::LoadMode::File);
module = std::make_unique<Module>(model_path, data_path.value(), load_mode);
} else {
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
module = std::make_unique<Module>(model_path, load_mode);
}

// Get metadata from Module
Expand Down
3 changes: 2 additions & 1 deletion extension/llm/runner/llm_runner_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ ET_EXPERIMENTAL std::unique_ptr<TextLLMRunner> create_text_llm_runner(
ET_EXPERIMENTAL std::unique_ptr<MultimodalRunner> create_multimodal_runner(
const std::string& model_path,
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
std::optional<const std::string> data_path = std::nullopt);
std::optional<const std::string> data_path = std::nullopt,
Module::LoadMode load_mode = Module::LoadMode::File);

} // namespace executorch::extension::llm
Loading
Loading