diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8cb8d0033f7d9..5e8583244eb73 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1321,17 +1321,29 @@ struct server_slot { && are_lora_equal(lora, other_slot.lora); } - bool has_budget(const common_params & global_params) { + bool has_budget(const common_params & global_params, int32_t slot_n_ctx) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; + if (global_params.n_predict == -1) { + if (params.n_predict == -2) + n_remaining = slot_n_ctx - n_decoded; + else + n_remaining = params.n_predict - n_decoded; + } else if (global_params.n_predict == -2) { + if (params.n_predict == -1 || params.n_predict == -2) + n_remaining = slot_n_ctx - n_decoded; + else + n_remaining = std::min(params.n_predict - n_decoded, slot_n_ctx - n_decoded); + } else { + if (params.n_predict == -1) + n_remaining = global_params.n_predict - n_decoded; + else if (params.n_predict == -2) + n_remaining = std::min(global_params.n_predict - n_decoded, slot_n_ctx - n_decoded); + else + n_remaining = std::min(params.n_predict - n_decoded, global_params.n_predict - n_decoded); } return n_remaining > 0; // no budget @@ -2153,7 +2165,7 @@ struct server_context { } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base, slot.n_ctx)) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false;