Skip to content

Commit ff290ee

Browse files
add api route for llm query (#488)
This PR adds the server side mechanism to make API calls to OpenAI, so users can request SQL query from the console. This is helpful to quickly generate sophisticated SQL queries using simple plain text. --------- Co-authored-by: Satyam Singh <[email protected]>
1 parent 34a082c commit ff290ee

File tree

3 files changed

+217
-11
lines changed

3 files changed

+217
-11
lines changed

server/src/handlers/http.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use self::middleware::{DisAllowRootUser, RouteExt};
3434
mod about;
3535
mod health_check;
3636
mod ingest;
37+
mod llm;
3738
mod logstream;
3839
mod middleware;
3940
mod query;
@@ -229,6 +230,21 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
229230
.wrap(DisAllowRootUser),
230231
),
231232
);
233+
234+
let llm_query_api = web::scope("/llm")
235+
.service(
236+
web::resource("").route(
237+
web::post()
238+
.to(llm::make_llm_request)
239+
.authorize(Action::Query),
240+
),
241+
)
242+
.service(
243+
// to check if the API key for an LLM has been set up as env var
244+
web::resource("isactive")
245+
.route(web::post().to(llm::is_llm_active).authorize(Action::Query)),
246+
);
247+
232248
// Deny request if username is same as the env variable P_USERNAME.
233249
cfg.service(
234250
// Base path "{url}/api/v1"
@@ -266,7 +282,8 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
266282
logstream_api,
267283
),
268284
)
269-
.service(user_api),
285+
.service(user_api)
286+
.service(llm_query_api),
270287
)
271288
// GET "/" ==> Serve the static frontend directory
272289
.service(ResourceFiles::new("/", generated).resolve_not_found_to_root());

