Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions examples/models/voxtral/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#
# Simple CMake build system for voxtral runner.
#
cmake_minimum_required(VERSION 3.24)
project(voxtral)

set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)

include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)

if(CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$")
set(CMAKE_TOOLCHAIN_IOS ON)
else()
set(CMAKE_TOOLCHAIN_IOS OFF)
endif()

# Let files say "include <executorch/path/to/header.h>"
set(_common_include_directories ${EXECUTORCH_ROOT}/..)

# Need this for gflags for some reason
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
find_package(gflags REQUIRED)

# Find `executorch` libraries, same as for gflags
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
executorch_target_link_options_shared_lib(executorch)

set(LINK_LIBS executorch gflags)
set(link_libraries ${LINK_LIBS})
set(_srcs multimodal.cpp)

list(
APPEND
link_libraries
optimized_native_cpu_ops_lib
quantized_ops_lib
custom_ops
cpublas
eigen_blas
)
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
executorch_target_link_options_shared_lib(quantized_ops_lib)
executorch_target_link_options_shared_lib(custom_ops)

# XNNPACK
if(TARGET xnnpack_backend)
set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod)
if(TARGET kleidiai)
list(APPEND xnnpack_backend_libs kleidiai)
endif()
list(APPEND link_libraries ${xnnpack_backend_libs})
executorch_target_link_options_shared_lib(xnnpack_backend)
endif()

# Add LLM runner and extension module
if(NOT TARGET extension_llm_runner)
message(
FATAL_ERROR
"ExecuTorch must be installed with EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER enabled."
)
endif()

# Needed for cpuinfo where it uses android specific log lib
if(ANDROID)
list(APPEND link_libraries log)
endif()

# Add the required ExecutorTorch extensions for multimodal LLM runner
list(
APPEND
link_libraries
extension_llm_runner
extension_module
extension_data_loader
extension_tensor
extension_flat_tensor
)

# Add tokenizers
list(APPEND link_libraries tokenizers::tokenizers)

add_executable(voxtral_runner ${_srcs})
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
target_link_options_gc_sections(voxtral_runner)
if(NOT APPLE)
target_link_options(voxtral_runner PRIVATE "LINKER:-s")
endif()
endif()

target_include_directories(voxtral_runner PUBLIC ${_common_include_directories})
target_link_libraries(voxtral_runner PUBLIC ${link_libraries})
target_compile_options(voxtral_runner PUBLIC ${_common_compile_options})
172 changes: 172 additions & 0 deletions examples/models/voxtral/multimodal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
#include <cstring>
#include <fstream>

#include <gflags/gflags.h>

#include <executorch/extension/llm/runner/audio.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/llm_runner_helper.h>
#include <executorch/extension/llm/runner/multimodal_input.h>
#include <executorch/extension/llm/runner/multimodal_runner.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/platform/log.h>

#if defined(ET_USE_THREADPOOL)
#include <executorch/extension/threadpool/cpuinfo_utils.h>
#include <executorch/extension/threadpool/threadpool.h>
#endif

DEFINE_string(
model_path,
"multimodal.pte",
"Model serialized in flatbuffer format.");

DEFINE_string(tokenizer_path, "tekken.json", "Tokenizer stuff.");

DEFINE_string(prompt, "What is happening in this audio?", "Text prompt.");

DEFINE_string(audio_path, "", "Path to input audio file.");

DEFINE_double(
temperature,
0.8f,
"Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");

DEFINE_int32(
cpu_threads,
-1,
"Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");

DEFINE_bool(warmup, false, "Whether to run a warmup run.");

namespace {

using ::executorch::extension::llm::Image;
using ::executorch::extension::llm::make_image_input;
using ::executorch::extension::llm::make_text_input;
using ::executorch::extension::llm::MultimodalInput;

} // namespace

