diff --git a/src/execution/live_updater.rs b/src/execution/live_updater.rs index e85c1ca72..485d8c4e7 100644 --- a/src/execution/live_updater.rs +++ b/src/execution/live_updater.rs @@ -164,7 +164,7 @@ impl SourceUpdateTask { .next() .await .transpose() - .map_err(retryable::Error::always_retryable) + .map_err(retryable::Error::retryable) }, &retry_options, ) diff --git a/src/llm/openai.rs b/src/llm/openai.rs index a29a9bce2..68ec64214 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -1,10 +1,11 @@ -use crate::api_bail; +use crate::prelude::*; +use base64::prelude::*; use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type}; -use anyhow::Result; use async_openai::{ Client as OpenAIClient, config::OpenAIConfig, + error::OpenAIError, types::{ ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage, @@ -14,8 +15,6 @@ use async_openai::{ ResponseFormat, ResponseFormatJsonSchema, }, }; -use async_trait::async_trait; -use base64::prelude::*; use phf::phf_map; static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { @@ -62,77 +61,99 @@ impl Client { } } -#[async_trait] -impl LlmGenerationClient for Client { - async fn generate<'req>( - &self, - request: super::LlmGenerateRequest<'req>, - ) -> Result { - let mut messages = Vec::new(); - - // Add system prompt if provided - if let Some(system) = request.system_prompt { - messages.push(ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()), - ..Default::default() - }, - )); +impl utils::retryable::IsRetryable for OpenAIError { + fn is_retryable(&self) -> bool { + match self { + OpenAIError::Reqwest(e) => e.is_retryable(), + _ => false, } + } +} - // Add user message - let user_message_content = match request.image { - Some(img_bytes) => { - let base64_image = BASE64_STANDARD.encode(img_bytes.as_ref()); - let mime_type = detect_image_mime_type(img_bytes.as_ref())?; - let image_url = format!("data:{mime_type};base64,{base64_image}"); - ChatCompletionRequestUserMessageContent::Array(vec![ - ChatCompletionRequestUserMessageContentPart::Text( - ChatCompletionRequestMessageContentPartText { - text: request.user_prompt.into_owned(), - }, - ), - ChatCompletionRequestUserMessageContentPart::ImageUrl( - ChatCompletionRequestMessageContentPartImage { - image_url: async_openai::types::ImageUrl { - url: image_url, - detail: Some(ImageDetail::Auto), - }, - }, - ), - ]) - } - None => ChatCompletionRequestUserMessageContent::Text(request.user_prompt.into_owned()), - }; - messages.push(ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: user_message_content, +fn create_llm_generation_request( + request: &super::LlmGenerateRequest, +) -> Result { + let mut messages = Vec::new(); + + // Add system prompt if provided + if let Some(system) = &request.system_prompt { + messages.push(ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text(system.to_string()), ..Default::default() }, )); + } - // Create the chat completion request - let request = CreateChatCompletionRequest { - model: request.model.to_string(), - messages, - response_format: match request.output_format { - Some(super::OutputFormat::JsonSchema { name, schema }) => { - Some(ResponseFormat::JsonSchema { - json_schema: ResponseFormatJsonSchema { - name: name.into_owned(), - description: None, - schema: Some(serde_json::to_value(&schema)?), - strict: Some(true), + // Add user message + let user_message_content = match &request.image { + Some(img_bytes) => { + let base64_image = BASE64_STANDARD.encode(img_bytes.as_ref()); + let mime_type = detect_image_mime_type(img_bytes.as_ref())?; + let image_url = format!("data:{mime_type};base64,{base64_image}"); + ChatCompletionRequestUserMessageContent::Array(vec![ + ChatCompletionRequestUserMessageContentPart::Text( + ChatCompletionRequestMessageContentPartText { + text: request.user_prompt.to_string(), + }, + ), + ChatCompletionRequestUserMessageContentPart::ImageUrl( + ChatCompletionRequestMessageContentPartImage { + image_url: async_openai::types::ImageUrl { + url: image_url, + detail: Some(ImageDetail::Auto), }, - }) - } - None => None, - }, + }, + ), + ]) + } + None => ChatCompletionRequestUserMessageContent::Text(request.user_prompt.to_string()), + }; + messages.push(ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: user_message_content, ..Default::default() - }; + }, + )); + // Create the chat completion request + let request = CreateChatCompletionRequest { + model: request.model.to_string(), + messages, + response_format: match &request.output_format { + Some(super::OutputFormat::JsonSchema { name, schema }) => { + Some(ResponseFormat::JsonSchema { + json_schema: ResponseFormatJsonSchema { + name: name.to_string(), + description: None, + schema: Some(serde_json::to_value(&schema)?), + strict: Some(true), + }, + }) + } + None => None, + }, + ..Default::default() + }; - // Send request and get response - let response = self.client.chat().create(request).await?; + Ok(request) +} + +#[async_trait] +impl LlmGenerationClient for Client { + async fn generate<'req>( + &self, + request: super::LlmGenerateRequest<'req>, + ) -> Result { + let request = &request; + let response = retryable::run( + || async { + let req = create_llm_generation_request(request)?; + let response = self.client.chat().create(req).await?; + retryable::Ok(response) + }, + &retryable::RetryOptions::default(), + ) + .await?; // Extract the response text from the first choice let text = response @@ -161,16 +182,21 @@ impl LlmEmbeddingClient for Client { &self, request: super::LlmEmbeddingRequest<'req>, ) -> Result { - let response = self - .client - .embeddings() - .create(CreateEmbeddingRequest { - model: request.model.to_string(), - input: EmbeddingInput::String(request.text.to_string()), - dimensions: request.output_dimension, - ..Default::default() - }) - .await?; + let response = retryable::run( + || async { + self.client + .embeddings() + .create(CreateEmbeddingRequest { + model: request.model.to_string(), + input: EmbeddingInput::String(request.text.to_string()), + dimensions: request.output_dimension, + ..Default::default() + }) + .await + }, + &retryable::RetryOptions::default(), + ) + .await?; Ok(super::LlmEmbeddingResponse { embedding: response .data diff --git a/src/ops/targets/neo4j.rs b/src/ops/targets/neo4j.rs index 49517dd9c..a5e3559e5 100644 --- a/src/ops/targets/neo4j.rs +++ b/src/ops/targets/neo4j.rs @@ -1070,8 +1070,7 @@ impl TargetFactoryBase for Factory { }, &retry_options, ) - .await - .map_err(Into::::into)? + .await?; } Ok(()) } diff --git a/src/utils/retryable.rs b/src/utils/retryable.rs index 51fe519a0..3711bf1db 100644 --- a/src/utils/retryable.rs +++ b/src/utils/retryable.rs @@ -9,8 +9,8 @@ pub trait IsRetryable { } pub struct Error { - error: anyhow::Error, - is_retryable: bool, + pub error: anyhow::Error, + pub is_retryable: bool, } pub const DEFAULT_RETRY_TIMEOUT: Duration = Duration::from_secs(10 * 60); @@ -40,12 +40,19 @@ impl IsRetryable for reqwest::Error { } impl Error { - pub fn always_retryable(error: anyhow::Error) -> Self { + pub fn retryable>(error: E) -> Self { Self { - error, + error: error.into(), is_retryable: true, } } + + pub fn not_retryable>(error: E) -> Self { + Self { + error: error.into(), + is_retryable: false, + } + } } impl From for Error {