From ef241ab5bb618dd849f50bba79203beab97c187d Mon Sep 17 00:00:00 2001 From: LJ Date: Wed, 12 Mar 2025 00:12:15 -0700 Subject: [PATCH] Switch `mistral.rs` to remote LLM APIs and support ollama API. --- Cargo.toml | 1 + .../manual_extraction/manual_extraction.py | 10 +-- python/cocoindex/__init__.py | 2 +- python/cocoindex/flow.py | 5 +- python/cocoindex/functions.py | 16 ++-- python/cocoindex/llm.py | 13 ++++ src/lib.rs | 1 + src/llm/client.rs | 74 +++++++++++++++++++ src/llm/mod.rs | 36 +++++++++ ...xtract_by_mistral.rs => extract_by_llm.rs} | 63 ++++++---------- src/ops/functions/mod.rs | 2 +- src/ops/registration.rs | 2 +- 12 files changed, 164 insertions(+), 61 deletions(-) create mode 100644 python/cocoindex/llm.py create mode 100644 src/llm/client.rs create mode 100644 src/llm/mod.rs rename src/ops/functions/{extract_by_mistral.rs => extract_by_llm.rs} (60%) diff --git a/Cargo.toml b/Cargo.toml index d9b157ed0..b452d87eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,3 +47,4 @@ schemars = "0.8.22" openssl = { version = "0.10.71", features = ["vendored"] } console-subscriber = "0.4.1" env_logger = "0.11.7" +reqwest = { version = "0.12.13", features = ["json"] } diff --git a/examples/manual_extraction/manual_extraction.py b/examples/manual_extraction/manual_extraction.py index 39018a54e..96b370aed 100644 --- a/examples/manual_extraction/manual_extraction.py +++ b/examples/manual_extraction/manual_extraction.py @@ -82,12 +82,12 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco with data_scope["documents"].row() as doc: doc["markdown"] = doc["content"].transform(PdfToMarkdown()) doc["raw_module_info"] = doc["markdown"].transform( - cocoindex.functions.ExtractByMistral( - model=cocoindex.functions.MistralModelSpec( - model_id="microsoft/Phi-3.5-mini-instruct", - isq_type="Q8_0"), + cocoindex.functions.ExtractByLlm( + llm_spec=cocoindex.llm.LlmSpec( + api_type=cocoindex.llm.LlmApiType.OLLAMA, + model="llama3.2:latest"), output_type=cocoindex.typing.encode_enriched_type(ModuleInfo), - instructions="Please extract Python module information from the manual.")) + instruction="Please extract Python module information from the manual.")) doc["module_info"] = doc["raw_module_info"].transform(CleanUpManual()) manual_infos.collect(filename=doc["filename"], module_info=doc["module_info"]) diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py index 1501a2777..9cd86c8d9 100644 --- a/python/cocoindex/__init__.py +++ b/python/cocoindex/__init__.py @@ -1,7 +1,7 @@ """ Cocoindex is a framework for building and running indexing pipelines. """ -from . import flow, functions, query, sources, storages, cli +from . import flow, functions, query, sources, storages, cli, llm from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def from .vector import VectorSimilarityMetric from .lib import * diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index d0296802b..7b917f287 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -8,6 +8,7 @@ import inspect from typing import Any, Callable, Sequence, TypeVar from threading import Lock +from enum import Enum from . import _engine from . import vector @@ -62,7 +63,9 @@ def _spec_kind(spec: Any) -> str: def _spec_value_dump(spec: Any) -> Any: """Recursively dump a spec object and its nested attributes to a dictionary.""" - if hasattr(spec, '__dict__'): + if isinstance(spec, Enum): + return spec.value + elif hasattr(spec, '__dict__'): return {k: _spec_value_dump(v) for k, v in spec.__dict__.items()} elif isinstance(spec, (list, tuple)): return [_spec_value_dump(item) for item in spec] diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 150f0a7ca..d97e7ab14 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -4,7 +4,7 @@ import sentence_transformers from .typing import Float32, Vector, TypeAttr -from . import op +from . import op, llm class SplitRecursively(op.FunctionSpec): """Split a document (in string) recursively.""" @@ -12,19 +12,13 @@ class SplitRecursively(op.FunctionSpec): chunk_overlap: int language: str | None = None -@dataclass -class MistralModelSpec: - """A specification for a Mistral model.""" - model_id: str - isq_type: str +class ExtractByLlm(op.FunctionSpec): + """Extract information from a text using a LLM.""" -class ExtractByMistral(op.FunctionSpec): - """Extract information from a text using a Mistral model.""" - - model: MistralModelSpec + llm_spec: llm.LlmSpec # Expected to be generated by `cocoindex.typing.encode_enriched_type()` output_type: dict[str, Any] - instructions: str | None = None + instruction: str | None = None class SentenceTransformerEmbed(op.FunctionSpec): """ diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py new file mode 100644 index 000000000..e26c688f5 --- /dev/null +++ b/python/cocoindex/llm.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from enum import Enum + +class LlmApiType(Enum): + """The type of LLM API to use.""" + OLLAMA = "Ollama" + +@dataclass +class LlmSpec: + """A specification for a LLM.""" + api_type: LlmApiType + model: str + address: str | None = None diff --git a/src/lib.rs b/src/lib.rs index b682213d5..8ccb271b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ mod base; mod builder; mod execution; mod lib_context; +mod llm; mod ops; mod py; mod server; diff --git a/src/llm/client.rs b/src/llm/client.rs new file mode 100644 index 000000000..28ada294f --- /dev/null +++ b/src/llm/client.rs @@ -0,0 +1,74 @@ +use anyhow::Result; +use schemars::schema::SchemaObject; +use serde::{Deserialize, Serialize}; + +pub struct LlmClient { + generate_url: String, + model: String, + reqwest_client: reqwest::Client, +} + +#[derive(Debug, Serialize)] +enum OllamaFormat<'a> { + #[serde(untagged)] + JsonSchema(&'a SchemaObject), +} + +#[derive(Debug, Serialize)] +struct OllamaRequest<'a> { + pub model: &'a str, + pub prompt: &'a str, + pub format: Option>, + pub system: Option<&'a str>, + pub stream: Option, +} + +#[derive(Debug, Deserialize)] +struct OllamaResponse { + pub response: String, +} + +const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434"; + +impl LlmClient { + pub async fn new(spec: super::LlmSpec) -> Result { + let address = match &spec.address { + Some(addr) => addr.trim_end_matches('/'), + None => OLLAMA_DEFAULT_ADDRESS, + }; + Ok(Self { + generate_url: format!("{}/api/generate", address), + model: spec.model, + reqwest_client: reqwest::Client::new(), + }) + } + + pub async fn generate<'a>( + &self, + request: super::LlmGenerateRequest<'a>, + ) -> Result { + let req = OllamaRequest { + model: &self.model, + prompt: request.user_prompt.as_ref(), + format: match &request.output_format { + Some(super::OutputFormat::JsonSchema(schema)) => { + Some(OllamaFormat::JsonSchema(schema.as_ref())) + } + None => None, + }, + system: request.system_prompt.as_ref().map(|s| s.as_ref()), + stream: Some(false), + }; + let res = self + .reqwest_client + .post(self.generate_url.as_str()) + .json(&req) + .send() + .await?; + let body = res.text().await?; + let json: OllamaResponse = serde_json::from_str(&body)?; + Ok(super::LlmGenerateResponse { + text: json.response, + }) + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs new file mode 100644 index 000000000..aaa412abd --- /dev/null +++ b/src/llm/mod.rs @@ -0,0 +1,36 @@ +use std::borrow::Cow; + +use schemars::schema::SchemaObject; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum LlmApiType { + Ollama, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmSpec { + api_type: LlmApiType, + address: Option, + model: String, +} + +#[derive(Debug)] +pub enum OutputFormat<'a> { + JsonSchema(Cow<'a, SchemaObject>), +} + +#[derive(Debug)] +pub struct LlmGenerateRequest<'a> { + pub system_prompt: Option>, + pub user_prompt: Cow<'a, str>, + pub output_format: Option>, +} + +#[derive(Debug)] +pub struct LlmGenerateResponse { + pub text: String, +} + +mod client; +pub use client::LlmClient; diff --git a/src/ops/functions/extract_by_mistral.rs b/src/ops/functions/extract_by_llm.rs similarity index 60% rename from src/ops/functions/extract_by_mistral.rs rename to src/ops/functions/extract_by_llm.rs index 52201a46e..3b3989dbc 100644 --- a/src/ops/functions/extract_by_mistral.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -1,32 +1,28 @@ +use std::borrow::Cow; use std::sync::Arc; -use anyhow::anyhow; -use mistralrs::{self, TextMessageRole}; +use schemars::schema::SchemaObject; use serde::Serialize; use crate::base::json_schema::ToJsonSchema; +use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat}; use crate::ops::sdk::*; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct MistralModelSpec { - model_id: String, - isq_type: mistralrs::IsqType, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Spec { - model: MistralModelSpec, + llm_spec: LlmSpec, output_type: EnrichedValueType, - instructions: Option, + instruction: Option, } struct Executor { - model: mistralrs::Model, + client: LlmClient, + output_json_schema: SchemaObject, output_type: EnrichedValueType, - request_base: mistralrs::RequestBuilder, + system_prompt: String, } -fn get_system_message(instructions: &Option) -> String { +fn get_system_prompt(instructions: &Option) -> String { let mut message = "You are a helpful assistant that extracts structured information from text. \ Your task is to analyze the input text and output valid JSON that matches the specified schema. \ @@ -44,24 +40,11 @@ Output only the JSON without any additional messages or explanations." impl Executor { async fn new(spec: Spec) -> Result { - let model = mistralrs::TextModelBuilder::new(spec.model.model_id) - .with_isq(spec.model.isq_type) - .with_paged_attn(|| mistralrs::PagedAttentionMetaBuilder::default().build())? - .build() - .await?; - let request_base = mistralrs::RequestBuilder::new() - .set_constraint(mistralrs::Constraint::JsonSchema(serde_json::to_value( - spec.output_type.to_json_schema(), - )?)) - .set_deterministic_sampler() - .add_message( - TextMessageRole::System, - get_system_message(&spec.instructions), - ); Ok(Self { - model, + client: LlmClient::new(spec.llm_spec).await?, + output_json_schema: spec.output_type.to_json_schema(), output_type: spec.output_type, - request_base, + system_prompt: get_system_prompt(&spec.instruction), }) } } @@ -78,17 +61,15 @@ impl SimpleFunctionExecutor for Executor { async fn evaluate(&self, input: Vec) -> Result { let text = input.iter().next().unwrap().as_str()?; - let request = self - .request_base - .clone() - .add_message(TextMessageRole::User, text); - let response = self.model.send_chat_request(request).await?; - let response_text = response.choices[0] - .message - .content - .as_ref() - .ok_or_else(|| anyhow!("No content in response"))?; - let json_value: serde_json::Value = serde_json::from_str(response_text)?; + let req = LlmGenerateRequest { + system_prompt: Some(Cow::Borrowed(&self.system_prompt)), + user_prompt: Cow::Borrowed(text), + output_format: Some(OutputFormat::JsonSchema(Cow::Borrowed( + &self.output_json_schema, + ))), + }; + let res = self.client.generate(req).await?; + let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?; let value = Value::from_json(json_value, &self.output_type.typ)?; Ok(value) } @@ -101,7 +82,7 @@ impl SimpleFunctionFactoryBase for Factory { type Spec = Spec; fn name(&self) -> &str { - "ExtractByMistral" + "ExtractByLlm" } fn get_output_schema( diff --git a/src/ops/functions/mod.rs b/src/ops/functions/mod.rs index ac82c6cf9..254c29988 100644 --- a/src/ops/functions/mod.rs +++ b/src/ops/functions/mod.rs @@ -1,2 +1,2 @@ -pub mod extract_by_mistral; +pub mod extract_by_llm; pub mod split_recursively; diff --git a/src/ops/registration.rs b/src/ops/registration.rs index 33fb803d6..1a258fc94 100644 --- a/src/ops/registration.rs +++ b/src/ops/registration.rs @@ -8,7 +8,7 @@ use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard}; fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> { sources::local_file::Factory.register(registry)?; functions::split_recursively::Factory.register(registry)?; - functions::extract_by_mistral::Factory.register(registry)?; + functions::extract_by_llm::Factory.register(registry)?; Arc::new(storages::postgres::Factory::default()).register(registry)?; Ok(())