int32_t main(int32_t argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

const char* model_path = FLAGS_model_path.c_str();

const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
const char* prompt = FLAGS_prompt.c_str();
const char* audio_path = FLAGS_audio_path.c_str();
float temperature = FLAGS_temperature;
int32_t cpu_threads = FLAGS_cpu_threads;
bool warmup = FLAGS_warmup;

#if defined(ET_USE_THREADPOOL)
uint32_t num_performant_cores = cpu_threads == -1
? ::executorch::extension::cpuinfo::get_num_performant_cores()
: static_cast<uint32_t>(cpu_threads);
ET_LOG(
Info, "Resetting threadpool with num threads = %d", num_performant_cores);
if (num_performant_cores > 0) {
::executorch::extension::threadpool::get_threadpool()
->_unsafe_reset_threadpool(num_performant_cores);
}
#endif

// Load tokenizer
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
::executorch::extension::llm::load_tokenizer(tokenizer_path);
if (tokenizer == nullptr) {
ET_LOG(Error, "Failed to load tokenizer from: %s", tokenizer_path);
return 1;
}

// Create multimodal runner
std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner =
::executorch::extension::llm::create_multimodal_runner(
model_path, std::move(tokenizer));
if (runner == nullptr) {
ET_LOG(Error, "Failed to create multimodal runner");
return 1;
}

// Load runner
auto load_error = runner->load();
if (load_error != ::executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to load multimodal runner");
return 1;
}

// Prepare inputs
std::vector<MultimodalInput> inputs;

// 1. Add start bos-related text inputs and modality start token.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is show we run audio inputs with multimodal runner?

inputs.emplace_back(make_text_input("<s>[INST][BEGIN_AUDIO]"));

// 2. Add audio input
// Using a preprocessed audio, saved using:
// with open("tensor.bin", "wb") as f:
// f.write(t.numpy().tobytes())
std::ifstream f(audio_path, std::ios::binary | std::ios::ate);
int32_t n_bins = 128;
int32_t n_frames = 3000;
std::size_t n_floats =
f.tellg() / sizeof(float); // Number of floats in the audio file.
f.seekg(0, std::ios::beg);
int32_t batch_size = ceil(
n_floats /
(n_bins * n_frames)); // Batch in increments of n_frames, rounding up.
std::vector<float> audio_data(batch_size * n_bins * n_frames);
f.read(
reinterpret_cast<char*>(audio_data.data()),
audio_data.size() * sizeof(float));

ET_LOG(Info, "audio_data len = %d", audio_data.size());

auto audio = std::make_unique<::executorch::extension::llm::Audio>();
audio->batch_size = batch_size;
audio->n_bins = n_bins;
audio->n_frames = n_frames;
audio->data.resize(audio_data.size() * sizeof(float));
std::memcpy(
audio->data.data(), audio_data.data(), audio_data.size() * sizeof(float));
inputs.emplace_back(
::executorch::extension::llm::make_audio_input(std::move(*audio)));

// 3. Add text input
inputs.emplace_back(make_text_input(std::string(prompt) + "[/INST]"));

::executorch::extension::llm::GenerationConfig config;
config.max_new_tokens = 100;
config.temperature = temperature;

// Run warmup if requested
if (warmup) {
ET_LOG(Info, "Running warmup...");
auto warmup_error = runner->generate(inputs, config);
if (warmup_error != ::executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to run warmup");
return 1;
}
runner->reset();
}

// Generate
ET_LOG(Info, "Starting generation...");
auto error = runner->generate(inputs, config);
if (error != ::executorch::runtime::Error::Ok) {
ET_LOG(Error, "Failed to generate with multimodal runner");
return 1;
}

printf("\n");
return 0;
}
57 changes: 51 additions & 6 deletions extension/llm/runner/llm_runner_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
#include <executorch/extension/llm/runner/text_llm_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/runner/text_token_generator.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/platform/runtime.h>
#include <pytorch/tokenizers/hf_tokenizer.h>
#include <pytorch/tokenizers/llama2c_tokenizer.h>
#include <pytorch/tokenizers/sentencepiece.h>
#include <pytorch/tokenizers/tekken.h>
#include <pytorch/tokenizers/tiktoken.h>

