@@ -23,33 +23,57 @@ namespace serve {
23
23
24
24
TVM_REGISTER_OBJECT_TYPE (GenerationConfigNode);
25
25
26
- GenerationConfig::GenerationConfig (String config_json_str, const GenerationConfig& default_config) {
26
+ Result<GenerationConfig> GenerationConfig::Validate (GenerationConfig cfg) {
27
+ using TResult = Result<GenerationConfig>;
28
+ if (cfg->n <= 0 ) {
29
+ return TResult::Error (" \" n\" should be at least 1" );
30
+ }
31
+ if (cfg->temperature < 0 ) {
32
+ return TResult::Error (" \" temperature\" should be non-negative" );
33
+ }
34
+ if (cfg->top_p < 0 || cfg->top_p > 1 ) {
35
+ return TResult::Error (" \" top_p\" should be in range [0, 1]" );
36
+ }
37
+ if (std::fabs (cfg->frequency_penalty ) > 2.0 ) {
38
+ return TResult::Error (" frequency_penalty must be in [-2, 2]!" );
39
+ }
40
+ if (cfg->repetition_penalty <= 0 ) {
41
+ return TResult::Error (" \" repetition_penalty\" must be positive" );
42
+ }
43
+ if (cfg->top_logprobs < 0 || cfg->top_logprobs > 5 ) {
44
+ return TResult::Error (" At most 5 top logprob tokens are supported" );
45
+ }
46
+ if (cfg->top_logprobs != 0 && !(cfg->logprobs )) {
47
+ return TResult::Error (" \" logprobs\" must be true to support \" top_logprobs\" " );
48
+ }
49
+ for (const auto & item : cfg->logit_bias ) {
50
+ double bias_value = item.second ;
51
+ if (std::fabs (bias_value) > 100.0 ) {
52
+ return TResult::Error (" Logit bias value should be in range [-100, 100]." );
53
+ }
54
+ }
55
+ return TResult::Ok (cfg);
56
+ }
57
+
58
+ Result<GenerationConfig> GenerationConfig::FromJSON (String config_json_str,
59
+ const GenerationConfig& default_config) {
60
+ using TResult = Result<GenerationConfig>;
27
61
picojson::object config = json::ParseToJSONObject (config_json_str);
28
62
ObjectPtr<GenerationConfigNode> n = make_object<GenerationConfigNode>();
29
63
30
64
n->n = json::LookupOrDefault<int64_t >(config, " n" , default_config->n );
31
- CHECK_GT (n->n , 0 ) << " \" n\" should be at least 1" ;
32
65
n->temperature =
33
66
json::LookupOrDefault<double >(config, " temperature" , default_config->temperature );
34
- CHECK_GE (n->temperature , 0 ) << " \" temperature\" should be non-negative" ;
35
67
n->top_p = json::LookupOrDefault<double >(config, " top_p" , default_config->top_p );
36
- CHECK (n->top_p >= 0 && n->top_p <= 1 ) << " \" top_p\" should be in range [0, 1]" ;
37
68
n->frequency_penalty =
38
69
json::LookupOrDefault<double >(config, " frequency_penalty" , default_config->frequency_penalty );
39
- CHECK (std::fabs (n->frequency_penalty ) <= 2.0 ) << " Frequency penalty must be in [-2, 2]!" ;
40
70
n->presence_penalty =
41
71
json::LookupOrDefault<double >(config, " presence_penalty" , default_config->presence_penalty );
42
- CHECK (std::fabs (n->presence_penalty ) <= 2.0 ) << " Presence penalty must be in [-2, 2]!" ;
43
72
n->repetition_penalty = json::LookupOrDefault<double >(config, " repetition_penalty" ,
44
73
default_config->repetition_penalty );
45
- CHECK (n->repetition_penalty > 0 ) << " Repetition penalty must be a positive number!" ;
46
74
n->logprobs = json::LookupOrDefault<bool >(config, " logprobs" , default_config->logprobs );
47
75
n->top_logprobs =
48
76
json::LookupOrDefault<int64_t >(config, " top_logprobs" , default_config->top_logprobs );
49
- CHECK (n->top_logprobs >= 0 && n->top_logprobs <= 5 )
50
- << " At most 5 top logprob tokens are supported" ;
51
- CHECK (n->top_logprobs == 0 || n->logprobs )
52
- << " \" logprobs\" must be true to support \" top_logprobs\" " ;
53
77
54
78
std::optional<picojson::object> logit_bias_obj =
55
79
json::LookupOptional<picojson::object>(config, " logit_bias" );
@@ -59,7 +83,6 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
59
83
for (auto [token_id_str, bias] : logit_bias_obj.value ()) {
60
84
CHECK (bias.is <double >());
61
85
double bias_value = bias.get <double >();
62
- CHECK_LE (std::fabs (bias_value), 100.0 ) << " Logit bias value should be in range [-100, 100]." ;
63
86
logit_bias.emplace_back (std::stoi (token_id_str), bias_value);
64
87
}
65
88
n->logit_bias = std::move (logit_bias);
@@ -78,7 +101,9 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
78
101
Array<String> stop_strs;
79
102
stop_strs.reserve (stop_strs_arr.value ().size ());
80
103
for (const picojson::value& v : stop_strs_arr.value ()) {
81
- CHECK (v.is <std::string>()) << " Invalid stop string in stop_strs" ;
104
+ if (!v.is <std::string>()) {
105
+ return TResult::Error (" Invalid stop string in stop_strs" );
106
+ }
82
107
stop_strs.push_back (v.get <std::string>());
83
108
}
84
109
n->stop_strs = std::move (stop_strs);
@@ -91,7 +116,9 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
91
116
std::vector<int > stop_token_ids;
92
117
stop_token_ids.reserve (stop_token_ids_arr.value ().size ());
93
118
for (const picojson::value& v : stop_token_ids_arr.value ()) {
94
- CHECK (v.is <int64_t >()) << " Invalid stop token in stop_token_ids" ;
119
+ if (!v.is <int64_t >()) {
120
+ return TResult::Error (" Invalid stop token in stop_token_ids" );
121
+ }
95
122
stop_token_ids.push_back (v.get <int64_t >());
96
123
}
97
124
n->stop_token_ids = std::move (stop_token_ids);
@@ -123,8 +150,7 @@ GenerationConfig::GenerationConfig(String config_json_str, const GenerationConfi
123
150
n->debug_config .ignore_eos =
124
151
json::LookupOrDefault<bool >(debug_config_obj.value (), " ignore_eos" , false );
125
152
}
126
-
127
- data_ = std::move (n);
153
+ return Validate (GenerationConfig (n));
128
154
}
129
155
130
156
GenerationConfig GenerationConfig::GetDefaultFromModelConfig (
0 commit comments