@@ -24,19 +24,23 @@ namespace duckdb {
2424 struct OpenPromptData : FunctionData {
2525 idx_t model_idx;
2626 idx_t json_schema_idx;
27+ idx_t json_system_prompt_idx;
2728 unique_ptr<FunctionData> Copy () const {
2829 auto res = make_uniq<OpenPromptData>();
2930 res->model_idx = model_idx;
3031 res->json_schema_idx = json_schema_idx;
32+ res->json_system_prompt_idx = json_system_prompt_idx;
3133 return res;
3234 };
3335 bool Equals (const FunctionData &other) const {
3436 return model_idx == other.Cast <OpenPromptData>().model_idx &&
35- json_schema_idx == other.Cast <OpenPromptData>().json_schema_idx ;
37+ json_schema_idx == other.Cast <OpenPromptData>().json_schema_idx &&
38+ json_system_prompt_idx==other.Cast <OpenPromptData>().json_system_prompt_idx ;
3639 };
3740 OpenPromptData () {
3841 model_idx = 0 ;
3942 json_schema_idx = 0 ;
43+ json_system_prompt_idx = 0 ;
4044 }
4145 };
4246
@@ -49,6 +53,8 @@ namespace duckdb {
4953 res->model_idx = i;
5054 } else if (argument->alias == " json_schema" ) {
5155 res->json_schema_idx = i;
56+ } else if (argument->alias == " system_prompt" ) {
57+ res->json_system_prompt_idx = i;
5258 }
5359 }
5460 return std::move (res);
@@ -182,26 +188,65 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
182188 std::string api_token = GetConfigValue (context, " openprompt_api_token" , " " );
183189 std::string model_name = GetConfigValue (context, " openprompt_model_name" , " qwen2.5:0.5b" );
184190 std::string json_schema;
191+ std::string system_prompt;
185192
186193 if (info.model_idx != 0 ) {
187194 model_name = args.data [info.model_idx ].GetValue (0 ).ToString ();
188195 }
189196 if (info.json_schema_idx != 0 ) {
190197 json_schema = args.data [info.json_schema_idx ].GetValue (0 ).ToString ();
191198 }
199+ if (info.json_system_prompt_idx != 0 ) {
200+ system_prompt = args.data [info.json_system_prompt_idx ].GetValue (0 ).ToString ();
201+ }
192202
193- std::string request_body = " {" ;
194- request_body += " \" model\" :\" " + model_name + " \" ," ;
203+ unique_ptr<duckdb_yyjson::yyjson_mut_doc, void (*)(duckdb_yyjson::yyjson_mut_doc*)> doc (
204+ duckdb_yyjson::yyjson_mut_doc_new (nullptr ), &duckdb_yyjson::yyjson_mut_doc_free);
205+ auto obj = duckdb_yyjson::yyjson_mut_obj (doc.get ());
206+ duckdb_yyjson::yyjson_mut_doc_set_root (doc.get (), obj);
207+ duckdb_yyjson::yyjson_mut_obj_add (obj,
208+ duckdb_yyjson::yyjson_mut_str (doc.get (), " model" ),
209+ duckdb_yyjson::yyjson_mut_str (doc.get (), model_name.c_str ())
210+ );
195211 if (!json_schema.empty ()) {
196- request_body += " \" response_format\" :{\" type\" :\" json_object\" , \" schema\" :" ;
197- request_body += json_schema;
198- request_body += " }," ;
212+ auto response_format = duckdb_yyjson::yyjson_mut_obj (doc.get ());
213+ duckdb_yyjson::yyjson_mut_obj_add (response_format,
214+ duckdb_yyjson::yyjson_mut_str (doc.get (), " type" ),
215+ duckdb_yyjson::yyjson_mut_str (doc.get (), " json_object" ));
216+ auto yyschema = duckdb_yyjson::yyjson_mut_raw (doc.get (), json_schema.c_str ());
217+ duckdb_yyjson::yyjson_mut_obj_add (response_format,
218+ duckdb_yyjson::yyjson_mut_str (doc.get (), " schema" ),
219+ yyschema);
220+ duckdb_yyjson::yyjson_mut_obj_add (obj,
221+ duckdb_yyjson::yyjson_mut_str (doc.get ()," response_format" ),
222+ response_format);
199223 }
200- request_body += " \" messages\" :[" ;
201- request_body += " {\" role\" :\" system\" ,\" content\" :\" You are a helpful assistant.\" }," ;
202- request_body += " {\" role\" :\" user\" ,\" content\" :\" " + user_prompt.GetString () + " \" }" ;
203- request_body += " ]}" ;
204-
224+ auto messages = duckdb_yyjson::yyjson_mut_arr (doc.get ());
225+ string str_messages[2 ][2 ] = {
226+ {" system" , system_prompt},
227+ {" user" , user_prompt.GetString ()}
228+ };
229+ for (auto message : str_messages) {
230+ if (message[1 ].empty ()) {
231+ continue ;
232+ }
233+ auto yymessage = duckdb_yyjson::yyjson_mut_arr_add_obj (doc.get (),messages);
234+ duckdb_yyjson::yyjson_mut_obj_add (yymessage,
235+ duckdb_yyjson::yyjson_mut_str (doc.get (), " role" ),
236+ duckdb_yyjson::yyjson_mut_str (doc.get (), message[0 ].c_str ()));
237+ duckdb_yyjson::yyjson_mut_obj_add (yymessage,
238+ duckdb_yyjson::yyjson_mut_str (doc.get (), " content" ),
239+ duckdb_yyjson::yyjson_mut_str (doc.get (), message[1 ].c_str ()));
240+ }
241+ duckdb_yyjson::yyjson_mut_obj_add (obj, duckdb_yyjson::yyjson_mut_str (doc.get (), " messages" ),
242+ messages);
243+ duckdb_yyjson::yyjson_write_err err;
244+ auto request_body = duckdb_yyjson::yyjson_mut_write_opts (doc.get (), 0 , nullptr , nullptr , &err);
245+ if (request_body == nullptr ) {
246+ throw std::runtime_error (err.msg );
247+ }
248+ string str_request_body (request_body);
249+ free (request_body);
205250
206251 try {
207252 auto client_and_path = SetupHttpClient (api_url);
@@ -214,7 +259,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V
214259 headers.emplace (" Authorization" , " Bearer " + api_token);
215260 }
216261
217- auto res = client.Post (path.c_str (), headers, request_body , " application/json" );
262+ auto res = client.Post (path.c_str (), headers, str_request_body , " application/json" );
218263
219264 if (!res) {
220265 HandleHttpError (res, " POST" );
@@ -286,10 +331,14 @@ static void LoadInternal(DatabaseInstance &instance) {
286331 open_prompt.AddFunction (ScalarFunction (
287332 {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction,
288333 OpenPromptBind));
289- open_prompt.AddFunction (ScalarFunction (
334+ open_prompt.AddFunction (ScalarFunction (
290335 {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
291336 LogicalType::VARCHAR, OpenPromptRequestFunction,
292337 OpenPromptBind));
338+ open_prompt.AddFunction (ScalarFunction (
339+ {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR},
340+ LogicalType::VARCHAR, OpenPromptRequestFunction,
341+ OpenPromptBind));
293342
294343 ExtensionUtil::RegisterFunction (instance, open_prompt);
295344
0 commit comments