Skip to content

Commit 7eba612

Browse files
authored
[REFACTOR] Enable validation logic in GenerationConfig (#2411)
This PR enables a centralized validation logic in GenerationConfig.
1 parent 97df697 commit 7eba612

File tree

7 files changed

+83
-39
lines changed

7 files changed

+83
-39
lines changed

cpp/json_ffi/json_ffi_engine.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,14 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
9797
gen_cfg->stop_token_ids = conv_template_.stop_token_ids;
9898
gen_cfg->debug_config = request.debug_config.value_or(DebugConfig());
9999

100-
Request engine_request(request_id, inputs, GenerationConfig(gen_cfg));
101-
this->engine_->AddRequest(engine_request);
100+
Result<GenerationConfig> res_gen_config = GenerationConfig::Validate(GenerationConfig(gen_cfg));
101+
if (res_gen_config.IsErr()) {
102+
err_ = res_gen_config.UnwrapErr();
103+
return false;
104+
}
102105

106+
Request engine_request(request_id, inputs, res_gen_config.Unwrap());
107+
this->engine_->AddRequest(engine_request);
103108
return true;
104109
}
105110

cpp/serve/config.cc

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,57 @@ namespace serve {
2323

2424
TVM_REGISTER_OBJECT_TYPE(GenerationConfigNode);
2525

26-
GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfig& default_config) {
26+
Result<GenerationConfig> GenerationConfig::Validate(GenerationConfig cfg) {
27+
using TResult = Result<GenerationConfig>;
28+
if (cfg->n <= 0) {
29+
return TResult::Error("\"n\" should be at least 1");
30+
}
31+
if (cfg->temperature < 0) {
32+
return TResult::Error("\"temperature\" should be non-negative");
33+
}
34+
if (cfg->top_p < 0 || cfg->top_p > 1) {
35+
return TResult::Error("\"top_p\" should be in range [0, 1]");
36+
}
37+
if (std::fabs(cfg->frequency_penalty) > 2.0) {
38+
return TResult::Error("frequency_penalty must be in [-2, 2]!");
39+
}
40+
if (cfg->repetition_penalty <= 0) {
41+
return TResult::Error("\"repetition_penalty\" must be positive");
42+
}
43+
if (cfg->top_logprobs < 0 || cfg->top_logprobs > 5) {
44+
return TResult::Error("At most 5 top logprob tokens are supported");
45+
}
46+
if (cfg->top_logprobs != 0 && !(cfg->logprobs)) {
47+
return TResult::Error("\"logprobs\" must be true to support \"top_logprobs\"");
48+
}
49+
for (const auto& item : cfg->logit_bias) {
50+
double bias_value = item.second;
51+
if (std::fabs(bias_value) > 100.0) {
52+
return TResult::Error("Logit bias value should be in range [-100, 100].");
53+
}
54+
}
55+
return TResult::Ok(cfg);
56+
}
57+
58+
Result<GenerationConfig> GenerationConfig::FromJSON(String config_json_str,
59+
const GenerationConfig& default_config) {
60+
using TResult = Result<GenerationConfig>;
2761
picojson::object config = json::ParseToJSONObject(config_json_str);
2862
ObjectPtr<GenerationConfigNode> n = make_object<GenerationConfigNode>();
2963

3064
n->n = json::LookupOrDefault<int64_t>(config, "n", default_config->n);
31-
CHECK_GT(n->n, 0) << "\"n\" should be at least 1";
3265
n->temperature =
3366
json::LookupOrDefault<double>(config, "temperature", default_config->temperature);
34-
CHECK_GE(n->temperature, 0) << "\"temperature\" should be non-negative";
3567
n->top_p = json::LookupOrDefault<double>(config, "top_p", default_config->top_p);
36-
CHECK(n->top_p >= 0 && n->top_p <= 1) << "\"top_p\" should be in range [0, 1]";
3768
n->frequency_penalty =
3869
json::LookupOrDefault<double>(config, "frequency_penalty", default_config->frequency_penalty);
39-
CHECK(std::fabs(n->frequency_penalty) <= 2.0) << "Frequency penalty must be in [-2, 2]!";
4070
n->presence_penalty =
4171
json::LookupOrDefault<double>(config, "presence_penalty", default_config->presence_penalty);
42-
CHECK(std::fabs(n->presence_penalty) <= 2.0) << "Presence penalty must be in [-2, 2]!";
4372
n->repetition_penalty = json::LookupOrDefault<double>(config, "repetition_penalty",
4473
default_config->repetition_penalty);
45-
CHECK(n->repetition_penalty > 0) << "Repetition penalty must be a positive number!";
4674
n->logprobs = json::LookupOrDefault<bool>(config, "logprobs", default_config->logprobs);
4775
n->top_logprobs =
4876
json::LookupOrDefault<int64_t>(config, "top_logprobs", default_config->top_logprobs);
49-
CHECK(n->top_logprobs >= 0 && n->top_logprobs <= 5)
50-
<< "At most 5 top logprob tokens are supported";
51-
CHECK(n->top_logprobs == 0 || n->logprobs)
52-
<< "\"logprobs\" must be true to support \"top_logprobs\"";
5377

5478
std::optional<picojson::object> logit_bias_obj =
5579
json::LookupOptional<picojson::object>(config, "logit_bias");
@@ -59,7 +83,6 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
5983
for (auto [token_id_str, bias] : logit_bias_obj.value()) {
6084
CHECK(bias.is<double>());
6185
double bias_value = bias.get<double>();
62-
CHECK_LE(std::fabs(bias_value), 100.0) << "Logit bias value should be in range [-100, 100].";
6386
logit_bias.emplace_back(std::stoi(token_id_str), bias_value);
6487
}
6588
n->logit_bias = std::move(logit_bias);
@@ -78,7 +101,9 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
78101
Array<String> stop_strs;
79102
stop_strs.reserve(stop_strs_arr.value().size());
80103
for (const picojson::value& v : stop_strs_arr.value()) {
81-
CHECK(v.is<std::string>()) << "Invalid stop string in stop_strs";
104+
if (!v.is<std::string>()) {
105+
return TResult::Error("Invalid stop string in stop_strs");
106+
}
82107
stop_strs.push_back(v.get<std::string>());
83108
}
84109
n->stop_strs = std::move(stop_strs);
@@ -91,7 +116,9 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
91116
std::vector<int> stop_token_ids;
92117
stop_token_ids.reserve(stop_token_ids_arr.value().size());
93118
for (const picojson::value& v : stop_token_ids_arr.value()) {
94-
CHECK(v.is<int64_t>()) << "Invalid stop token in stop_token_ids";
119+
if (!v.is<int64_t>()) {
120+
return TResult::Error("Invalid stop token in stop_token_ids");
121+
}
95122
stop_token_ids.push_back(v.get<int64_t>());
96123
}
97124
n->stop_token_ids = std::move(stop_token_ids);
@@ -123,8 +150,7 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
123150
n->debug_config.ignore_eos =
124151
json::LookupOrDefault<bool>(debug_config_obj.value(), "ignore_eos", false);
125152
}
126-
127-
data_ = std::move(n);
153+
return Validate(GenerationConfig(n));
128154
}
129155

