Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions program-data-separation/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,57 @@ option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
option(EXECUTORCH_BUILD_XNNPACK "" ON)

# Add ExecuTorch subdirectory
# Dependencies required for llm runner in lora demo.
if(EXECUTORCH_BUILD_LORA_DEMO)
option(EXECUTORCH_BUILD_EXTENSION_LLM "" ON)
option(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER "" ON)
option(EXECUTORCH_BUILD_KERNELS_LLM "" ON)
option(EXECUTORCH_BUILD_KERNELS_LLM_AOT "" ON)
endif()

# Add ExecuTorch subdirectory, after setting options.
add_subdirectory("executorch")

set(DEMO_SOURCES linear_example/main.cpp)
set(LINK_LIBS executorch
executorch::extensions
xnnpack_backend
# NOTE: xnnpack_backend has to go before
# kernels otherwise it doesn't get registered.
executorch::kernels
gflags
)

# Add sources and dependencies.
set(DEMO_SOURCES "")
if(EXECUTORCH_BUILD_LINEAR_DEMO)
list(APPEND DEMO_SOURCES "linear_example/main.cpp")
endif()
if(EXECUTORCH_BUILD_LORA_DEMO)
list(APPEND DEMO_SOURCES "lora_example/main.cpp")
endif()

# Create executable
add_executable(executorch_program_data_separation ${DEMO_SOURCES})

# Include directories
target_include_directories(executorch_program_data_separation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})

# Link libraries
target_link_libraries(
executorch_program_data_separation
PRIVATE executorch
extension_module_static
extension_flat_tensor
extension_tensor
xnnpack_backend
portable_ops_lib
portable_kernels
gflags
PRIVATE ${LINK_LIBS}
)

# Include directories for lora demo.
if(EXECUTORCH_BUILD_LORA_DEMO)
# Include directories
target_include_directories(executorch_program_data_separation PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/executorch/extension/llm/tokenizers/include
)
target_link_libraries(
executorch_program_data_separation
PUBLIC tokenizers::tokenizers
)
endif()

