diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 5179bf28fc7..d75b152be1f 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -20,6 +20,8 @@ DEFINE_string( "llama2.pte", "Model serialized in flatbuffer format."); +DEFINE_string(data_path, "", "Data file for the model."); + DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); @@ -49,6 +51,11 @@ int32_t main(int32_t argc, char** argv) { // and users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); + std::optional data_path = std::nullopt; + if (!FLAGS_data_path.empty()) { + data_path = FLAGS_data_path.c_str(); + } + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); const char* prompt = FLAGS_prompt.c_str(); @@ -74,7 +81,7 @@ int32_t main(int32_t argc, char** argv) { #endif // create llama runner // @lint-ignore CLANGTIDY facebook-hte-Deprecated - example::Runner runner(model_path, tokenizer_path); + example::Runner runner(model_path, tokenizer_path, data_path); if (warmup) { // @lint-ignore CLANGTIDY facebook-hte-Deprecated