From c24ec7d0ee306fbaf84d6a9b31d75c2a06b8292b Mon Sep 17 00:00:00 2001 From: lucylq Date: Thu, 24 Apr 2025 16:57:58 -0700 Subject: [PATCH] Add data path to runner Pull Request resolved: https://github.com/pytorch/executorch/pull/10445 ^ ghstack-source-id: 280183822 @exported-using-ghexport Differential Revision: [D73617661](https://our.internmc.facebook.com/intern/diff/D73617661/) --- examples/models/llama/main.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 5179bf28fc7..d75b152be1f 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -20,6 +20,8 @@ DEFINE_string( "llama2.pte", "Model serialized in flatbuffer format."); +DEFINE_string(data_path, "", "Data file for the model."); + DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); @@ -49,6 +51,11 @@ int32_t main(int32_t argc, char** argv) { // and users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); + std::optional data_path = std::nullopt; + if (!FLAGS_data_path.empty()) { + data_path = FLAGS_data_path.c_str(); + } + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); const char* prompt = FLAGS_prompt.c_str(); @@ -74,7 +81,7 @@ int32_t main(int32_t argc, char** argv) { #endif // create llama runner // @lint-ignore CLANGTIDY facebook-hte-Deprecated - example::Runner runner(model_path, tokenizer_path); + example::Runner runner(model_path, tokenizer_path, data_path); if (warmup) { // @lint-ignore CLANGTIDY facebook-hte-Deprecated