Skip to content

Commit 36acf22

Browse files
committed
add better error checks, not yet impl [skip ci]
1 parent 8bd8f1b commit 36acf22

File tree

2 files changed

+137
-12
lines changed

2 files changed

+137
-12
lines changed

executor/src/executors/errors.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#![allow(unused)]
2+
3+
use crate::{Model, ModelProvider};
4+
use rig::completion::{CompletionError, PromptError};
5+
6+
#[derive(Debug, thiserror::Error, serde::Serialize, serde::Deserialize)]
7+
pub enum DriaExecutorError {
8+
#[error("Model {0} is not a valid model.")]
9+
InvalidModel(String),
10+
#[error("Model {0} not found in your configuration.")]
11+
ModelNotSupported(Model),
12+
#[error("Provider {0} not found in your configuration")]
13+
ProviderNotSupported(ModelProvider),
14+
15+
/// A generic error that wraps a [`rig::completion::PromptError`] in string form.
16+
#[error("Rig error: {0}")]
17+
RigError(String),
18+
/// A sub-type of `PromptError` that succesfully parses the error from the provider.
19+
///
20+
/// It is parsed from `PrompError(ProviderError(String))`.
21+
#[error("{provider} error ({code}): {message}")]
22+
ProviderError {
23+
/// Not necessarily an HTTP status code, but a code that the provider uses to identify the error.
24+
///
25+
/// For example, OpenAI uses a string code like "invalid_request_error".
26+
code: String,
27+
/// The error message returned by the provider.
28+
///
29+
/// May contain additional information about the error.
30+
message: String,
31+
/// The provider that returned the error.
32+
///
33+
/// Do we need it?
34+
provider: ModelProvider,
35+
},
36+
}
37+
38+
/// Maps a [`PromptError`] to a [`DriaExecutorError`] with respect to the given provider.
39+
pub fn map_prompt_error(provider: &ModelProvider, err: PromptError) -> DriaExecutorError {
40+
if let PromptError::CompletionError(CompletionError::ProviderError(err_inner)) = &err {
41+
// all the body's below have an `error` field
42+
#[derive(Clone, serde::Deserialize)]
43+
struct ErrorObject<T> {
44+
error: T,
45+
}
46+
47+
match provider {
48+
ModelProvider::Gemini => {
49+
/// A Gemini API error object.
50+
///
51+
/// See their Go [client for reference](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
52+
#[derive(Clone, serde::Deserialize)]
53+
pub struct GeminiError {
54+
code: u32,
55+
message: String,
56+
status: String,
57+
}
58+
59+
serde_json::from_str::<ErrorObject<GeminiError>>(err_inner).map(
60+
|ErrorObject {
61+
error: gemini_error,
62+
}| DriaExecutorError::ProviderError {
63+
code: format!("{} ({})", gemini_error.code, gemini_error.status),
64+
message: gemini_error.message,
65+
provider: ModelProvider::Gemini,
66+
},
67+
)
68+
}
69+
ModelProvider::OpenAI => {
70+
/// An OpenAI error object.
71+
///
72+
/// See their Go [client for reference](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
73+
#[derive(Clone, serde::Deserialize)]
74+
pub struct OpenAIError {
75+
code: String,
76+
message: String,
77+
}
78+
79+
serde_json::from_str::<ErrorObject<OpenAIError>>(err_inner).map(
80+
|ErrorObject {
81+
error: openai_error,
82+
}| DriaExecutorError::ProviderError {
83+
code: openai_error.code,
84+
message: openai_error.message,
85+
provider: ModelProvider::OpenAI,
86+
},
87+
)
88+
}
89+
ModelProvider::OpenRouter => {
90+
/// An OpenRouter error object.
91+
///
92+
/// See [their documentation](https://openrouter.ai/docs/api-reference/errors).
93+
#[derive(Clone, serde::Deserialize)]
94+
pub struct OpenRouterError {
95+
code: u32,
96+
message: String,
97+
}
98+
99+
serde_json::from_str::<ErrorObject<OpenRouterError>>(err_inner).map(
100+
|ErrorObject {
101+
error: openrouter_error,
102+
}| {
103+
DriaExecutorError::ProviderError {
104+
code: openrouter_error.code.to_string(),
105+
message: openrouter_error.message,
106+
provider: ModelProvider::OpenRouter,
107+
}
108+
},
109+
)
110+
}
111+
ModelProvider::Ollama => serde_json::from_str::<ErrorObject<String>>(err_inner).map(
112+
|ErrorObject {
113+
error: ollama_error,
114+
}| {
115+
DriaExecutorError::ProviderError {
116+
code: "ollama".to_string(),
117+
message: ollama_error,
118+
provider: ModelProvider::Ollama,
119+
}
120+
},
121+
),
122+
}
123+
// if we couldn't parse it, just return a generic prompt error
124+
.unwrap_or(DriaExecutorError::RigError(err.to_string()))
125+
} else {
126+
// not a provider error, fallback to generic prompt error
127+
DriaExecutorError::RigError(err.to_string())
128+
}
129+
}

executor/src/executors/mod.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
use crate::ModelProvider;
12
use rig::completion::PromptError;
23
use std::collections::HashSet;
34

5+
mod errors;
6+
pub use errors::DriaExecutorError;
7+
48
mod ollama;
59
use ollama::OllamaClient;
610

@@ -13,18 +17,6 @@ use gemini::GeminiClient;
1317
mod openrouter;
1418
use openrouter::OpenRouterClient;
1519

16-
use crate::{Model, ModelProvider};
17-
18-
#[derive(Debug, thiserror::Error)]
19-
pub enum DriaExecutorError {
20-
#[error("Model {0} is not a valid model.")]
21-
InvalidModel(String),
22-
#[error("Model {0} not found in your configuration.")]
23-
ModelNotSupported(Model),
24-
#[error("Provider {0} not found in your configuration")]
25-
ProviderNotSupported(ModelProvider),
26-
}
27-
2820
/// A wrapper enum for all model providers.
2921
#[derive(Clone)]
3022
pub enum DriaExecutor {
@@ -49,9 +41,13 @@ impl DriaExecutor {
4941
pub async fn execute(&self, task: crate::TaskBody) -> Result<String, PromptError> {
5042
match self {
5143
DriaExecutor::Ollama(provider) => provider.execute(task).await,
44+
// .map_err(|e| map_prompt_error(&ModelProvider::Ollama, e)),
5245
DriaExecutor::OpenAI(provider) => provider.execute(task).await,
46+
// .map_err(|e| map_prompt_error(&ModelProvider::OpenAI, e)),
5347
DriaExecutor::Gemini(provider) => provider.execute(task).await,
48+
// .map_err(|e| map_prompt_error(&ModelProvider::Gemini, e)),
5449
DriaExecutor::OpenRouter(provider) => provider.execute(task).await,
50+
// .map_err(|e| map_prompt_error(&ModelProvider::OpenRouter, e)),
5551
}
5652
}
5753

0 commit comments

Comments
 (0)