Skip to content

Commit e9a165f

Browse files
authored
Support OpenAI API for LLM generation. #27 (#111)
1 parent d36d059 commit e9a165f

File tree

8 files changed

+126
-8
lines changed

8 files changed

+126
-8
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ openssl = { version = "0.10.71", features = ["vendored"] }
4848
console-subscriber = "0.4.1"
4949
env_logger = "0.11.7"
5050
reqwest = { version = "0.12.13", features = ["json"] }
51+
async-openai = "0.28.0"

examples/manual_extraction/manual_extraction.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,13 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
8484
doc["raw_module_info"] = doc["markdown"].transform(
8585
cocoindex.functions.ExtractByLlm(
8686
llm_spec=cocoindex.llm.LlmSpec(
87-
api_type=cocoindex.llm.LlmApiType.OLLAMA,
88-
model="llama3.2:latest"),
87+
api_type=cocoindex.llm.LlmApiType.OLLAMA,
88+
# See the full list of models: https://ollama.com/library
89+
model="llama3.2:latest"
90+
),
91+
# Replace by this spec below, to use OpenAI API model instead of ollama
92+
# llm_spec=cocoindex.llm.LlmSpec(
93+
# api_type=cocoindex.llm.LlmApiType.OPENAI, model="gpt-4o"),
8994
output_type=cocoindex.typing.encode_enriched_type(ModuleInfo),
9095
instruction="Please extract Python module information from the manual."))
9196
doc["module_info"] = doc["raw_module_info"].transform(CleanUpManual())

python/cocoindex/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
class LlmApiType(Enum):
55
"""The type of LLM API to use."""
6+
OPENAI = "OpenAi"
67
OLLAMA = "Ollama"
78

89
@dataclass

src/base/json_schema.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::schema;
22
use schemars::schema::{
3-
ArrayValidation, InstanceType, ObjectValidation, SchemaObject, SingleOrVec,
3+
ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec,
44
};
55

66
pub trait ToJsonSchema {
@@ -82,6 +82,7 @@ impl ToJsonSchema for schema::StructSchema {
8282
.iter()
8383
.filter_map(|f| (!f.value_type.nullable).then(|| f.name.to_string()))
8484
.collect(),
85+
additional_properties: Some(Schema::Bool(false).into()),
8586
..Default::default()
8687
})),
8788
..Default::default()

src/llm/mod.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};
88
#[derive(Debug, Clone, Serialize, Deserialize)]
99
pub enum LlmApiType {
1010
Ollama,
11+
OpenAi,
1112
}
1213

1314
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -19,7 +20,10 @@ pub struct LlmSpec {
1920

2021
#[derive(Debug)]
2122
pub enum OutputFormat<'a> {
22-
JsonSchema(Cow<'a, SchemaObject>),
23+
JsonSchema {
24+
name: Cow<'a, str>,
25+
schema: Cow<'a, SchemaObject>,
26+
},
2327
}
2428

2529
#[derive(Debug)]
@@ -43,12 +47,16 @@ pub trait LlmGenerationClient: Send + Sync {
4347
}
4448

4549
mod ollama;
50+
mod openai;
4651

4752
pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
4853
let client = match spec.api_type {
4954
LlmApiType::Ollama => {
5055
Box::new(ollama::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
5156
}
57+
LlmApiType::OpenAi => {
58+
Box::new(openai::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
59+
}
5260
};
5361
Ok(client)
5462
}

src/llm/ollama.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ impl LlmGenerationClient for Client {
5656
model: &self.model,
5757
prompt: request.user_prompt.as_ref(),
5858
format: match &request.output_format {
59-
Some(super::OutputFormat::JsonSchema(schema)) => {
59+
Some(super::OutputFormat::JsonSchema { schema, .. }) => {
6060
Some(OllamaFormat::JsonSchema(schema.as_ref()))
6161
}
6262
None => None,

src/llm/openai.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
use crate::api_bail;
2+
3+
use super::LlmGenerationClient;
4+
use anyhow::Result;
5+
use async_openai::{
6+
config::OpenAIConfig,
7+
types::{
8+
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
9+
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
10+
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat,
11+
ResponseFormatJsonSchema,
12+
},
13+
Client as OpenAIClient,
14+
};
15+
use async_trait::async_trait;
16+
17+
pub struct Client {
18+
client: async_openai::Client<OpenAIConfig>,
19+
model: String,
20+
}
21+
22+
impl Client {
23+
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
24+
if let Some(address) = spec.address {
25+
api_bail!("OpenAI doesn't support custom API address: {address}");
26+
}
27+
// Verify API key is set
28+
if std::env::var("OPENAI_API_KEY").is_err() {
29+
api_bail!("OPENAI_API_KEY environment variable must be set");
30+
}
31+
Ok(Self {
32+
// OpenAI client will use OPENAI_API_KEY env variable by default
33+
client: OpenAIClient::new(),
34+
model: spec.model,
35+
})
36+
}
37+
}
38+
39+
#[async_trait]
40+
impl LlmGenerationClient for Client {
41+
async fn generate<'req>(
42+
&self,
43+
request: super::LlmGenerateRequest<'req>,
44+
) -> Result<super::LlmGenerateResponse> {
45+
let mut messages = Vec::new();
46+
47+
// Add system prompt if provided
48+
if let Some(system) = request.system_prompt {
49+
messages.push(ChatCompletionRequestMessage::System(
50+
ChatCompletionRequestSystemMessage {
51+
content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()),
52+
..Default::default()
53+
},
54+
));
55+
}
56+
57+
// Add user message
58+
messages.push(ChatCompletionRequestMessage::User(
59+
ChatCompletionRequestUserMessage {
60+
content: ChatCompletionRequestUserMessageContent::Text(
61+
request.user_prompt.into_owned(),
62+
),
63+
..Default::default()
64+
},
65+
));
66+
67+
// Create the chat completion request
68+
let request = CreateChatCompletionRequest {
69+
model: self.model.clone(),
70+
messages,
71+
response_format: match request.output_format {
72+
Some(super::OutputFormat::JsonSchema { name, schema }) => {
73+
Some(ResponseFormat::JsonSchema {
74+
json_schema: ResponseFormatJsonSchema {
75+
name: name.into_owned(),
76+
description: None,
77+
schema: Some(serde_json::to_value(&schema)?),
78+
strict: Some(true),
79+
},
80+
})
81+
}
82+
None => None,
83+
},
84+
..Default::default()
85+
};
86+
87+
// Send request and get response
88+
let response = self.client.chat().create(request).await?;
89+
90+
// Extract the response text from the first choice
91+
let text = response
92+
.choices
93+
.into_iter()
94+
.next()
95+
.map(|choice| choice.message.content)
96+
.flatten()
97+
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
98+
99+
Ok(super::LlmGenerateResponse { text })
100+
}
101+
}

src/ops/functions/extract_by_llm.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ impl SimpleFunctionExecutor for Executor {
6666
let req = LlmGenerateRequest {
6767
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
6868
user_prompt: Cow::Borrowed(text),
69-
output_format: Some(OutputFormat::JsonSchema(Cow::Borrowed(
70-
&self.output_json_schema,
71-
))),
69+
output_format: Some(OutputFormat::JsonSchema {
70+
name: Cow::Borrowed("ExtractedData"),
71+
schema: Cow::Borrowed(&self.output_json_schema),
72+
}),
7273
};
7374
let res = self.client.generate(req).await?;
7475
let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;

0 commit comments

Comments
 (0)