@@ -32,7 +32,6 @@ int main(int argc, char ** argv) {
3232 // }
3333 // }, NULL);
3434
35- auto mparams = common_model_params_to_llama (params);
3635 auto cparams = common_context_params_to_llama (params);
3736
3837 int dev_count = ggml_backend_dev_count ();
@@ -43,7 +42,7 @@ int main(int argc, char ** argv) {
4342 gpu_dev_count++;
4443 }
4544 }
46- const int num_models = gpu_dev_count + 1 ; // GPUs + 1 CPU model
45+ const int num_models = gpu_dev_count + 1 + 1 ; // GPUs + 1 CPU model + 1 layer split
4746 // const int num_models = std::max(1, gpu_dev_count);
4847 const int num_contexts = std::max (1 , params.n_parallel );
4948
@@ -52,8 +51,17 @@ int main(int argc, char ** argv) {
5251 std::atomic<bool > failed = false ;
5352
5453 for (int m = 0 ; m < num_models; ++m) {
55- mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
56- mparams.main_gpu = m < gpu_dev_count ? m : -1 ;
54+ auto mparams = common_model_params_to_llama (params);
55+
56+ if (m < gpu_dev_count) {
57+ mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
58+ mparams.main_gpu = m;
59+ } else if (m == gpu_dev_count) {
60+ mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
61+ mparams.main_gpu = -1 ; // CPU model
62+ } else {
63+ mparams.split_mode = LLAMA_SPLIT_MODE_LAYER;;
64+ }
5765
5866 llama_model * model = llama_model_load_from_file (params.model .path .c_str (), mparams);
5967 if (model == NULL ) {
@@ -111,20 +119,21 @@ int main(int argc, char ** argv) {
111119 token = llama_vocab_bos (vocab);
112120 }
113121
122+ result += common_token_to_piece (ctx.get (), token);
123+
114124 if (llama_vocab_is_eog (vocab, token)) {
115125 break ;
116126 }
117- result += common_token_to_piece (ctx.get (), token);
118127
119128 batch = llama_batch_get_one (&token, 1 );
120129 if (llama_decode (ctx.get (), batch)) {
121- LOG_ERR (" failed to decode\n " );
130+ LOG_ERR (" Model %d/%d, Context %d/%d: failed to decode\n " , m + 1 , num_models, c + 1 , num_contexts );
122131 failed.store (true );
123132 return ;
124133 }
125134 }
126135
127- LOG_INF (" Model %d/%d, Context %d/%d: Result: '%s' \n " , m + 1 , num_models, c + 1 , num_contexts, result.c_str ());
136+ LOG_INF (" Model %d/%d, Context %d/%d: %s \n \n" , m + 1 , num_models, c + 1 , num_contexts, result.c_str ());
128137 });
129138 }
130139 }
@@ -138,6 +147,6 @@ int main(int argc, char ** argv) {
138147 return 1 ;
139148 }
140149
141- LOG_INF (" All threads completed successfully .\n " );
150+ LOG_INF (" All threads finished without errors .\n " );
142151 return 0 ;
143152}
0 commit comments