Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ schemars = "0.8.22"
openssl = { version = "0.10.71", features = ["vendored"] }
console-subscriber = "0.4.1"
env_logger = "0.11.7"
reqwest = { version = "0.12.13", features = ["json"] }
10 changes: 5 additions & 5 deletions examples/manual_extraction/manual_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
with data_scope["documents"].row() as doc:
doc["markdown"] = doc["content"].transform(PdfToMarkdown())
doc["raw_module_info"] = doc["markdown"].transform(
cocoindex.functions.ExtractByMistral(
model=cocoindex.functions.MistralModelSpec(
model_id="microsoft/Phi-3.5-mini-instruct",
isq_type="Q8_0"),
cocoindex.functions.ExtractByLlm(
llm_spec=cocoindex.llm.LlmSpec(
api_type=cocoindex.llm.LlmApiType.OLLAMA,
model="llama3.2:latest"),
output_type=cocoindex.typing.encode_enriched_type(ModuleInfo),
instructions="Please extract Python module information from the manual."))
instruction="Please extract Python module information from the manual."))
doc["module_info"] = doc["raw_module_info"].transform(CleanUpManual())
manual_infos.collect(filename=doc["filename"], module_info=doc["module_info"])

Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Cocoindex is a framework for building and running indexing pipelines.
"""
from . import flow, functions, query, sources, storages, cli
from . import flow, functions, query, sources, storages, cli, llm
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def
from .vector import VectorSimilarityMetric
from .lib import *
Expand Down
5 changes: 4 additions & 1 deletion python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import inspect
from typing import Any, Callable, Sequence, TypeVar
from threading import Lock
from enum import Enum

from . import _engine
from . import vector
Expand Down Expand Up @@ -62,7 +63,9 @@ def _spec_kind(spec: Any) -> str:

def _spec_value_dump(spec: Any) -> Any:
"""Recursively dump a spec object and its nested attributes to a dictionary."""
if hasattr(spec, '__dict__'):
if isinstance(spec, Enum):
return spec.value
elif hasattr(spec, '__dict__'):
return {k: _spec_value_dump(v) for k, v in spec.__dict__.items()}
elif isinstance(spec, (list, tuple)):
return [_spec_value_dump(item) for item in spec]
Expand Down
16 changes: 5 additions & 11 deletions python/cocoindex/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,21 @@

import sentence_transformers
from .typing import Float32, Vector, TypeAttr
from . import op
from . import op, llm

class SplitRecursively(op.FunctionSpec):
"""Split a document (in string) recursively."""
chunk_size: int
chunk_overlap: int
language: str | None = None

@dataclass
class MistralModelSpec:
"""A specification for a Mistral model."""
model_id: str
isq_type: str
class ExtractByLlm(op.FunctionSpec):
"""Extract information from a text using a LLM."""

class ExtractByMistral(op.FunctionSpec):
"""Extract information from a text using a Mistral model."""

model: MistralModelSpec
llm_spec: llm.LlmSpec
# Expected to be generated by `cocoindex.typing.encode_enriched_type()`
output_type: dict[str, Any]
instructions: str | None = None
instruction: str | None = None

class SentenceTransformerEmbed(op.FunctionSpec):
"""
Expand Down
13 changes: 13 additions & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dataclasses import dataclass
from enum import Enum

class LlmApiType(Enum):
"""The type of LLM API to use."""
OLLAMA = "Ollama"

@dataclass
class LlmSpec:
"""A specification for a LLM."""
api_type: LlmApiType
model: str
address: str | None = None
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod base;
mod builder;
mod execution;
mod lib_context;
mod llm;
mod ops;
mod py;
mod server;
Expand Down
74 changes: 74 additions & 0 deletions src/llm/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use anyhow::Result;
use schemars::schema::SchemaObject;
use serde::{Deserialize, Serialize};

pub struct LlmClient {
generate_url: String,
model: String,
reqwest_client: reqwest::Client,
}

#[derive(Debug, Serialize)]
enum OllamaFormat<'a> {
#[serde(untagged)]
JsonSchema(&'a SchemaObject),
}

#[derive(Debug, Serialize)]
struct OllamaRequest<'a> {
pub model: &'a str,
pub prompt: &'a str,
pub format: Option<OllamaFormat<'a>>,
pub system: Option<&'a str>,
pub stream: Option<bool>,
}

#[derive(Debug, Deserialize)]
struct OllamaResponse {
pub response: String,
}

const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434";

impl LlmClient {
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
let address = match &spec.address {
Some(addr) => addr.trim_end_matches('/'),
None => OLLAMA_DEFAULT_ADDRESS,
};
Ok(Self {
generate_url: format!("{}/api/generate", address),
model: spec.model,
reqwest_client: reqwest::Client::new(),
})
}

