Skip to content

Commit 546e2ae

Browse files
authored
Switch mistral.rs to remote LLM APIs and support ollama API. #27 (#94)
1 parent 42cead6 commit 546e2ae

File tree

12 files changed

+164
-61
lines changed

12 files changed

+164
-61
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ schemars = "0.8.22"
4747
openssl = { version = "0.10.71", features = ["vendored"] }
4848
console-subscriber = "0.4.1"
4949
env_logger = "0.11.7"
50+
reqwest = { version = "0.12.13", features = ["json"] }

examples/manual_extraction/manual_extraction.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
8282
with data_scope["documents"].row() as doc:
8383
doc["markdown"] = doc["content"].transform(PdfToMarkdown())
8484
doc["raw_module_info"] = doc["markdown"].transform(
85-
cocoindex.functions.ExtractByMistral(
86-
model=cocoindex.functions.MistralModelSpec(
87-
model_id="microsoft/Phi-3.5-mini-instruct",
88-
isq_type="Q8_0"),
85+
cocoindex.functions.ExtractByLlm(
86+
llm_spec=cocoindex.llm.LlmSpec(
87+
api_type=cocoindex.llm.LlmApiType.OLLAMA,
88+
model="llama3.2:latest"),
8989
output_type=cocoindex.typing.encode_enriched_type(ModuleInfo),
90-
instructions="Please extract Python module information from the manual."))
90+
instruction="Please extract Python module information from the manual."))
9191
doc["module_info"] = doc["raw_module_info"].transform(CleanUpManual())
9292
manual_infos.collect(filename=doc["filename"], module_info=doc["module_info"])
9393

python/cocoindex/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Cocoindex is a framework for building and running indexing pipelines.
33
"""
4-
from . import flow, functions, query, sources, storages, cli
4+
from . import flow, functions, query, sources, storages, cli, llm
55
from .flow import FlowBuilder, DataScope, DataSlice, Flow, flow_def
66
from .vector import VectorSimilarityMetric
77
from .lib import *

python/cocoindex/flow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import inspect
99
from typing import Any, Callable, Sequence, TypeVar
1010
from threading import Lock
11+
from enum import Enum
1112

1213
from . import _engine
1314
from . import vector
@@ -62,7 +63,9 @@ def _spec_kind(spec: Any) -> str:
6263

6364
def _spec_value_dump(spec: Any) -> Any:
6465
"""Recursively dump a spec object and its nested attributes to a dictionary."""
65-
if hasattr(spec, '__dict__'):
66+
if isinstance(spec, Enum):
67+
return spec.value
68+
elif hasattr(spec, '__dict__'):
6669
return {k: _spec_value_dump(v) for k, v in spec.__dict__.items()}
6770
elif isinstance(spec, (list, tuple)):
6871
return [_spec_value_dump(item) for item in spec]

python/cocoindex/functions.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,21 @@
44

55
import sentence_transformers
66
from .typing import Float32, Vector, TypeAttr
7-
from . import op
7+
from . import op, llm
88

99
class SplitRecursively(op.FunctionSpec):
1010
"""Split a document (in string) recursively."""
1111
chunk_size: int
1212
chunk_overlap: int
1313
language: str | None = None
1414

15-
@dataclass
16-
class MistralModelSpec:
17-
"""A specification for a Mistral model."""
18-
model_id: str
19-
isq_type: str
15+
class ExtractByLlm(op.FunctionSpec):
16+
"""Extract information from a text using a LLM."""
2017

21-
class ExtractByMistral(op.FunctionSpec):
22-
"""Extract information from a text using a Mistral model."""
23-
24-
model: MistralModelSpec
18+
llm_spec: llm.LlmSpec
2519
# Expected to be generated by `cocoindex.typing.encode_enriched_type()`
2620
output_type: dict[str, Any]
27-
instructions: str | None = None
21+
instruction: str | None = None
2822

2923
class SentenceTransformerEmbed(op.FunctionSpec):
3024
"""

