@@ -57,28 +57,33 @@ bool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request
57
57
return false ;
58
58
}
59
59
ChatCompletionRequest request = request_res.Unwrap ();
60
- // get prompt: note, assistant was appended in the end.
61
- Result<std::vector<Data>> inputs_obj =
62
- CreatePrompt (this ->conv_template_ , request, this ->model_config_ , this ->device_ );
63
- if (inputs_obj.IsErr ()) {
64
- err_ = inputs_obj.UnwrapErr ();
65
- return false ;
66
- }
67
- Array<Data> inputs = inputs_obj.Unwrap ();
68
-
69
- // generation_cfg
60
+ Array<Data> inputs;
70
61
Array<String> stop_strs;
71
- stop_strs.reserve (this ->conv_template_ .stop_str .size ());
72
- for (const std::string& stop_str : this ->conv_template_ .stop_str ) {
73
- stop_strs.push_back (stop_str);
74
- }
75
- if (request.stop .has_value ()) {
76
- stop_strs.reserve (stop_strs.size () + request.stop .value ().size ());
77
- for (const std::string& stop_str : request.stop .value ()) {
62
+ bool is_special_request =
63
+ (request.debug_config .has_value () &&
64
+ request.debug_config .value ().special_request != SpecialRequestKind::kNone );
65
+ // special request does not have to go through prompt construction
66
+ if (!is_special_request) {
67
+ // get prompt: note, assistant was appended in the end.
68
+ Result<std::vector<Data>> inputs_obj =
69
+ CreatePrompt (this ->conv_template_ , request, this ->model_config_ , this ->device_ );
70
+ if (inputs_obj.IsErr ()) {
71
+ err_ = inputs_obj.UnwrapErr ();
72
+ return false ;
73
+ }
74
+ inputs = inputs_obj.Unwrap ();
75
+
76
+ stop_strs.reserve (this ->conv_template_ .stop_str .size ());
77
+ for (const std::string& stop_str : this ->conv_template_ .stop_str ) {
78
78
stop_strs.push_back (stop_str);
79
79
}
80
+ if (request.stop .has_value ()) {
81
+ stop_strs.reserve (stop_strs.size () + request.stop .value ().size ());
82
+ for (const std::string& stop_str : request.stop .value ()) {
83
+ stop_strs.push_back (stop_str);
84
+ }
85
+ }
80
86
}
81
-
82
87
// create a generation config from request
83
88
const auto & default_gen_cfg = default_generation_config_;
84
89
auto gen_cfg = tvm::runtime::make_object<GenerationConfigNode>();
@@ -115,8 +120,6 @@ bool JSONFFIEngine::Abort(std::string request_id) {
115
120
116
121
std::string JSONFFIEngine::GetLastError () { return err_; }
117
122
118
- std::string JSONFFIEngine::JSONMetrics () { return this ->engine_ ->JSONMetrics (); }
119
-
120
123
void JSONFFIEngine::ExitBackgroundLoop () { this ->engine_ ->ExitBackgroundLoop (); }
121
124
122
125
JSONFFIEngine::~JSONFFIEngine () { this ->ExitBackgroundLoop (); }
@@ -131,7 +134,6 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
131
134
TVM_MODULE_VTABLE_ENTRY (" chat_completion" , &JSONFFIEngineImpl::ChatCompletion);
132
135
TVM_MODULE_VTABLE_ENTRY (" abort" , &JSONFFIEngineImpl::Abort);
133
136
TVM_MODULE_VTABLE_ENTRY (" get_last_error" , &JSONFFIEngineImpl::GetLastError);
134
- TVM_MODULE_VTABLE_ENTRY (" json_metrics" , &JSONFFIEngineImpl::JSONMetrics);
135
137
TVM_MODULE_VTABLE_ENTRY (" run_background_loop" , &JSONFFIEngineImpl::RunBackgroundLoop);
136
138
TVM_MODULE_VTABLE_ENTRY (" run_background_stream_back_loop" ,
137
139
&JSONFFIEngineImpl::RunBackgroundStreamBackLoop);
@@ -190,11 +192,35 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
190
192
191
193
String GetResponseFromStreamOutput (Array<RequestStreamOutput> delta_outputs) {
192
194
std::unordered_map<std::string, std::vector<ChatCompletionStreamResponseChoice>> response_map;
195
+ std::vector<picojson::value> request_final_usage_messages;
196
+ std::string model = " json_ffi" ;
197
+
193
198
for (const auto & delta_output : delta_outputs) {
194
199
std::string request_id = delta_output->request_id ;
195
200
if (response_map.find (request_id) == response_map.end ()) {
196
201
response_map[request_id] = std::vector<ChatCompletionStreamResponseChoice>();
197
202
}
203
+
204
+ // build the final usage messages
205
+ // invariant, we can always let other messages to come first
206
+ // then the final usage messages, as final usage is always last
207
+ if (delta_output->request_final_usage_json_str .defined ()) {
208
+ ChatCompletionStreamResponse response;
209
+ response.id = request_id;
210
+ response.model = model;
211
+ response.system_fingerprint = " " ;
212
+ std::string usage_json_str = delta_output->request_final_usage_json_str .value ();
213
+ picojson::value usage_json;
214
+ std::string err = picojson::parse (usage_json, usage_json_str);
215
+ if (!err.empty ()) {
216
+ err_ = err;
217
+ } else {
218
+ response.usage = usage_json;
219
+ }
220
+ request_final_usage_messages.push_back (picojson::value (response.AsJSON ()));
221
+ continue ;
222
+ }
223
+ ICHECK_NE (delta_output->group_finish_reason .size (), 0 );
198
224
ChatCompletionStreamResponseChoice choice;
199
225
200
226
if (delta_output->group_finish_reason .size () != 1 ) {
@@ -232,13 +258,17 @@ class JSONFFIEngineImpl : public JSONFFIEngine, public ModuleNode {
232
258
233
259
picojson::array response_arr;
234
260
for (const auto & [request_id, choices] : response_map) {
261
+ if (choices.size () == 0 ) continue ;
235
262
ChatCompletionStreamResponse response;
236
263
response.id = request_id;
237
264
response.choices = choices;
238
265
response.model = " json_ffi" ; // TODO: Return model name from engine (or from args)
239
266
response.system_fingerprint = " " ;
240
267
response_arr.push_back (picojson::value (response.AsJSON ()));
241
268
}
269
+ for (auto && item : request_final_usage_messages) {
270
+ response_arr.emplace_back (std::move (item));
271
+ }
242
272
return picojson::value (response_arr).serialize ();
243
273
}
244
274
};
0 commit comments