pub async fn generate<'a>(
&self,
request: super::LlmGenerateRequest<'a>,
) -> Result<super::LlmGenerateResponse> {
let req = OllamaRequest {
model: &self.model,
prompt: request.user_prompt.as_ref(),
format: match &request.output_format {
Some(super::OutputFormat::JsonSchema(schema)) => {
Some(OllamaFormat::JsonSchema(schema.as_ref()))
}
None => None,
},
system: request.system_prompt.as_ref().map(|s| s.as_ref()),
stream: Some(false),
};
let res = self
.reqwest_client
.post(self.generate_url.as_str())
.json(&req)
.send()
.await?;
let body = res.text().await?;
let json: OllamaResponse = serde_json::from_str(&body)?;
Ok(super::LlmGenerateResponse {
text: json.response,
})
}
}
36 changes: 36 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::borrow::Cow;

use schemars::schema::SchemaObject;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LlmApiType {
Ollama,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmSpec {
api_type: LlmApiType,
address: Option<String>,
model: String,
}

#[derive(Debug)]
pub enum OutputFormat<'a> {
JsonSchema(Cow<'a, SchemaObject>),
}

#[derive(Debug)]
pub struct LlmGenerateRequest<'a> {
pub system_prompt: Option<Cow<'a, str>>,
pub user_prompt: Cow<'a, str>,
pub output_format: Option<OutputFormat<'a>>,
}

#[derive(Debug)]
pub struct LlmGenerateResponse {
pub text: String,
}

mod client;
pub use client::LlmClient;
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
use std::borrow::Cow;
use std::sync::Arc;

use anyhow::anyhow;
use mistralrs::{self, TextMessageRole};
use schemars::schema::SchemaObject;
use serde::Serialize;

use crate::base::json_schema::ToJsonSchema;
use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat};
use crate::ops::sdk::*;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MistralModelSpec {
model_id: String,
isq_type: mistralrs::IsqType,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Spec {
model: MistralModelSpec,
llm_spec: LlmSpec,
output_type: EnrichedValueType,
instructions: Option<String>,
instruction: Option<String>,
}

struct Executor {
model: mistralrs::Model,
client: LlmClient,
output_json_schema: SchemaObject,
output_type: EnrichedValueType,
request_base: mistralrs::RequestBuilder,
system_prompt: String,
}

fn get_system_message(instructions: &Option<String>) -> String {
fn get_system_prompt(instructions: &Option<String>) -> String {
let mut message =
"You are a helpful assistant that extracts structured information from text. \
Your task is to analyze the input text and output valid JSON that matches the specified schema. \
Expand All @@ -44,24 +40,11 @@ Output only the JSON without any additional messages or explanations."

impl Executor {
async fn new(spec: Spec) -> Result<Self> {
let model = mistralrs::TextModelBuilder::new(spec.model.model_id)
.with_isq(spec.model.isq_type)
.with_paged_attn(|| mistralrs::PagedAttentionMetaBuilder::default().build())?
.build()
.await?;
let request_base = mistralrs::RequestBuilder::new()
.set_constraint(mistralrs::Constraint::JsonSchema(serde_json::to_value(
spec.output_type.to_json_schema(),
)?))
.set_deterministic_sampler()
.add_message(
TextMessageRole::System,
get_system_message(&spec.instructions),
);
Ok(Self {
model,
client: LlmClient::new(spec.llm_spec).await?,
output_json_schema: spec.output_type.to_json_schema(),
output_type: spec.output_type,
request_base,
system_prompt: get_system_prompt(&spec.instruction),
})
}
}
Expand All @@ -78,17 +61,15 @@ impl SimpleFunctionExecutor for Executor {

async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
let text = input.iter().next().unwrap().as_str()?;
let request = self
.request_base
.clone()
.add_message(TextMessageRole::User, text);
let response = self.model.send_chat_request(request).await?;
let response_text = response.choices[0]
.message
.content
.as_ref()
.ok_or_else(|| anyhow!("No content in response"))?;
let json_value: serde_json::Value = serde_json::from_str(response_text)?;
let req = LlmGenerateRequest {
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
user_prompt: Cow::Borrowed(text),
output_format: Some(OutputFormat::JsonSchema(Cow::Borrowed(
&self.output_json_schema,
))),
};
let res = self.client.generate(req).await?;
let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;
let value = Value::from_json(json_value, &self.output_type.typ)?;
Ok(value)
}
Expand All @@ -101,7 +82,7 @@ impl SimpleFunctionFactoryBase for Factory {
type Spec = Spec;

fn name(&self) -> &str {
"ExtractByMistral"
"ExtractByLlm"
}

fn get_output_schema(
Expand Down
2 changes: 1 addition & 1 deletion src/ops/functions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod extract_by_mistral;
pub mod extract_by_llm;
pub mod split_recursively;
2 changes: 1 addition & 1 deletion src/ops/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard};
fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
sources::local_file::Factory.register(registry)?;
functions::split_recursively::Factory.register(registry)?;
functions::extract_by_mistral::Factory.register(registry)?;
functions::extract_by_llm::Factory.register(registry)?;
Arc::new(storages::postgres::Factory::default()).register(registry)?;

Ok(())
Expand Down