# Set output directory
set_target_properties(executorch_program_data_separation
PROPERTIES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mkdir -p build
cd build

# Configure CMake
cmake -DCMAKE_BUILD_TYPE=Release ../..
cmake -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_BUILD_LINEAR_DEMO=True ../..

# Build the project
cmake --build . -j$(nproc)
Expand Down
88 changes: 88 additions & 0 deletions program-data-separation/cpp/lora_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# ExecuTorch Program Data Separation Demo C++.

This directory contains the C++ code to run the examples generated in [program-data-separation](../program-data-separation/README.md).


## Virtual environment setup.
Create and activate a Python virtual environment:
```bash
python3 -m venv .venv && source .venv/bin/activate && pip install --upgrade pip
```
Or alternatively, [install conda on your machine](https://conda.io/projects/conda/en/latest/user-guide/install/index.html)
```bash
conda create -yn executorch-ptd python=3.10.0 && conda activate executorch-ptd
```

Install dependencies:
LoRA isn't available in the 0.7.0 release of ExecuTorch. Instead, please install from source until ExecuTorch 1.0 is released.

[Install ExecuTorch pip package from source](https://docs.pytorch.org/executorch/stable/using-executorch-building-from-source.html#install-executorch-pip-package-from-source).

Currently, the LoRA changes aren't in nightlies. Once they are in, you can also install from the nightly build.
```
pip install executorch==0.8.0.devYYYYMMDD --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```

## Export the model/s.
Change into the program-data-separation directory and create a directory to hold exported artifacts.
```bash
cd ~/executorch-examples/program-data-separation
mkdir models
```

Export models into the `models` directory. The first command will generated undelegated model/data files, and the second will generate XNNPACK-delegated model/data files.
```bash
sh export_lora.sh
```
Expect the files:
- llama_3_2_1B.pte
- llama_3_2_1B.ptd
- llama_3_2_1B_lora.pte
- foundation_weights.ptd
- tokenizer.model

llama_3_2_1B.ptd and foundation_weights.ptd contain the same contents, and you can remove llama_3_2_1B.ptd.
tokenizer.model is copied from the temp directory where we downloaded the HF artifacts. It will be used at runtime.

Note:
- PTE: contains the program execution logic.
- PTD: contains the constant tensors used by the PTE.

## Install runtime dependencies.
The ExecuTorch repository is configured as a git submodule at `~/executorch-examples/program-data-separation/cpp/executorch`. To initialize it:
```bash
cd ~/executorch-examples/
git submodule sync
git submodule update --init --recursive
```
Install dev requirements for ExecuTorch

```bash
cd ~/executorch-examples/program-data-separation/cpp/executorch
pip install -r requirements-dev.txt
```

## Build the runtime.
Install some dependencies:
```bash
cd ~/executorch-examples/program-data-separation/cpp/executorch
sh examples/models/llama/install_requirements.sh
```

Build the executable:
```bash
cd ~/executorch-examples/program-data-separation/cpp/lora_example
sh build_example.sh
```

## Run the executable.
```bash
cd ~/executorch-examples/program-data-separation/cpp/lora_example

./build/bin/executorch_program_data_separation --lora_model_path=../../llama_3_2_1B_lora.pte --llama_model_path=../../llama_3_2_1B.pte --tokenizer_path=../../tokenizer.model --data_path=../../foundation.ptd
```

## Clean up.
rm -rf build
cd ~/executorch-examples/program-data-separation
rm -rf *.pte *.ptd tokenizer.model
15 changes: 15 additions & 0 deletions program-data-separation/cpp/lora_example/build_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash
set -e

# Clean and create build directory if it doesn't exist
rm -rf build
mkdir -p build
cd build

# Configure CMake
cmake -DCMAKE_BUILD_TYPE=Release -DEXECUTORCH_BUILD_LORA_DEMO=True -DEXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE=True ../..

# Build the project
cmake --build . -j$(nproc)

echo "Build complete! Executable located at: ./build/bin/executorch_program_data_separation"
128 changes: 128 additions & 0 deletions program-data-separation/cpp/lora_example/main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
*/

#include <memory>
#include <string>
#include <vector>

#include <gflags/gflags.h>

#include <executorch/extension/llm/runner/llm_runner_helper.h>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_llm_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/runner/text_token_generator.h>

#if defined(ET_USE_THREADPOOL)
#include <executorch/extension/threadpool/cpuinfo_utils.h>
#include <executorch/extension/threadpool/threadpool.h>
#endif

DEFINE_string(lora_model_path, "llama_3_2_1B_lora.pte",
"LoRA model serialized in flatbuffer format.");
DEFINE_string(llama_model_path, "llama_3_2_1B.pte",
"Model serialized in flatbuffer format.");
DEFINE_string(data_path, "foundation.ptd",
"Data serialized in flatbuffer format.");

DEFINE_string(tokenizer_path, "tokenizer.model", "Tokenizer stuff.");

DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");

DEFINE_double(temperature, 0,
"Temperature; Default is 0. 0 = greedy argmax sampling "
"(deterministic). Lower temperature = more deterministic");

DEFINE_int32(
seq_len, 128,
"Total number of tokens to generate (prompt + output). Defaults to "
"max_seq_len. If the number of input tokens + seq_len > max_seq_len, the "
"output will be truncated to max_seq_len tokens.");

using executorch::extension::Module;
using executorch::runtime::Error;
namespace llm = executorch::extension::llm;

namespace {
static constexpr int32_t kSpecialTokensSize = 256;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add these to tokenizer or llm API.

static inline std::unique_ptr<std::vector<std::string>>
_get_default_special_tokens() {
auto special_tokens =
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
"<|begin_of_text|>", "<|end_of_text|>",
"<|reserved_special_token_0|>", "<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>", "<|step_id|>", "<|start_header_id|>",
"<|end_header_id|>", "<|eom_id|>", "<|eot_id|>", "<|python_tag|>"});
// pad the rest of the special tokens with reserved tokens
ssize_t reserved_special_token_num = 2;
while (special_tokens->size() < kSpecialTokensSize) {
special_tokens->emplace_back("<|reserved_special_token_" +
std::to_string(reserved_special_token_num++) +
"|>");
}
return special_tokens;
}
} // namespace

int main(int argc, char *argv[]) {
ET_LOG(Info, "Running program-data separation lora example...");

gflags::ParseCommandLineFlags(&argc, &argv, true);

const char *lora_model_path = FLAGS_lora_model_path.c_str();
const char *llama_model_path = FLAGS_llama_model_path.c_str();
const char *data_path = FLAGS_data_path.c_str();

const char *tokenizer_path = FLAGS_tokenizer_path.c_str();
const char *prompt = FLAGS_prompt.c_str();
float temperature = FLAGS_temperature;
int32_t seq_len = 128;
int32_t cpu_threads = -1;

// Create tokenizers.
std::unique_ptr<tokenizers::Tokenizer> tokenizer1 =
llm::load_tokenizer(tokenizer_path, _get_default_special_tokens());
std::unique_ptr<tokenizers::Tokenizer> tokenizer2 =
llm::load_tokenizer(tokenizer_path, _get_default_special_tokens());

if (tokenizer1 == nullptr || tokenizer2 == nullptr) {
ET_LOG(Info,
"Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c "
"tokenizer, make sure the artifact is one of these types",
tokenizer_path);
return 1;
}

// Create runners.
std::unique_ptr<llm::TextLLMRunner> llama_runner =
llm::create_text_llm_runner(llama_model_path, std::move(tokenizer1),
data_path, temperature);
std::unique_ptr<llm::TextLLMRunner> lora_runner = llm::create_text_llm_runner(
lora_model_path, std::move(tokenizer2), data_path, temperature);

// Generate.
llm::GenerationConfig config{.seq_len = seq_len, .temperature = temperature};

ET_LOG(Info, "Generating with llama...");
auto error = llama_runner->generate(prompt, config);
if (error != Error::Ok) {
ET_LOG(Error, "Failed to generate with llama_runner, error code %zu.",
error);
return 1;
}

error = lora_runner->generate(prompt, config);
if (error != Error::Ok) {
ET_LOG(Error, "Failed to generate with lora_runner, error code %zu.",
error);
return 1;
}

return 0;
}
53 changes: 53 additions & 0 deletions program-data-separation/export_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -exu

python -m pip install torchtune==0.7.0.dev20250730 --extra-index-url https://download.pytorch.org/whl/nightly/cpu

# Download model artifacts from HF.
DOWNLOADED_PATH=$(python -c "
from huggingface_hub import snapshot_download
path=snapshot_download(
repo_id=\"lucylq/llama3_1B_lora\",
)
import os
print(path)
")

# Copy over tokenizer, for use at runtime.
cp "${DOWNLOADED_PATH}/tokenizer.model" .

# Export a non-LoRA model with program-data separated.
MODEL="llama_3_2_1B"
python -m executorch.extension.llm.export.export_llm \
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
base.params="${DOWNLOADED_PATH}/params.json" \
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
model.use_kv_cache=true \
model.use_sdpa_with_kv_cache=true \
model.dtype_override="fp32" \
backend.xnnpack.enabled=true \
backend.xnnpack.extended_ops=true \
export.output_name="${MODEL}.pte" \
export.foundation_weights_file="${MODEL}.ptd"

# Export a LoRA model, with program and data separated.
LORA_MODEL="llama_3_2_1B_lora"
python -m executorch.extension.llm.export.export_llm \
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
base.params="${DOWNLOADED_PATH}/params.json" \
base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \
base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
model.use_kv_cache=true \
model.use_sdpa_with_kv_cache=true \
model.dtype_override="fp32" \
backend.xnnpack.enabled=true \
backend.xnnpack.extended_ops=true \
export.output_name="${LORA_MODEL}.pte" \
export.foundation_weights_file="foundation.ptd"