python/cocoindex/llm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
4+
class LlmApiType(Enum):
5+
"""The type of LLM API to use."""
6+
OLLAMA = "Ollama"
7+
8+
@dataclass
9+
class LlmSpec:
10+
"""A specification for a LLM."""
11+
api_type: LlmApiType
12+
model: str
13+
address: str | None = None

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod base;
22
mod builder;
33
mod execution;
44
mod lib_context;
5+
mod llm;
56
mod ops;
67
mod py;
78
mod server;

src/llm/client.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
use anyhow::Result;
2+
use schemars::schema::SchemaObject;
3+
use serde::{Deserialize, Serialize};
4+
5+
pub struct LlmClient {
6+
generate_url: String,
7+
model: String,
8+
reqwest_client: reqwest::Client,
9+
}
10+
11+
#[derive(Debug, Serialize)]
12+
enum OllamaFormat<'a> {
13+
#[serde(untagged)]
14+
JsonSchema(&'a SchemaObject),
15+
}
16+
17+
#[derive(Debug, Serialize)]
18+
struct OllamaRequest<'a> {
19+
pub model: &'a str,
20+
pub prompt: &'a str,
21+
pub format: Option<OllamaFormat<'a>>,
22+
pub system: Option<&'a str>,
23+
pub stream: Option<bool>,
24+
}
25+
26+
#[derive(Debug, Deserialize)]
27+
struct OllamaResponse {
28+
pub response: String,
29+
}
30+
31+
const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434";
32+
33+
impl LlmClient {
34+
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
35+
let address = match &spec.address {
36+
Some(addr) => addr.trim_end_matches('/'),
37+
None => OLLAMA_DEFAULT_ADDRESS,
38+
};
39+
Ok(Self {
40+
generate_url: format!("{}/api/generate", address),
41+
model: spec.model,
42+
reqwest_client: reqwest::Client::new(),
43+
})
44+
}
45+
46+
pub async fn generate<'a>(
47+
&self,
48+
request: super::LlmGenerateRequest<'a>,
49+
) -> Result<super::LlmGenerateResponse> {
50+
let req = OllamaRequest {
51+
model: &self.model,
52+
prompt: request.user_prompt.as_ref(),
53+
format: match &request.output_format {
54+
Some(super::OutputFormat::JsonSchema(schema)) => {
55+
Some(OllamaFormat::JsonSchema(schema.as_ref()))
56+
}
57+
None => None,
58+
},
59+
system: request.system_prompt.as_ref().map(|s| s.as_ref()),
60+
stream: Some(false),
61+
};
62+
let res = self
63+
.reqwest_client
64+
.post(self.generate_url.as_str())
65+
.json(&req)
66+
.send()
67+
.await?;
68+
let body = res.text().await?;
69+
let json: OllamaResponse = serde_json::from_str(&body)?;
70+
Ok(super::LlmGenerateResponse {
71+
text: json.response,
72+
})
73+
}
74+
}

src/llm/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
use std::borrow::Cow;
2+
3+
use schemars::schema::SchemaObject;
4+
use serde::{Deserialize, Serialize};
5+
6+
#[derive(Debug, Clone, Serialize, Deserialize)]
7+
pub enum LlmApiType {
8+
Ollama,
9+
}
10+
11+
#[derive(Debug, Clone, Serialize, Deserialize)]
12+
pub struct LlmSpec {
13+
api_type: LlmApiType,
14+
address: Option<String>,
15+
model: String,
16+
}
17+
18+
#[derive(Debug)]
19+
pub enum OutputFormat<'a> {
20+
JsonSchema(Cow<'a, SchemaObject>),
21+
}
22+
23+
#[derive(Debug)]
24+
pub struct LlmGenerateRequest<'a> {
25+
pub system_prompt: Option<Cow<'a, str>>,
26+
pub user_prompt: Cow<'a, str>,
27+
pub output_format: Option<OutputFormat<'a>>,
28+
}
29+
30+
#[derive(Debug)]
31+
pub struct LlmGenerateResponse {
32+
pub text: String,
33+
}
34+
35+
mod client;
36+
pub use client::LlmClient;

src/ops/functions/extract_by_mistral.rs renamed to src/ops/functions/extract_by_llm.rs

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,28 @@
1+
use std::borrow::Cow;
12
use std::sync::Arc;
23

