|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated |
| 8 | + */ |
| 9 | +#include <gflags/gflags.h> |
| 10 | + |
| 11 | +#include <executorch/examples/models/llama/runner/runner.h> |
| 12 | + |
| 13 | +#if defined(ET_USE_THREADPOOL) |
| 14 | +#include <executorch/extension/threadpool/cpuinfo_utils.h> |
| 15 | +#include <executorch/extension/threadpool/threadpool.h> |
| 16 | +#endif |
| 17 | + |
| 18 | +#if defined(__linux__) || defined(__ANDROID__) || defined(__unix__) |
| 19 | +#include <sys/resource.h> |
| 20 | +#endif |
| 21 | +size_t inline get_rss_bytes() { |
| 22 | +#if defined(__linux__) || defined(__ANDROID__) || defined(__unix__) |
| 23 | + struct rusage r_usage; |
| 24 | + if (getrusage(RUSAGE_SELF, &r_usage) == 0) { |
| 25 | + return r_usage.ru_maxrss * 1024; |
| 26 | + } |
| 27 | +#endif // __linux__ || __ANDROID__ || __unix__ |
| 28 | + // Unsupported platform like Windows, or getrusage() failed. |
| 29 | + // __APPLE__ and __MACH__ are not supported because r_usage.ru_maxrss does not |
| 30 | + // consistently return kbytes on macOS. On older versions of macOS, it |
| 31 | + // returns bytes, but on newer versions it returns kbytes. Need to figure out |
| 32 | + // when this changed. |
| 33 | + return 0; |
| 34 | +} |
| 35 | + |
| 36 | +DEFINE_string(lora_model_path, "llama_3_2_1B_lora.pte", |
| 37 | + "LoRA model serialized in flatbuffer format."); |
| 38 | +DEFINE_string(llama_model_path, "llama_3_2_1B.pte", |
| 39 | + "Model serialized in flatbuffer format."); |
| 40 | +DEFINE_string(data_path, "foundation.ptd", |
| 41 | + "Data serialized in flatbuffer format."); |
| 42 | + |
| 43 | +DEFINE_string(tokenizer_path, "tokenizer.model", "Tokenizer stuff."); |
| 44 | + |
| 45 | +DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); |
| 46 | + |
| 47 | +DEFINE_double(temperature, 0.8f, |
| 48 | + "Temperature; Default is 0.8f. 0 = greedy argmax sampling " |
| 49 | + "(deterministic). Lower temperature = more deterministic"); |
| 50 | + |
| 51 | +DEFINE_int32( |
| 52 | + seq_len, 128, |
| 53 | + "Total number of tokens to generate (prompt + output). Defaults to " |
| 54 | + "max_seq_len. If the number of input tokens + seq_len > max_seq_len, the " |
| 55 | + "output will be truncated to max_seq_len tokens."); |
| 56 | + |
| 57 | +using namespace ::executorch::extension; |
| 58 | + |
| 59 | +int main(int argc, char *argv[]) { |
| 60 | + ET_LOG(Info, "Running program-data separation lora example..."); |
| 61 | + |
| 62 | + auto rss_0 = get_rss_bytes() / 1024.0 / 1024.0; |
| 63 | + ET_LOG(Info, "0 RSS start: %f MiB (0 if unsupported)", rss_0); |
| 64 | + |
| 65 | + gflags::ParseCommandLineFlags(&argc, &argv, true); |
| 66 | + |
| 67 | + const char *lora_model_path = FLAGS_lora_model_path.c_str(); |
| 68 | + const char *llama_model_path = FLAGS_llama_model_path.c_str(); |
| 69 | + const char *data_path = FLAGS_data_path.c_str(); |
| 70 | + |
| 71 | + const char *tokenizer_path = FLAGS_tokenizer_path.c_str(); |
| 72 | + const char *prompt = FLAGS_prompt.c_str(); |
| 73 | + float temperature = FLAGS_temperature; |
| 74 | + int32_t seq_len = 128; |
| 75 | + int32_t cpu_threads = -1; |
| 76 | + |
| 77 | +#if defined(ET_USE_THREADPOOL) |
| 78 | + uint32_t num_performant_cores = |
| 79 | + cpu_threads == -1 |
| 80 | + ? ::executorch::extension::cpuinfo::get_num_performant_cores() |
| 81 | + : static_cast<uint32_t>(cpu_threads); |
| 82 | + ET_LOG(Info, "Resetting threadpool with num threads = %d", |
| 83 | + num_performant_cores); |
| 84 | + if (num_performant_cores > 0) { |
| 85 | + ::executorch::extension::threadpool::get_threadpool() |
| 86 | + ->_unsafe_reset_threadpool(num_performant_cores); |
| 87 | + } |
| 88 | +#endif |
| 89 | + |
| 90 | + // Create runner for lora model. |
| 91 | + auto rss_1 = get_rss_bytes() / 1024.0 / 1024.0; |
| 92 | + ET_LOG(Info, "1 RSS before creating lora_runner: %f MiB (0 if unsupported)", |
| 93 | + rss_1); |
| 94 | + std::unique_ptr<::executorch::extension::llm::TextLLMRunner> lora_runner = |
| 95 | + example::create_llama_runner(lora_model_path, tokenizer_path, data_path); |
| 96 | + if (lora_runner == nullptr) { |
| 97 | + ET_LOG(Error, "Failed to create lora_runner."); |
| 98 | + return 1; |
| 99 | + } |
| 100 | + |
| 101 | + // create runner for llama model |
| 102 | + auto rss_2 = get_rss_bytes() / 1024.0 / 1024.0; |
| 103 | + ET_LOG(Info, "2 RSS before creating llama_runner: %f MiB (0 if unsupported)", |
| 104 | + rss_2); |
| 105 | + std::unique_ptr<::executorch::extension::llm::TextLLMRunner> llama_runner = |
| 106 | + example::create_llama_runner(llama_model_path, tokenizer_path, data_path); |
| 107 | + if (llama_runner == nullptr) { |
| 108 | + ET_LOG(Error, "Failed to create llama_runner."); |
| 109 | + return 1; |
| 110 | + } |
| 111 | + auto rss_3 = get_rss_bytes() / 1024.0 / 1024.0; |
| 112 | + ET_LOG(Info, |
| 113 | + "3 RSS before after creating the runners: %f MiB (0 if unsupported)", |
| 114 | + rss_3); |
| 115 | + |
| 116 | + // generate |
| 117 | + executorch::extension::llm::GenerationConfig config{ |
| 118 | + .seq_len = seq_len, .temperature = temperature}; |
| 119 | + |
| 120 | + ET_LOG(Info, "Generating with lora..."); |
| 121 | + auto rss_4 = get_rss_bytes() / 1024.0 / 1024.0; |
| 122 | + ET_LOG(Info, "4 RSS before running lora_runner: %f MiB (0 if unsupported)", |
| 123 | + rss_4); |
| 124 | + auto error = lora_runner->generate(prompt, config); |
| 125 | + if (error != executorch::runtime::Error::Ok) { |
| 126 | + ET_LOG(Error, "Failed to generate with lora_runner, error code %zu.", |
| 127 | + error); |
| 128 | + return 1; |
| 129 | + } |
| 130 | + auto rss_5 = get_rss_bytes() / 1024.0 / 1024.0; |
| 131 | + ET_LOG(Info, |
| 132 | + "5 RSS after lora_runner/before llama_runner: %f MiB " |
| 133 | + "(0 if unsupported)", |
| 134 | + rss_5); |
| 135 | + ET_LOG(Info, "Generating with llama..."); |
| 136 | + error = lora_runner->generate(prompt, config); |
| 137 | + if (error != executorch::runtime::Error::Ok) { |
| 138 | + ET_LOG(Error, "Failed to generate with llama_runner, error code %zu.", |
| 139 | + error); |
| 140 | + return 1; |
| 141 | + } |
| 142 | + auto rss_6 = get_rss_bytes() / 1024.0 / 1024.0; |
| 143 | + ET_LOG(Info, |
| 144 | + "6 RSS after llama_runner: %f MiB " |
| 145 | + "(0 if unsupported)", |
| 146 | + rss_6); |
| 147 | + |
| 148 | + return 0; |
| 149 | +} |
0 commit comments