Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ openssl = { version = "0.10.71", features = ["vendored"] }
console-subscriber = "0.4.1"
env_logger = "0.11.7"
reqwest = { version = "0.12.13", features = ["json"] }
async-openai = "0.28.0"
9 changes: 7 additions & 2 deletions examples/manual_extraction/manual_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
doc["raw_module_info"] = doc["markdown"].transform(
cocoindex.functions.ExtractByLlm(
llm_spec=cocoindex.llm.LlmSpec(
api_type=cocoindex.llm.LlmApiType.OLLAMA,
model="llama3.2:latest"),
api_type=cocoindex.llm.LlmApiType.OLLAMA,
# See the full list of models: https://ollama.com/library
model="llama3.2:latest"
),
# Replace by this spec below, to use OpenAI API model instead of ollama
# llm_spec=cocoindex.llm.LlmSpec(
# api_type=cocoindex.llm.LlmApiType.OPENAI, model="gpt-4o"),
output_type=cocoindex.typing.encode_enriched_type(ModuleInfo),
instruction="Please extract Python module information from the manual."))
doc["module_info"] = doc["raw_module_info"].transform(CleanUpManual())
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

class LlmApiType(Enum):
"""The type of LLM API to use."""
OPENAI = "OpenAi"
OLLAMA = "Ollama"

@dataclass
Expand Down
3 changes: 2 additions & 1 deletion src/base/json_schema.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::schema;
use schemars::schema::{
ArrayValidation, InstanceType, ObjectValidation, SchemaObject, SingleOrVec,
ArrayValidation, InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec,
};

pub trait ToJsonSchema {
Expand Down Expand Up @@ -82,6 +82,7 @@ impl ToJsonSchema for schema::StructSchema {
.iter()
.filter_map(|f| (!f.value_type.nullable).then(|| f.name.to_string()))
.collect(),
additional_properties: Some(Schema::Bool(false).into()),
..Default::default()
})),
..Default::default()
Expand Down
10 changes: 9 additions & 1 deletion src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LlmApiType {
Ollama,
OpenAi,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -19,7 +20,10 @@ pub struct LlmSpec {

#[derive(Debug)]
pub enum OutputFormat<'a> {
JsonSchema(Cow<'a, SchemaObject>),
JsonSchema {
name: Cow<'a, str>,
schema: Cow<'a, SchemaObject>,
},
}

#[derive(Debug)]
Expand All @@ -43,12 +47,16 @@ pub trait LlmGenerationClient: Send + Sync {
}

mod ollama;
mod openai;

pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
let client = match spec.api_type {
LlmApiType::Ollama => {
Box::new(ollama::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::OpenAi => {
Box::new(openai::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
};
Ok(client)
}
2 changes: 1 addition & 1 deletion src/llm/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl LlmGenerationClient for Client {
model: &self.model,
prompt: request.user_prompt.as_ref(),
format: match &request.output_format {
Some(super::OutputFormat::JsonSchema(schema)) => {
Some(super::OutputFormat::JsonSchema { schema, .. }) => {
Some(OllamaFormat::JsonSchema(schema.as_ref()))
}
None => None,
Expand Down
101 changes: 101 additions & 0 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use crate::api_bail;

use super::LlmGenerationClient;
use anyhow::Result;
use async_openai::{
config::OpenAIConfig,
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest, ResponseFormat,
ResponseFormatJsonSchema,
},
Client as OpenAIClient,
};
use async_trait::async_trait;

pub struct Client {
client: async_openai::Client<OpenAIConfig>,
model: String,
}

impl Client {
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
if let Some(address) = spec.address {
api_bail!("OpenAI doesn't support custom API address: {address}");
}
// Verify API key is set
if std::env::var("OPENAI_API_KEY").is_err() {
api_bail!("OPENAI_API_KEY environment variable must be set");
}
Ok(Self {
// OpenAI client will use OPENAI_API_KEY env variable by default
client: OpenAIClient::new(),
model: spec.model,
})
}
}

#[async_trait]
impl LlmGenerationClient for Client {
async fn generate<'req>(
&self,
request: super::LlmGenerateRequest<'req>,
) -> Result<super::LlmGenerateResponse> {
let mut messages = Vec::new();

// Add system prompt if provided
if let Some(system) = request.system_prompt {
messages.push(ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()),
..Default::default()
},
));
}

// Add user message
messages.push(ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(
request.user_prompt.into_owned(),
),
..Default::default()
},
));

// Create the chat completion request
let request = CreateChatCompletionRequest {
model: self.model.clone(),
messages,
response_format: match request.output_format {
Some(super::OutputFormat::JsonSchema { name, schema }) => {
Some(ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
name: name.into_owned(),
description: None,
schema: Some(serde_json::to_value(&schema)?),
strict: Some(true),
},
})
}
None => None,
},
..Default::default()
};

// Send request and get response
let response = self.client.chat().create(request).await?;

// Extract the response text from the first choice
let text = response
.choices
.into_iter()
.next()
.map(|choice| choice.message.content)
.flatten()
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;

Ok(super::LlmGenerateResponse { text })
}
}
7 changes: 4 additions & 3 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ impl SimpleFunctionExecutor for Executor {
let req = LlmGenerateRequest {
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
user_prompt: Cow::Borrowed(text),
output_format: Some(OutputFormat::JsonSchema(Cow::Borrowed(
&self.output_json_schema,
))),
output_format: Some(OutputFormat::JsonSchema {
name: Cow::Borrowed("ExtractedData"),
schema: Cow::Borrowed(&self.output_json_schema),
}),
};
let res = self.client.generate(req).await?;
let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;
Expand Down