3-
use anyhow::anyhow;
4-
use mistralrs::{self, TextMessageRole};
4+
use schemars::schema::SchemaObject;
55
use serde::Serialize;
66

77
use crate::base::json_schema::ToJsonSchema;
8+
use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat};
89
use crate::ops::sdk::*;
910

10-
#[derive(Debug, Clone, Serialize, Deserialize)]
11-
pub struct MistralModelSpec {
12-
model_id: String,
13-
isq_type: mistralrs::IsqType,
14-
}
15-
1611
#[derive(Debug, Clone, Serialize, Deserialize)]
1712
pub struct Spec {
18-
model: MistralModelSpec,
13+
llm_spec: LlmSpec,
1914
output_type: EnrichedValueType,
20-
instructions: Option<String>,
15+
instruction: Option<String>,
2116
}
2217

2318
struct Executor {
24-
model: mistralrs::Model,
19+
client: LlmClient,
20+
output_json_schema: SchemaObject,
2521
output_type: EnrichedValueType,
26-
request_base: mistralrs::RequestBuilder,
22+
system_prompt: String,
2723
}
2824

29-
fn get_system_message(instructions: &Option<String>) -> String {
25+
fn get_system_prompt(instructions: &Option<String>) -> String {
3026
let mut message =
3127
"You are a helpful assistant that extracts structured information from text. \
3228
Your task is to analyze the input text and output valid JSON that matches the specified schema. \
@@ -44,24 +40,11 @@ Output only the JSON without any additional messages or explanations."
4440

4541
impl Executor {
4642
async fn new(spec: Spec) -> Result<Self> {
47-
let model = mistralrs::TextModelBuilder::new(spec.model.model_id)
48-
.with_isq(spec.model.isq_type)
49-
.with_paged_attn(|| mistralrs::PagedAttentionMetaBuilder::default().build())?
50-
.build()
51-
.await?;
52-
let request_base = mistralrs::RequestBuilder::new()
53-
.set_constraint(mistralrs::Constraint::JsonSchema(serde_json::to_value(
54-
spec.output_type.to_json_schema(),
55-
)?))
56-
.set_deterministic_sampler()
57-
.add_message(
58-
TextMessageRole::System,
59-
get_system_message(&spec.instructions),
60-
);
6143
Ok(Self {
62-
model,
44+
client: LlmClient::new(spec.llm_spec).await?,
45+
output_json_schema: spec.output_type.to_json_schema(),
6346
output_type: spec.output_type,
64-
request_base,
47+
system_prompt: get_system_prompt(&spec.instruction),
6548
})
6649
}
6750
}
@@ -78,17 +61,15 @@ impl SimpleFunctionExecutor for Executor {
7861

7962
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
8063
let text = input.iter().next().unwrap().as_str()?;
81-
let request = self
82-
.request_base
83-
.clone()
84-
.add_message(TextMessageRole::User, text);
85-
let response = self.model.send_chat_request(request).await?;
86-
let response_text = response.choices[0]
87-
.message
88-
.content
89-
.as_ref()
90-
.ok_or_else(|| anyhow!("No content in response"))?;
91-
let json_value: serde_json::Value = serde_json::from_str(response_text)?;
64+
let req = LlmGenerateRequest {
65+
system_prompt: Some(Cow::Borrowed(&self.system_prompt)),
66+
user_prompt: Cow::Borrowed(text),
67+
output_format: Some(OutputFormat::JsonSchema(Cow::Borrowed(
68+
&self.output_json_schema,
69+
))),
70+
};
71+
let res = self.client.generate(req).await?;
72+
let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;
9273
let value = Value::from_json(json_value, &self.output_type.typ)?;
9374
Ok(value)
9475
}
@@ -101,7 +82,7 @@ impl SimpleFunctionFactoryBase for Factory {
10182
type Spec = Spec;
10283

10384
fn name(&self) -> &str {
104-
"ExtractByMistral"
85+
"ExtractByLlm"
10586
}
10687

10788
fn get_output_schema(

0 commit comments

Comments
 (0)