server/src/handlers/http/llm.rs

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Parseable Server (C) 2022 - 2023 Parseable, Inc.
3+
*
4+
* This program is free software: you can redistribute it and/or modify
5+
* it under the terms of the GNU Affero General Public License as
6+
* published by the Free Software Foundation, either version 3 of the
7+
* License, or (at your option) any later version.
8+
*
9+
* This program is distributed in the hope that it will be useful,
10+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
* GNU Affero General Public License for more details.
13+
*
14+
* You should have received a copy of the GNU Affero General Public License
15+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
*
17+
*/
18+
19+
use actix_web::{http::header::ContentType, web, HttpResponse, Result};
20+
use http::{header, StatusCode};
21+
use itertools::Itertools;
22+
use reqwest;
23+
use serde_json::{json, Value};
24+
25+
use crate::{
26+
metadata::{error::stream_info::MetadataError, STREAM_INFO},
27+
option::CONFIG,
28+
};
29+
30+
const OPEN_AI_URL: &str = "https://api.openai.com/v1/chat/completions";
31+
32+
// Deserialize types for OpenAI Response
33+
#[derive(serde::Deserialize, Debug)]
34+
struct ResponseData {
35+
choices: Vec<Choice>,
36+
}
37+
38+
#[derive(serde::Deserialize, Debug)]
39+
struct Choice {
40+
message: Message,
41+
}
42+
43+
#[derive(serde::Deserialize, Debug)]
44+
struct Message {
45+
content: String,
46+
}
47+
48+
// Request body
49+
#[derive(serde::Deserialize, Debug)]
50+
pub struct AiPrompt {
51+
prompt: String,
52+
stream: String,
53+
}
54+
55+
// Temperory type
56+
#[derive(Debug, serde::Serialize)]
57+
struct Field {
58+
name: String,
59+
data_type: String,
60+
}
61+
62+
impl From<&arrow_schema::Field> for Field {
63+
fn from(field: &arrow_schema::Field) -> Self {
64+
Self {
65+
name: field.name().clone(),
66+
data_type: field.data_type().to_string(),
67+
}
68+
}
69+
}
70+
71+
fn build_prompt(stream: &str, prompt: &str, schema_json: &str) -> String {
72+
format!(
73+
r#"I have a table called {}.
74+
It has the columns:\n{}
75+
Based on this, generate valid SQL for the query: "{}"
76+
Generate only SQL as output. Also add comments in SQL syntax to explain your actions.
77+
Don't output anything else.
78+
If it is not possible to generate valid SQL, output an SQL comment saying so."#,
79+
stream, schema_json, prompt
80+
)
81+
}
82+
83+
fn build_request_body(ai_prompt: String) -> impl serde::Serialize {
84+
json!({
85+
"model": "gpt-3.5-turbo",
86+
"messages": [{ "role": "user", "content": ai_prompt}],
87+
"temperature": 0.6,
88+
})
89+
}
90+
91+
pub async fn make_llm_request(body: web::Json<AiPrompt>) -> Result<HttpResponse, LLMError> {
92+
let api_key = match &CONFIG.parseable.open_ai_key {
93+
Some(api_key) if api_key.len() > 3 => api_key,
94+
_ => return Err(LLMError::InvalidAPIKey),
95+
};
96+
97+
let stream_name = &body.stream;
98+
let schema = STREAM_INFO.schema(stream_name)?;
99+
let filtered_schema = schema
100+
.all_fields()
101+
.into_iter()
102+
.map(Field::from)
103+
.collect_vec();
104+
105+
let schema_json =
106+
serde_json::to_string(&filtered_schema).expect("always converted to valid json");
107+
108+
let prompt = build_prompt(stream_name, &body.prompt, &schema_json);
109+
let body = build_request_body(prompt);
110+
111+
let client = reqwest::Client::new();
112+
let response = client
113+
.post(OPEN_AI_URL)
114+
.header(header::CONTENT_TYPE, "application/json")
115+
.bearer_auth(api_key)
116+
.json(&body)
117+
.send()
118+
.await?;
119+
120+
if response.status().is_success() {
121+
let body: ResponseData = response
122+
.json()
123+
.await
124+
.expect("OpenAI response is always the same");
125+
Ok(HttpResponse::Ok()
126+
.content_type("application/json")
127+
.json(&body.choices[0].message.content))
128+
} else {
129+
let body: Value = response.json().await?;
130+
let message = body
131+
.as_object()
132+
.and_then(|body| body.get("error"))
133+
.and_then(|error| error.as_object())
134+
.and_then(|error| error.get("message"))
135+
.map(|message| message.to_string())
136+
.unwrap_or_else(|| "Error from OpenAI".to_string());
137+
138+
Err(LLMError::APIError(message))
139+
}
140+
}
141+
142+
pub async fn is_llm_active(_body: web::Json<AiPrompt>) -> HttpResponse {
143+
let is_active = matches!(&CONFIG.parseable.open_ai_key, Some(api_key) if api_key.len() > 3);
144+
HttpResponse::Ok()
145+
.content_type("application/json")
146+
.json(json!({"is_active": is_active}))
147+
}
148+
149+
#[derive(Debug, thiserror::Error)]
150+
pub enum LLMError {
151+
#[error("Either OpenAI key was not provided or was invalid")]
152+
InvalidAPIKey,
153+
#[error("Failed to call OpenAI endpoint: {0}")]
154+
FailedRequest(#[from] reqwest::Error),
155+
#[error("{0}")]
156+
APIError(String),
157+
#[error("{0}")]
158+
StreamDoesNotExist(#[from] MetadataError),
159+
}
160+
161+
impl actix_web::ResponseError for LLMError {
162+
fn status_code(&self) -> http::StatusCode {
163+
match self {
164+
Self::InvalidAPIKey => StatusCode::INTERNAL_SERVER_ERROR,
165+
Self::FailedRequest(_) => StatusCode::INTERNAL_SERVER_ERROR,
166+
Self::APIError(_) => StatusCode::INTERNAL_SERVER_ERROR,
167+
Self::StreamDoesNotExist(_) => StatusCode::INTERNAL_SERVER_ERROR,
168+
}
169+
}
170+
171+
fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
172+
actix_web::HttpResponse::build(self.status_code())
173+
.insert_header(ContentType::plaintext())
174+
.body(self.to_string())
175+
}
176+
}

server/src/option.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ pub struct Server {
184184
/// Server should send anonymous analytics or not
185185
pub send_analytics: bool,
186186

187+
/// Open AI access key
188+
pub open_ai_key: Option<String>,
189+
187190
/// Rows in Parquet Rowgroup
188191
pub row_group_size: usize,
189192

@@ -232,6 +235,7 @@ impl FromArgMatches for Server {
232235
.get_one::<bool>(Self::SEND_ANALYTICS)
233236
.cloned()
234237
.expect("default for send analytics");
238+
self.open_ai_key = m.get_one::<String>(Self::OPEN_AI_KEY).cloned();
235239
// converts Gib to bytes before assigning
236240
self.query_memory_pool_size = m
237241
.get_one::<u8>(Self::QUERY_MEM_POOL_SIZE)
@@ -271,6 +275,7 @@ impl Server {
271275
pub const PASSWORD: &str = "password";
272276
pub const CHECK_UPDATE: &str = "check-update";
273277
pub const SEND_ANALYTICS: &str = "send-analytics";
278+
pub const OPEN_AI_KEY: &str = "open-ai-key";
274279
pub const QUERY_MEM_POOL_SIZE: &str = "query-mempool-size";
275280
pub const ROW_GROUP_SIZE: &str = "row-group-size";
276281
pub const PARQUET_COMPRESSION_ALGO: &str = "compression-algo";
@@ -351,6 +356,24 @@ impl Server {
351356
.required(true)
352357
.help("Password for the basic authentication on the server"),
353358
)
359+
.arg(
360+
Arg::new(Self::SEND_ANALYTICS)
361+
.long(Self::SEND_ANALYTICS)
362+
.env("P_SEND_ANONYMOUS_USAGE_DATA")
363+
.value_name("BOOL")
364+
.required(false)
365+
.default_value("true")
366+
.value_parser(value_parser!(bool))
367+
.help("Disable/Enable sending anonymous user data"),
368+
)
369+
.arg(
370+
Arg::new(Self::OPEN_AI_KEY)
371+
.long(Self::OPEN_AI_KEY)
372+
.env("OPENAI_API_KEY")
373+
.value_name("STRING")
374+
.required(false)
375+
.help("Set OpenAI key to enable llm feature"),
376+
)
354377
.arg(
355378
Arg::new(Self::CHECK_UPDATE)
356379
.long(Self::CHECK_UPDATE)
@@ -380,16 +403,6 @@ impl Server {
380403
.value_parser(value_parser!(usize))
381404
.help("Number of rows in a row groups"),
382405
)
383-
.arg(
384-
Arg::new(Self::SEND_ANALYTICS)
385-
.long(Self::SEND_ANALYTICS)
386-
.env("P_SEND_ANONYMOUS_USAGE_DATA")
387-
.value_name("BOOL")
388-
.required(false)
389-
.default_value("true")
390-
.value_parser(value_parser!(bool))
391-
.help("Disable/Enable sending anonymous user data"),
392-
)
393406
.arg(
394407
Arg::new(Self::PARQUET_COMPRESSION_ALGO)
395408
.long(Self::PARQUET_COMPRESSION_ALGO)

0 commit comments

Comments
 (0)