130156
GenerationConfig GenerationConfig::GetDefaultFromModelConfig(

cpp/serve/config.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,19 @@ class GenerationConfigNode : public Object {
6868

6969
class GenerationConfig : public ObjectRef {
7070
public:
71-
explicit GenerationConfig(String config_json_str, const GenerationConfig& default_config);
71+
/*!
72+
* \brief Run validation of generation config and ensure values are in bound.
73+
* \return The validtaed Generation config or error.
74+
*/
75+
static Result<GenerationConfig> Validate(GenerationConfig cfg);
76+
77+
/*!
78+
* \brief Create generation config from JSON.
79+
* \param config_json_str The json string for generation config
80+
* \param default_config The default config
81+
*/
82+
static Result<GenerationConfig> FromJSON(String config_json_str,
83+
const GenerationConfig& default_config);
7284

7385
/*! \brief Get the default generation config from the model config. */
7486
static GenerationConfig GetDefaultFromModelConfig(const picojson::object& json);
@@ -192,7 +204,7 @@ class EngineConfigNode : public Object {
192204
/*************** Debug ***************/
193205
bool verbose = false;
194206

195-
TVM_DLL String AsJSONString() const;
207+
String AsJSONString() const;
196208

197209
static constexpr const char* _type_key = "mlc.serve.EngineConfig";
198210
static constexpr const bool _type_has_method_sequal_reduce = false;
@@ -203,14 +215,14 @@ class EngineConfigNode : public Object {
203215
class EngineConfig : public ObjectRef {
204216
public:
205217
/*! \brief Create EngineConfig from JSON object and inferred config. */
206-
TVM_DLL static EngineConfig FromJSONAndInferredConfig(
207-
const picojson::object& json, const InferrableEngineConfig& inferred_config);
218+
static EngineConfig FromJSONAndInferredConfig(const picojson::object& json,
219+
const InferrableEngineConfig& inferred_config);
208220

209221
/*!
210222
* \brief Get all the models and model libs from the JSON string for engine initialization.
211223
* \return The parsed models/model libs from config or error message.
212224
*/
213-
TVM_DLL static Result<std::vector<std::pair<std::string, std::string>>>
225+
static Result<std::vector<std::pair<std::string, std::string>>>
214226
GetModelsAndModelLibsFromJSONString(const std::string& json_str);
215227

216228
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(EngineConfig, ObjectRef, EngineConfigNode);
@@ -225,13 +237,13 @@ struct InferrableEngineConfig {
225237
std::optional<int64_t> max_history_size;
226238

227239
/*! \brief Infer the config for KV cache from a given initial config. */
228-
TVM_DLL static Result<InferrableEngineConfig> InferForKVCache(
240+
static Result<InferrableEngineConfig> InferForKVCache(
229241
EngineMode mode, Device device, double gpu_memory_utilization,
230242
const std::vector<picojson::object>& model_configs,
231243
const std::vector<ModelMetadata>& model_metadata, InferrableEngineConfig init_config,
232244
bool verbose);
233245
/*! \brief Infer the config for RNN state from a given initial config. */
234-
TVM_DLL static Result<InferrableEngineConfig> InferForRNNState(
246+
static Result<InferrableEngineConfig> InferForRNNState(
235247
EngineMode mode, Device device, double gpu_memory_utilization,
236248
const std::vector<picojson::object>& model_configs,
237249
const std::vector<ModelMetadata>& model_metadata, InferrableEngineConfig init_config,

cpp/serve/engine.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,10 @@ class EngineModule : public ModuleNode {
601601
void Abort(const String& request_id) { return GetEngine()->AbortRequest(request_id); }
602602

603603
Request CreateRequest(String id, Array<Data> inputs, String generation_cfg_json_str) {
604-
return Request(
605-
std::move(id), std::move(inputs),
606-
GenerationConfig(std::move(generation_cfg_json_str), default_generation_config_));
604+
auto gen_config =
605+
GenerationConfig::FromJSON(std::move(generation_cfg_json_str), default_generation_config_);
606+
CHECK(gen_config.IsOk()) << gen_config.UnwrapErr();
607+
return Request(std::move(id), std::move(inputs), gen_config.Unwrap());
607608
}
608609
/*! \brief Redirection to `Engine::Step`. */
609610
void Step() { return GetEngine()->Step(); }

cpp/serve/model.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,16 +355,16 @@ class Model : public ObjectRef {
355355
* \param trace_enabled A boolean indicating whether tracing is enabled.
356356
* \return The created runtime module.
357357
*/
358-
TVM_DLL static Model Create(String reload_lib_path, String model_path,
359-
const picojson::object& model_config, DLDevice device,
360-
const Optional<Session>& session, bool trace_enabled);
358+
static Model Create(String reload_lib_path, String model_path,
359+
const picojson::object& model_config, DLDevice device,
360+
const Optional<Session>& session, bool trace_enabled);
361361

362362
/*!
363363
* Load the model config from the given model path.
364364
* \param model_path The path to the model weight parameters.
365365
* \return The model config json object.
366366
*/
367-
TVM_DLL static Result<picojson::object> LoadModelConfig(const String& model_path);
367+
static Result<picojson::object> LoadModelConfig(const String& model_path);
368368

369369
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Model, ObjectRef, ModelObj);
370370
};

cpp/serve/sampler/sampler.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class SamplerObj : public Object {
126126
class Sampler : public ObjectRef {
127127
public:
128128
/*! * \brief Create a CPU sampler. */
129-
TVM_DLL static Sampler CreateCPUSampler(Optional<EventTraceRecorder> trace_recorder);
129+
static Sampler CreateCPUSampler(Optional<EventTraceRecorder> trace_recorder);
130130
/*!
131131
* \brief Create a GPU sampler.
132132
* \param max_num_sample The max number of samples to sample at a time.
@@ -135,9 +135,8 @@ class Sampler : public ObjectRef {
135135
* \param device The device that the model runs on.
136136
* \param trace_recorder The event trace recorder.
137137
*/
138-
TVM_DLL static Sampler CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft,
139-
DLDevice device,
140-
Optional<EventTraceRecorder> trace_recorder);
138+
static Sampler CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft,
139+
DLDevice device, Optional<EventTraceRecorder> trace_recorder);
141140

142141
/*! \brief Check if the given device supports GPU sampling. */
143142
static bool SupportGPUSampler(Device device) {

cpp/serve/threaded_engine.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,10 @@ class ThreadedEngineImpl : public ThreadedEngine {
220220
}
221221

222222
Request CreateRequest(String id, Array<Data> inputs, String generation_cfg_json_str) const {
223-
return Request(
224-
std::move(id), std::move(inputs),
225-
GenerationConfig(std::move(generation_cfg_json_str), GetDefaultGenerationConfig()));
223+
auto gen_config = GenerationConfig::FromJSON(std::move(generation_cfg_json_str),
224+
GetDefaultGenerationConfig());
225+
CHECK(gen_config.IsOk()) << gen_config.UnwrapErr();
226+
return Request(std::move(id), std::move(inputs), gen_config.Unwrap());
226227
}
227228

228229
EngineConfig GetCompleteEngineConfig() const final {

0 commit comments

Comments
 (0)