@@ -1039,7 +1039,12 @@ namespace chatllm
10391039 while (!aborted && !completed && (n_past + (int )curr_input_ids.size () < gen_config.max_length ))
10401040 {
10411041 std::vector<float > lm_logits;
1042- generate_next_token (curr_input_ids, gen_config, lm_logits);
1042+ if (!generate_next_token (curr_input_ids, gen_config, lm_logits))
1043+ {
1044+ ggml::log (GGML_LOG_LEVEL_ERROR, " Out of memory" );
1045+ aborted = true ;
1046+ break ;
1047+ }
10431048
10441049 if (first_call)
10451050 {
@@ -1113,29 +1118,35 @@ namespace chatllm
11131118 void text_embedding (const GenerationConfig &gen_config, const std::vector<int > &input_ids,
11141119 std::vector<float > &embedding) override
11151120 {
1116- run_model (input_ids, gen_config, 0 , embedding);
1121+ auto r = run_model (input_ids, gen_config, 0 , embedding);
1122+ if (!r) ggml::log (GGML_LOG_LEVEL_ERROR, " Out of memory" );
11171123 }
11181124
11191125 float qa_rank (const GenerationConfig &gen_config, const std::vector<int > &input_ids) override
11201126 {
11211127 std::vector<float > output;
1122- run_model (input_ids, gen_config, 0 , output);
1128+ auto r = run_model (input_ids, gen_config, 0 , output);
1129+ if (!r) ggml::log (GGML_LOG_LEVEL_ERROR, " Out of memory" );
11231130 CHATLLM_CHECK (output.size () == 1 ) << " ouput must be scaler" ;
11241131
11251132 return output[0 ];
11261133 }
11271134
1128- void generate_next_token (const std::vector<int > &input_ids, const GenerationConfig &gen_config, std::vector<float > &lm_logits) override
1135+ bool generate_next_token (const std::vector<int > &input_ids, const GenerationConfig &gen_config, std::vector<float > &lm_logits) override
11291136 {
11301137 if (batch_input)
11311138 {
1132- run_model (input_ids, gen_config, n_past + n_past_offset, lm_logits);
1139+ return run_model (input_ids, gen_config, n_past + n_past_offset, lm_logits);
11331140 }
11341141 else
11351142 {
11361143 int past = n_past + n_past_offset;
11371144 for (size_t i = 0 ; (i < input_ids.size ()) & !aborted; i++, past++)
1138- run_model ({input_ids[i]}, gen_config, past, lm_logits);
1145+ {
1146+ if (!run_model ({input_ids[i]}, gen_config, past, lm_logits))
1147+ return false ;
1148+ }
1149+ return true ;
11391150 }
11401151 }
11411152
@@ -1218,7 +1229,7 @@ namespace chatllm
12181229 return s;
12191230 }
12201231
1221- virtual void run_model (const std::vector<int > &input_ids,
1232+ virtual bool run_model (const std::vector<int > &input_ids,
12221233 const GenerationConfig &gen_config,
12231234 int past,
12241235 std::vector<float > &output)
@@ -1228,7 +1239,8 @@ namespace chatllm
12281239 initial_run = true ;
12291240 int past = gen_config.max_length - (int )input_ids.size ();
12301241 if (past < 0 ) past = 0 ;
1231- CHATLLM_CHECK (before_initial_run (input_ids, gen_config, past)) << " failed to reserve memory." ;
1242+ if (!before_initial_run (input_ids, gen_config, past))
1243+ return false ;
12321244 }
12331245
12341246 ForwardContext ctx (&backend_context);
@@ -1255,7 +1267,7 @@ namespace chatllm
12551267
12561268 output.resize (ggml::nbytes (r) / sizeof (output[0 ]));
12571269
1258- CHATLLM_CHECK ( ctx.allocate ()) << " failed to allocate memory for graph " ;
1270+ if (! ctx.allocate ()) return false ;
12591271
12601272 Backend::write_tensor_data (input_ids_tensor, input_ids.data ());
12611273
@@ -1270,6 +1282,8 @@ namespace chatllm
12701282 Backend::read_tensor_data (r, output.data ());
12711283
12721284 ctx.reset ();
1285+
1286+ return true ;
12731287 }
12741288
12751289 virtual bool is_output_terminated (const std::vector<int > &output_ids, int &keep_idx, int &pop_output)
0 commit comments