1+ use std:: borrow:: Cow ;
12use std:: sync:: Arc ;
23
3- use anyhow:: anyhow;
4- use mistralrs:: { self , TextMessageRole } ;
4+ use schemars:: schema:: SchemaObject ;
55use serde:: Serialize ;
66
77use crate :: base:: json_schema:: ToJsonSchema ;
8+ use crate :: llm:: { LlmClient , LlmGenerateRequest , LlmSpec , OutputFormat } ;
89use crate :: ops:: sdk:: * ;
910
10- #[ derive( Debug , Clone , Serialize , Deserialize ) ]
11- pub struct MistralModelSpec {
12- model_id : String ,
13- isq_type : mistralrs:: IsqType ,
14- }
15-
1611#[ derive( Debug , Clone , Serialize , Deserialize ) ]
1712pub struct Spec {
18- model : MistralModelSpec ,
13+ llm_spec : LlmSpec ,
1914 output_type : EnrichedValueType ,
20- instructions : Option < String > ,
15+ instruction : Option < String > ,
2116}
2217
2318struct Executor {
24- model : mistralrs:: Model ,
19+ client : LlmClient ,
20+ output_json_schema : SchemaObject ,
2521 output_type : EnrichedValueType ,
26- request_base : mistralrs :: RequestBuilder ,
22+ system_prompt : String ,
2723}
2824
29- fn get_system_message ( instructions : & Option < String > ) -> String {
25+ fn get_system_prompt ( instructions : & Option < String > ) -> String {
3026 let mut message =
3127 "You are a helpful assistant that extracts structured information from text. \
3228 Your task is to analyze the input text and output valid JSON that matches the specified schema. \
@@ -44,24 +40,11 @@ Output only the JSON without any additional messages or explanations."
4440
4541impl Executor {
4642 async fn new ( spec : Spec ) -> Result < Self > {
47- let model = mistralrs:: TextModelBuilder :: new ( spec. model . model_id )
48- . with_isq ( spec. model . isq_type )
49- . with_paged_attn ( || mistralrs:: PagedAttentionMetaBuilder :: default ( ) . build ( ) ) ?
50- . build ( )
51- . await ?;
52- let request_base = mistralrs:: RequestBuilder :: new ( )
53- . set_constraint ( mistralrs:: Constraint :: JsonSchema ( serde_json:: to_value (
54- spec. output_type . to_json_schema ( ) ,
55- ) ?) )
56- . set_deterministic_sampler ( )
57- . add_message (
58- TextMessageRole :: System ,
59- get_system_message ( & spec. instructions ) ,
60- ) ;
6143 Ok ( Self {
62- model,
44+ client : LlmClient :: new ( spec. llm_spec ) . await ?,
45+ output_json_schema : spec. output_type . to_json_schema ( ) ,
6346 output_type : spec. output_type ,
64- request_base ,
47+ system_prompt : get_system_prompt ( & spec . instruction ) ,
6548 } )
6649 }
6750}
@@ -78,17 +61,15 @@ impl SimpleFunctionExecutor for Executor {
7861
7962 async fn evaluate ( & self , input : Vec < Value > ) -> Result < Value > {
8063 let text = input. iter ( ) . next ( ) . unwrap ( ) . as_str ( ) ?;
81- let request = self
82- . request_base
83- . clone ( )
84- . add_message ( TextMessageRole :: User , text) ;
85- let response = self . model . send_chat_request ( request) . await ?;
86- let response_text = response. choices [ 0 ]
87- . message
88- . content
89- . as_ref ( )
90- . ok_or_else ( || anyhow ! ( "No content in response" ) ) ?;
91- let json_value: serde_json:: Value = serde_json:: from_str ( response_text) ?;
64+ let req = LlmGenerateRequest {
65+ system_prompt : Some ( Cow :: Borrowed ( & self . system_prompt ) ) ,
66+ user_prompt : Cow :: Borrowed ( text) ,
67+ output_format : Some ( OutputFormat :: JsonSchema ( Cow :: Borrowed (
68+ & self . output_json_schema ,
69+ ) ) ) ,
70+ } ;
71+ let res = self . client . generate ( req) . await ?;
72+ let json_value: serde_json:: Value = serde_json:: from_str ( res. text . as_str ( ) ) ?;
9273 let value = Value :: from_json ( json_value, & self . output_type . typ ) ?;
9374 Ok ( value)
9475 }
@@ -101,7 +82,7 @@ impl SimpleFunctionFactoryBase for Factory {
10182 type Spec = Spec ;
10283
10384 fn name ( & self ) -> & str {
104- "ExtractByMistral "
85+ "ExtractByLlm "
10586 }
10687
10788 fn get_output_schema (
0 commit comments