Skip to content

Commit a8a7ae4

Browse files
committed
Implement function ExtractByMistral to extract structured data from LLM
1 parent b889564 commit a8a7ae4

File tree

4 files changed

+134
-0
lines changed

4 files changed

+134
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ derivative = "2.2.0"
4040
async-lock = "3.4.0"
4141
hex = "0.4.3"
4242
pythonize = "0.23.0"
43+
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", tag = "v0.4.0" }
4344
schemars = "0.8.22"
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
use std::sync::Arc;
2+
3+
use anyhow::anyhow;
4+
use mistralrs::{self, TextMessageRole};
5+
use serde::Serialize;
6+
7+
use crate::base::json_schema::ToJsonSchema;
8+
use crate::ops::sdk::*;
9+
10+
#[derive(Debug, Clone, Serialize, Deserialize)]
11+
pub struct MistralModelSpec {
12+
model_id: String,
13+
isq_type: mistralrs::IsqType,
14+
}
15+
16+
#[derive(Debug, Clone, Serialize, Deserialize)]
17+
pub struct Spec {
18+
model: MistralModelSpec,
19+
output_type: EnrichedValueType,
20+
instructions: Option<String>,
21+
}
22+
23+
struct Executor {
24+
model: mistralrs::Model,
25+
output_type: EnrichedValueType,
26+
request_base: mistralrs::RequestBuilder,
27+
}
28+
29+
fn get_system_message(instructions: &Option<String>) -> String {
30+
let mut message =
31+
"You are a helpful assistant that extracts structured information from text. \
32+
Your task is to analyze the input text and output valid JSON that matches the specified schema. \
33+
Be precise and only include information that is explicitly stated in the text. \
34+
Output only the JSON without any additional messages or explanations."
35+
.to_string();
36+
37+
if let Some(custom_instructions) = instructions {
38+
message.push_str("\n\n");
39+
message.push_str(custom_instructions);
40+
}
41+
42+
message
43+
}
44+
45+
impl Executor {
46+
async fn new(spec: Spec) -> Result<Self> {
47+
let model = mistralrs::TextModelBuilder::new(spec.model.model_id)
48+
.with_isq(spec.model.isq_type)
49+
.with_logging()
50+
.with_paged_attn(|| mistralrs::PagedAttentionMetaBuilder::default().build())?
51+
.build()
52+
.await?;
53+
let request_base = mistralrs::RequestBuilder::new()
54+
.set_constraint(mistralrs::Constraint::JsonSchema(serde_json::to_value(
55+
spec.output_type.to_json_schema(),
56+
)?))
57+
.set_deterministic_sampler()
58+
.add_message(
59+
TextMessageRole::System,
60+
get_system_message(&spec.instructions),
61+
);
62+
Ok(Self {
63+
model,
64+
output_type: spec.output_type,
65+
request_base,
66+
})
67+
}
68+
}
69+
70+
#[async_trait]
71+
impl SimpleFunctionExecutor for Executor {
72+
fn behavior_version(&self) -> Option<u32> {
73+
Some(1)
74+
}
75+
76+
fn enable_cache(&self) -> bool {
77+
true
78+
}
79+
80+
async fn evaluate(&self, input: Vec<Value>) -> Result<Value> {
81+
let text = input.iter().next().unwrap().as_str()?;
82+
let request = self
83+
.request_base
84+
.clone()
85+
.add_message(TextMessageRole::User, text);
86+
let response = self.model.send_chat_request(request).await?;
87+
let response_text = response.choices[0]
88+
.message
89+
.content
90+
.as_ref()
91+
.ok_or_else(|| anyhow!("No content in response"))?;
92+
let json_value: serde_json::Value = serde_json::from_str(response_text)?;
93+
let value = Value::from_json(json_value, &self.output_type.typ)?;
94+
Ok(value)
95+
}
96+
}
97+
98+
pub struct Factory;
99+
100+
#[async_trait]
101+
impl SimpleFunctionFactoryBase for Factory {
102+
type Spec = Spec;
103+
104+
fn name(&self) -> &str {
105+
"ExtractByMistral"
106+
}
107+
108+
fn get_output_schema(
109+
&self,
110+
spec: &Spec,
111+
input_schema: &Vec<OpArgSchema>,
112+
_context: &FlowInstanceContext,
113+
) -> Result<EnrichedValueType> {
114+
match &expect_input_1(input_schema)?.value_type.typ {
115+
ValueType::Basic(BasicValueType::Str) => {}
116+
t => {
117+
api_bail!("Expect String as input type, got {}", t)
118+
}
119+
}
120+
Ok(spec.output_type.clone())
121+
}
122+
123+
async fn build_executor(
124+
self: Arc<Self>,
125+
spec: Spec,
126+
_input_schema: Vec<OpArgSchema>,
127+
_context: Arc<FlowInstanceContext>,
128+
) -> Result<Box<dyn SimpleFunctionExecutor>> {
129+
Ok(Box::new(Executor::new(spec).await?))
130+
}
131+
}

src/ops/functions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
pub mod extract_by_mistral;
12
pub mod split_recursively;

src/ops/registration.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::sync::{Arc, LazyLock, RwLock, RwLockReadGuard};
88
fn register_executor_factories(registry: &mut ExecutorFactoryRegistry) -> Result<()> {
99
sources::local_file::Factory.register(registry)?;
1010
functions::split_recursively::Factory.register(registry)?;
11+
functions::extract_by_mistral::Factory.register(registry)?;
1112
Arc::new(storages::postgres::Factory::default()).register(registry)?;
1213

1314
Ok(())

0 commit comments

Comments
 (0)