@@ -32,13 +32,15 @@ int main(int argc, char ** argv) {
3232 gpt_params params;
3333
3434 if (argc == 1 || argv[1 ][0 ] == ' -' ) {
35- printf (" usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n " , argv[0 ]);
35+ printf (" usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [ IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n " , argv[0 ]);
3636 printf (" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n " );
37- printf (" example: %s ggml-model-f16.gguf 2048 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n " , argv[0 ]);
37+ printf (" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n " , argv[0 ]);
3838 return 1 ;
3939 }
4040
4141 int n_kv_max = 2048 ;
42+ int n_batch = 2048 ;
43+ int n_ubatch = 512 ;
4244 int is_pp_shared = 0 ;
4345 int n_gpu_layers = 0 ;
4446
@@ -56,23 +58,31 @@ int main(int argc, char ** argv) {
5658 }
5759
5860 if (argc >= 4 ) {
59- is_pp_shared = std::atoi (argv[3 ]);
61+ n_batch = std::atoi (argv[3 ]);
6062 }
6163
6264 if (argc >= 5 ) {
63- n_gpu_layers = std::atoi (argv[4 ]);
65+ n_ubatch = std::atoi (argv[4 ]);
6466 }
6567
6668 if (argc >= 6 ) {
67- n_pp = parse_list (argv[5 ]);
69+ is_pp_shared = std::atoi (argv[5 ]);
6870 }
6971
7072 if (argc >= 7 ) {
71- n_tg = parse_list (argv[6 ]);
73+ n_gpu_layers = std::atoi (argv[6 ]);
7274 }
7375
7476 if (argc >= 8 ) {
75- n_pl = parse_list (argv[7 ]);
77+ n_pp = parse_list (argv[7 ]);
78+ }
79+
80+ if (argc >= 9 ) {
81+ n_tg = parse_list (argv[8 ]);
82+ }
83+
84+ if (argc >= 10 ) {
85+ n_pl = parse_list (argv[9 ]);
7686 }
7787
7888 // init LLM
@@ -100,7 +110,8 @@ int main(int argc, char ** argv) {
100110
101111 ctx_params.seed = 1234 ;
102112 ctx_params.n_ctx = n_kv_max;
103- ctx_params.n_batch = 512 ;
113+ ctx_params.n_batch = n_batch;
114+ ctx_params.n_ubatch = n_ubatch;
104115
105116 ctx_params.n_threads = params.n_threads ;
106117 ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch ;
@@ -158,7 +169,7 @@ int main(int argc, char ** argv) {
158169 }
159170
160171 LOG_TEE (" \n " );
161- LOG_TEE (" %s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, is_pp_shared, n_gpu_layers, ctx_params.n_threads , ctx_params.n_threads_batch );
172+ LOG_TEE (" %s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, n_batch, n_ubatch , is_pp_shared, n_gpu_layers, ctx_params.n_threads , ctx_params.n_threads_batch );
162173 LOG_TEE (" \n " );
163174
164175 LOG_TEE (" |%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n " , " PP" , " TG" , " B" , " N_KV" , " T_PP s" , " S_PP t/s" , " T_TG s" , " S_TG t/s" , " T s" , " S t/s" );
0 commit comments