From d8ee81e4899733c95f69a0f187f9b3473f3ad156 Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 12 Mar 2025 15:12:06 -0700 Subject: [PATCH] Support OpenAI API for LLM generation. --- Cargo.toml | 1 + .../manual_extraction/manual_extraction.py | 9 +- python/cocoindex/llm.py | 1 + src/base/json_schema.rs | 3 +- src/llm/mod.rs | 10 +- src/llm/ollama.rs | 2 +- src/llm/openai.rs | 101 ++++++++++++++++++ src/ops/functions/extract_by_llm.rs | 7 +- 8 files changed, 126 insertions(+), 8 deletions(-) create mode 100644 src/llm/openai.rs diff --git a/Cargo.toml b/Cargo.toml index b452d87eb..b21738d03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,3 +48,4 @@ openssl = { version = "0.10.71", features = ["vendored"] } console-subscriber = "0.4.1" env_logger = "0.11.7" reqwest = { version = "0.12.13", features = ["json"] } +async-openai = "0.28.0" diff --git a/examples/manual_extraction/manual_extraction.py b/examples/manual_extraction/manual_extraction.py index 96b370aed..26ca28ba8 100644 --- a/examples/manual_extraction/manual_extraction.py +++ b/examples/manual_extraction/manual_extraction.py @@ -84,8 +84,13 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco doc["raw_module_info"] = doc["markdown"].transform( cocoindex.functions.ExtractByLlm( llm_spec=cocoindex.llm.LlmSpec( - api_type=cocoindex.llm.LlmApiType.OLLAMA, - model="llama3.2:latest"), + api_type=cocoindex.llm.LlmApiType.OLLAMA, + # See the full list of models: https://ollama.com/library + model="llama3.2:latest" + ), + # Replace by this spec below, to use OpenAI API model instead of ollama + # llm_spec=cocoindex.llm.LlmSpec( + # api_type=cocoindex.llm.LlmApiType.OPENAI, model="gpt-4o"), output_type=cocoindex.typing.encode_enriched_type(ModuleInfo), instruction="Please extract Python module information from the manual.")) doc["module_info"] = doc["raw_module_info"].transform(CleanUpManual()) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index e26c688f5..ab1a6f040 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -3,6 +3,7 @@ class LlmApiType(Enum): """The type of LLM API to use.""" + OPENAI = "OpenAi" OLLAMA = "Ollama" @dataclass diff --git a/src/base/json_schema.rs b/src/base/json_schema.rs index e3fc86c71..c26567b5d 100644 --- a/src/base/json_schema.rs +++ b/src/base/json_schema.rs @@ -1,6 +1,6 @@ use super::schema; use schemars::schema::{ - ArrayValidation, InstanceType, ObjectValidation, SchemaObject, SingleOrVec, + ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec, }; pub trait ToJsonSchema { @@ -82,6 +82,7 @@ impl ToJsonSchema for schema::StructSchema { .iter() .filter_map(|f| (!f.value_type.nullable).then(|| f.name.to_string())) .collect(), + additional_properties: Some(Schema::Bool(false).into()), ..Default::default() })), ..Default::default() diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 8e8190f53..7c0b2b324 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum LlmApiType { Ollama, + OpenAi, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -19,7 +20,10 @@ pub struct LlmSpec { #[derive(Debug)] pub enum OutputFormat<'a> { - JsonSchema(Cow<'a, SchemaObject>), + JsonSchema { + name: Cow<'a, str>, + schema: Cow<'a, SchemaObject>, + }, } #[derive(Debug)] @@ -43,12 +47,16 @@ pub trait LlmGenerationClient: Send + Sync { } mod ollama; +mod openai; pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { let client = match spec.api_type { LlmApiType::Ollama => { Box::new(ollama::Client::new(spec).await?) as Box } + LlmApiType::OpenAi => { + Box::new(openai::Client::new(spec).await?) as Box + } }; Ok(client) } diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index 31bb22447..8d6acd4ce 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -56,7 +56,7 @@ impl LlmGenerationClient for Client { model: &self.model, prompt: request.user_prompt.as_ref(), format: match &request.output_format { - Some(super::OutputFormat::JsonSchema(schema)) => { + Some(super::OutputFormat::JsonSchema { schema, .. }) => { Some(OllamaFormat::JsonSchema(schema.as_ref())) } None => None, diff --git a/src/llm/openai.rs b/src/llm/openai.rs new file mode 100644 index 000000000..dd662b7e4 --- /dev/null +++ b/src/llm/openai.rs @@ -0,0 +1,101 @@ +use crate::api_bail; + +use super::LlmGenerationClient; +use anyhow::Result; +use async_openai::{ + config::OpenAIConfig, + types::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat, + ResponseFormatJsonSchema, + }, + Client as OpenAIClient, +}; +use async_trait::async_trait; + +pub struct Client { + client: async_openai::Client, + model: String, +} + +impl Client { + pub async fn new(spec: super::LlmSpec) -> Result { + if let Some(address) = spec.address { + api_bail!("OpenAI doesn't support custom API address: {address}"); + } + // Verify API key is set + if std::env::var("OPENAI_API_KEY").is_err() { + api_bail!("OPENAI_API_KEY environment variable must be set"); + } + Ok(Self { + // OpenAI client will use OPENAI_API_KEY env variable by default + client: OpenAIClient::new(), + model: spec.model, + }) + } +} + +#[async_trait] +impl LlmGenerationClient for Client { + async fn generate<'req>( + &self, + request: super::LlmGenerateRequest<'req>, + ) -> Result { + let mut messages = Vec::new(); + + // Add system prompt if provided + if let Some(system) = request.system_prompt { + messages.push(ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()), + ..Default::default() + }, + )); + } + + // Add user message + messages.push(ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text( + request.user_prompt.into_owned(), + ), + ..Default::default() + }, + )); + + // Create the chat completion request + let request = CreateChatCompletionRequest { + model: self.model.clone(), + messages, + response_format: match request.output_format { + Some(super::OutputFormat::JsonSchema { name, schema }) => { + Some(ResponseFormat::JsonSchema { + json_schema: ResponseFormatJsonSchema { + name: name.into_owned(), + description: None, + schema: Some(serde_json::to_value(&schema)?), + strict: Some(true), + }, + }) + } + None => None, + }, + ..Default::default() + }; + + // Send request and get response + let response = self.client.chat().create(request).await?; + + // Extract the response text from the first choice + let text = response + .choices + .into_iter() + .next() + .map(|choice| choice.message.content) + .flatten() + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; + + Ok(super::LlmGenerateResponse { text }) + } +} diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 64363ebc1..2dfe88a63 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -66,9 +66,10 @@ impl SimpleFunctionExecutor for Executor { let req = LlmGenerateRequest { system_prompt: Some(Cow::Borrowed(&self.system_prompt)), user_prompt: Cow::Borrowed(text), - output_format: Some(OutputFormat::JsonSchema(Cow::Borrowed( - &self.output_json_schema, - ))), + output_format: Some(OutputFormat::JsonSchema { + name: Cow::Borrowed("ExtractedData"), + schema: Cow::Borrowed(&self.output_json_schema), + }), }; let res = self.client.generate(req).await?; let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;