1- use crate :: api_bail;
1+ use crate :: prelude:: * ;
2+ use base64:: prelude:: * ;
23
34use super :: { LlmEmbeddingClient , LlmGenerationClient , detect_image_mime_type} ;
4- use anyhow:: Result ;
55use async_openai:: {
66 Client as OpenAIClient ,
77 config:: OpenAIConfig ,
8+ error:: OpenAIError ,
89 types:: {
910 ChatCompletionRequestMessage , ChatCompletionRequestMessageContentPartImage ,
1011 ChatCompletionRequestMessageContentPartText , ChatCompletionRequestSystemMessage ,
@@ -14,8 +15,6 @@ use async_openai::{
1415 ResponseFormat , ResponseFormatJsonSchema ,
1516 } ,
1617} ;
17- use async_trait:: async_trait;
18- use base64:: prelude:: * ;
1918use phf:: phf_map;
2019
2120static DEFAULT_EMBEDDING_DIMENSIONS : phf:: Map < & str , u32 > = phf_map ! {
@@ -62,77 +61,99 @@ impl Client {
6261 }
6362}
6463
65- #[ async_trait]
66- impl LlmGenerationClient for Client {
67- async fn generate < ' req > (
68- & self ,
69- request : super :: LlmGenerateRequest < ' req > ,
70- ) -> Result < super :: LlmGenerateResponse > {
71- let mut messages = Vec :: new ( ) ;
72-
73- // Add system prompt if provided
74- if let Some ( system) = request. system_prompt {
75- messages. push ( ChatCompletionRequestMessage :: System (
76- ChatCompletionRequestSystemMessage {
77- content : ChatCompletionRequestSystemMessageContent :: Text ( system. into_owned ( ) ) ,
78- ..Default :: default ( )
79- } ,
80- ) ) ;
64+ impl utils:: retryable:: IsRetryable for OpenAIError {
65+ fn is_retryable ( & self ) -> bool {
66+ match self {
67+ OpenAIError :: Reqwest ( e) => e. is_retryable ( ) ,
68+ _ => false ,
8169 }
70+ }
71+ }
8272
83- // Add user message
84- let user_message_content = match request. image {
85- Some ( img_bytes) => {
86- let base64_image = BASE64_STANDARD . encode ( img_bytes. as_ref ( ) ) ;
87- let mime_type = detect_image_mime_type ( img_bytes. as_ref ( ) ) ?;
88- let image_url = format ! ( "data:{mime_type};base64,{base64_image}" ) ;
89- ChatCompletionRequestUserMessageContent :: Array ( vec ! [
90- ChatCompletionRequestUserMessageContentPart :: Text (
91- ChatCompletionRequestMessageContentPartText {
92- text: request. user_prompt. into_owned( ) ,
93- } ,
94- ) ,
95- ChatCompletionRequestUserMessageContentPart :: ImageUrl (
96- ChatCompletionRequestMessageContentPartImage {
97- image_url: async_openai:: types:: ImageUrl {
98- url: image_url,
99- detail: Some ( ImageDetail :: Auto ) ,
100- } ,
101- } ,
102- ) ,
103- ] )
104- }
105- None => ChatCompletionRequestUserMessageContent :: Text ( request. user_prompt . into_owned ( ) ) ,
106- } ;
107- messages. push ( ChatCompletionRequestMessage :: User (
108- ChatCompletionRequestUserMessage {
109- content : user_message_content,
73+ fn create_llm_generation_request (
74+ request : & super :: LlmGenerateRequest ,
75+ ) -> Result < CreateChatCompletionRequest > {
76+ let mut messages = Vec :: new ( ) ;
77+
78+ // Add system prompt if provided
79+ if let Some ( system) = & request. system_prompt {
80+ messages. push ( ChatCompletionRequestMessage :: System (
81+ ChatCompletionRequestSystemMessage {
82+ content : ChatCompletionRequestSystemMessageContent :: Text ( system. to_string ( ) ) ,
11083 ..Default :: default ( )
11184 } ,
11285 ) ) ;
86+ }
11387
114- // Create the chat completion request
115- let request = CreateChatCompletionRequest {
116- model : request. model . to_string ( ) ,
117- messages,
118- response_format : match request. output_format {
119- Some ( super :: OutputFormat :: JsonSchema { name, schema } ) => {
120- Some ( ResponseFormat :: JsonSchema {
121- json_schema : ResponseFormatJsonSchema {
122- name : name. into_owned ( ) ,
123- description : None ,
124- schema : Some ( serde_json:: to_value ( & schema) ?) ,
125- strict : Some ( true ) ,
88+ // Add user message
89+ let user_message_content = match & request. image {
90+ Some ( img_bytes) => {
91+ let base64_image = BASE64_STANDARD . encode ( img_bytes. as_ref ( ) ) ;
92+ let mime_type = detect_image_mime_type ( img_bytes. as_ref ( ) ) ?;
93+ let image_url = format ! ( "data:{mime_type};base64,{base64_image}" ) ;
94+ ChatCompletionRequestUserMessageContent :: Array ( vec ! [
95+ ChatCompletionRequestUserMessageContentPart :: Text (
96+ ChatCompletionRequestMessageContentPartText {
97+ text: request. user_prompt. to_string( ) ,
98+ } ,
99+ ) ,
100+ ChatCompletionRequestUserMessageContentPart :: ImageUrl (
101+ ChatCompletionRequestMessageContentPartImage {
102+ image_url: async_openai:: types:: ImageUrl {
103+ url: image_url,
104+ detail: Some ( ImageDetail :: Auto ) ,
126105 } ,
127- } )
128- }
129- None => None ,
130- } ,
106+ } ,
107+ ) ,
108+ ] )
109+ }
110+ None => ChatCompletionRequestUserMessageContent :: Text ( request. user_prompt . to_string ( ) ) ,
111+ } ;
112+ messages. push ( ChatCompletionRequestMessage :: User (
113+ ChatCompletionRequestUserMessage {
114+ content : user_message_content,
131115 ..Default :: default ( )
132- } ;
116+ } ,
117+ ) ) ;
118+ // Create the chat completion request
119+ let request = CreateChatCompletionRequest {
120+ model : request. model . to_string ( ) ,
121+ messages,
122+ response_format : match & request. output_format {
123+ Some ( super :: OutputFormat :: JsonSchema { name, schema } ) => {
124+ Some ( ResponseFormat :: JsonSchema {
125+ json_schema : ResponseFormatJsonSchema {
126+ name : name. to_string ( ) ,
127+ description : None ,
128+ schema : Some ( serde_json:: to_value ( & schema) ?) ,
129+ strict : Some ( true ) ,
130+ } ,
131+ } )
132+ }
133+ None => None ,
134+ } ,
135+ ..Default :: default ( )
136+ } ;
133137
134- // Send request and get response
135- let response = self . client . chat ( ) . create ( request) . await ?;
138+ Ok ( request)
139+ }
140+
141+ #[ async_trait]
142+ impl LlmGenerationClient for Client {
143+ async fn generate < ' req > (
144+ & self ,
145+ request : super :: LlmGenerateRequest < ' req > ,
146+ ) -> Result < super :: LlmGenerateResponse > {
147+ let request = & request;
148+ let response = retryable:: run (
149+ || async {
150+ let req = create_llm_generation_request ( request) ?;
151+ let response = self . client . chat ( ) . create ( req) . await ?;
152+ retryable:: Ok ( response)
153+ } ,
154+ & retryable:: RetryOptions :: default ( ) ,
155+ )
156+ . await ?;
136157
137158 // Extract the response text from the first choice
138159 let text = response
@@ -161,16 +182,21 @@ impl LlmEmbeddingClient for Client {
161182 & self ,
162183 request : super :: LlmEmbeddingRequest < ' req > ,
163184 ) -> Result < super :: LlmEmbeddingResponse > {
164- let response = self
165- . client
166- . embeddings ( )
167- . create ( CreateEmbeddingRequest {
168- model : request. model . to_string ( ) ,
169- input : EmbeddingInput :: String ( request. text . to_string ( ) ) ,
170- dimensions : request. output_dimension ,
171- ..Default :: default ( )
172- } )
173- . await ?;
185+ let response = retryable:: run (
186+ || async {
187+ self . client
188+ . embeddings ( )
189+ . create ( CreateEmbeddingRequest {
190+ model : request. model . to_string ( ) ,
191+ input : EmbeddingInput :: String ( request. text . to_string ( ) ) ,
192+ dimensions : request. output_dimension ,
193+ ..Default :: default ( )
194+ } )
195+ . await
196+ } ,
197+ & retryable:: RetryOptions :: default ( ) ,
198+ )
199+ . await ?;
174200 Ok ( super :: LlmEmbeddingResponse {
175201 embedding : response
176202 . data
0 commit comments