namespace executorch::extension::llm {
Expand All @@ -35,6 +37,18 @@ std::unique_ptr<tokenizers::Tokenizer> load_tokenizer(
size_t bos_token_index,
size_t eos_token_index) {
runtime::runtime_init();
auto tekken_tokenizer = std::make_unique<tokenizers::Tekken>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I follow what this is doing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf tokenizer can "load" the tekken tokenizer since it's also a json, which we don't want since the pattern I see here is that we keep loading different tokenizers until one works

// Prevent the case where tekken tokenizer accidentally successfully loads a
// HuggingFace tokenizer, which is also .json.
const std::string tekken_name = "tekken.json";
if (tokenizer_path.size() >= tekken_name.size() &&
tokenizer_path.rfind(tekken_name) ==
tokenizer_path.size() - tekken_name.size()) {
if (tekken_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
ET_LOG(Info, "Loaded tekken tokenizer");
return tekken_tokenizer;
}
}
auto json_tokenizer = std::make_unique<tokenizers::HFTokenizer>();
if (json_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) {
ET_LOG(Info, "Loaded json tokenizer");
Expand Down Expand Up @@ -73,9 +87,8 @@ std::unique_ptr<tokenizers::Tokenizer> load_tokenizer(
return nullptr;
}

std::unordered_map<std::string, int64_t> get_llm_metadata(
tokenizers::Tokenizer* tokenizer,
Module* module) {
::executorch::runtime::Result<std::unordered_map<std::string, int64_t>>
get_llm_metadata(tokenizers::Tokenizer* tokenizer, Module* module) {
// Initialize metadata with default values
std::unordered_map<std::string, int64_t> metadata({
{llm::kEnableDynamicShape, false},
Expand All @@ -89,10 +102,20 @@ std::unordered_map<std::string, int64_t> get_llm_metadata(
auto method_names_result = module->method_names();
if (method_names_result.error() != Error::Ok) {
ET_LOG(Error, "Failed reading method names");
return metadata;
return ::executorch::runtime::Error::InvalidArgument;
}
const auto& method_names = method_names_result.get();

// Error out if the max seq len metadata method is not present, since
// it is hard to figure out from just the .pte itself.
if (!method_names.count(llm::kMaxSeqLen)) {
ET_LOG(
Error,
"Required metadata method %s not found in model",
llm::kMaxSeqLen);
return ::executorch::runtime::Error::InvalidArgument;
}

for (auto& pair : metadata) {
const auto& method_name = pair.first;
auto& value = pair.second;
Expand All @@ -109,6 +132,18 @@ std::unordered_map<std::string, int64_t> get_llm_metadata(
}
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
}

// If kMaxContextLen method not found but kMaxSeqLen is
// available, set kMaxContextLen to the value of kMaxSeqLen.
if (!method_names.count(llm::kMaxContextLen) &&
method_names.count(llm::kMaxSeqLen)) {
metadata[llm::kMaxContextLen] = metadata[llm::kMaxSeqLen];
ET_LOG(
Info,
"Setting kMaxContextLen to kMaxSeqLen value: %" PRId64,
metadata[llm::kMaxContextLen]);
}

// Set tokenizer-related metadata
metadata[llm::kBosId] = tokenizer->bos_tok();
metadata[llm::kVocabSize] = tokenizer->vocab_size();
Expand Down Expand Up @@ -165,7 +200,12 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(

// Get metadata from Module
ET_LOG(Info, "Reading metadata from model");
auto metadata = llm::get_llm_metadata(tokenizer.get(), module.get());
auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get());
if (metadata_result.error() != Error::Ok) {
ET_LOG(Error, "Failed to get metadata from model");
return nullptr;
}
auto metadata = metadata_result.get();

auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
llm::get_eos_ids(tokenizer.get(), module.get()));
Expand Down Expand Up @@ -228,7 +268,12 @@ std::unique_ptr<MultimodalRunner> create_multimodal_runner(

// Get metadata from Module
ET_LOG(Info, "Reading metadata from model");
auto metadata = get_llm_metadata(tokenizer.get(), module.get());
auto metadata_result = get_llm_metadata(tokenizer.get(), module.get());
if (metadata_result.error() != Error::Ok) {
ET_LOG(Error, "Failed to get metadata from model");
return nullptr;
}
auto metadata = metadata_result.get();

auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
get_eos_ids(tokenizer.get(), module.get()));
Expand Down
Loading
Loading