@@ -59,9 +59,22 @@ DEFINE_string(
5959 " SmartMask" );
6060DEFINE_int32 (num_iters, 1 , " total num of iterations to run." );
6161
62+ std::vector<std::string> CollectPrompts (int argc, char ** argv) {
63+ // Collect all prompts from command line, example usage:
64+ // --prompt "prompt1" --prompt "prompt2" --prompt "prompt3"
65+ std::vector<std::string> prompts;
66+ for (int i = 1 ; i < argc; i++) {
67+ if (std::string (argv[i]) == " --prompt" && i + 1 < argc) {
68+ prompts.push_back (argv[i + 1 ]);
69+ i++; // Skip the next argument
70+ }
71+ }
72+ return prompts;
73+ }
74+
6275int main (int argc, char ** argv) {
76+ std::vector<std::string> prompts = CollectPrompts (argc, argv);
6377 gflags::ParseCommandLineFlags (&argc, &argv, true );
64-
6578 // create llama runner
6679 example::Runner runner (
6780 {FLAGS_model_path},
@@ -83,11 +96,10 @@ int main(int argc, char** argv) {
8396 };
8497 // generate tokens & store inference output
8598 for (int i = 0 ; i < FLAGS_num_iters; i++) {
86- runner.generate (
87- FLAGS_seq_len,
88- FLAGS_prompt.c_str (),
89- FLAGS_system_prompt.c_str (),
90- callback);
99+ for (const auto & prompt : prompts) {
100+ runner.generate (
101+ FLAGS_seq_len, prompt.c_str (), FLAGS_system_prompt.c_str (), callback);
102+ }
91103 }
92104 fout.write (buf.data (), buf.size ());
93105 fout.close ();
0 commit comments