Skip to content

Commit 92fe260

Browse files
committed
add http error handle as well
1 parent 138fa6b commit 92fe260

File tree

3 files changed

+99
-92
lines changed

3 files changed

+99
-92
lines changed

compute/src/reqres/task.rs

Lines changed: 94 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -165,101 +165,108 @@ impl TaskResponder {
165165

166166
/// Maps a [`PromptError`] to a [`TaskError`] with respect to the given provider.
167167
fn map_prompt_error_to_task_error(provider: ModelProvider, err: PromptError) -> TaskError {
168-
if let PromptError::CompletionError(CompletionError::ProviderError(err_inner)) = &err {
169-
/// A wrapper for `{ error: T }` to match the provider error format.
170-
#[derive(Clone, serde::Deserialize)]
171-
struct ErrorObject<T> {
172-
error: T,
173-
}
168+
match &err {
169+
// if the error is a provider error, we can try to parse it
170+
PromptError::CompletionError(CompletionError::ProviderError(err_inner)) => {
171+
/// A wrapper for `{ error: T }` to match the provider error format.
172+
#[derive(Clone, serde::Deserialize)]
173+
struct ErrorObject<T> {
174+
error: T,
175+
}
176+
177+
match provider {
178+
ModelProvider::Gemini => {
179+
/// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
180+
#[derive(Clone, serde::Deserialize)]
181+
pub struct GeminiError {
182+
code: u32,
183+
message: String,
184+
status: String,
185+
}
174186

175-
match provider {
176-
ModelProvider::Gemini => {
177-
/// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
178-
#[derive(Clone, serde::Deserialize)]
179-
pub struct GeminiError {
180-
code: u32,
181-
message: String,
182-
status: String,
187+
serde_json::from_str::<ErrorObject<GeminiError>>(err_inner).map(
188+
|ErrorObject {
189+
error: gemini_error,
190+
}| TaskError::ProviderError {
191+
code: format!("{} ({})", gemini_error.code, gemini_error.status),
192+
message: gemini_error.message,
193+
provider: provider.to_string(),
194+
},
195+
)
183196
}
197+
ModelProvider::OpenAI => {
198+
/// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
199+
#[derive(Clone, serde::Deserialize)]
200+
pub struct OpenAIError {
201+
code: String,
202+
message: String,
203+
}
184204

185-
serde_json::from_str::<ErrorObject<GeminiError>>(err_inner).map(
186-
|ErrorObject {
187-
error: gemini_error,
188-
}| TaskError::ProviderError {
189-
code: format!("{} ({})", gemini_error.code, gemini_error.status),
190-
message: gemini_error.message,
191-
provider: provider.to_string(),
192-
},
193-
)
194-
}
195-
ModelProvider::OpenAI => {
196-
/// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
197-
#[derive(Clone, serde::Deserialize)]
198-
pub struct OpenAIError {
199-
code: String,
200-
message: String,
205+
serde_json::from_str::<ErrorObject<OpenAIError>>(err_inner).map(
206+
|ErrorObject {
207+
error: openai_error,
208+
}| TaskError::ProviderError {
209+
code: openai_error.code,
210+
message: openai_error.message,
211+
provider: provider.to_string(),
212+
},
213+
)
201214
}
215+
ModelProvider::OpenRouter => {
216+
/// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
217+
#[derive(Clone, serde::Deserialize)]
218+
pub struct OpenRouterError {
219+
code: u32,
220+
message: String,
221+
}
202222

203-
serde_json::from_str::<ErrorObject<OpenAIError>>(err_inner).map(
204-
|ErrorObject {
205-
error: openai_error,
206-
}| TaskError::ProviderError {
207-
code: openai_error.code,
208-
message: openai_error.message,
209-
provider: provider.to_string(),
210-
},
211-
)
212-
}
213-
ModelProvider::OpenRouter => {
214-
/// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
215-
#[derive(Clone, serde::Deserialize)]
216-
pub struct OpenRouterError {
217-
code: u32,
218-
message: String,
223+
serde_json::from_str::<ErrorObject<OpenRouterError>>(err_inner).map(
224+
|ErrorObject {
225+
error: openrouter_error,
226+
}| {
227+
TaskError::ProviderError {
228+
code: openrouter_error.code.to_string(),
229+
message: openrouter_error.message,
230+
provider: provider.to_string(),
231+
}
232+
},
233+
)
219234
}
235+
ModelProvider::Ollama => serde_json::from_str::<ErrorObject<String>>(err_inner)
236+
.map(
237+
// Ollama just returns a string error message
238+
|ErrorObject {
239+
error: ollama_error,
240+
}| {
241+
// based on the error message, we can come up with out own "dummy" codes
242+
let code = if ollama_error.contains("server busy, please try again.") {
243+
"server_busy"
244+
} else if ollama_error.contains("model requires more system memory") {
245+
"model_requires_more_memory"
246+
} else if ollama_error.contains("cudaMalloc failed: out of memory") {
247+
"cuda_malloc_failed"
248+
} else if ollama_error.contains("CUDA error: out of memory") {
249+
"cuda_oom"
250+
} else {
251+
"unknown"
252+
};
220253

221-
serde_json::from_str::<ErrorObject<OpenRouterError>>(err_inner).map(
222-
|ErrorObject {
223-
error: openrouter_error,
224-
}| {
225-
TaskError::ProviderError {
226-
code: openrouter_error.code.to_string(),
227-
message: openrouter_error.message,
228-
provider: provider.to_string(),
229-
}
230-
},
231-
)
254+
TaskError::ProviderError {
255+
code: code.to_string(),
256+
message: ollama_error,
257+
provider: provider.to_string(),
258+
}
259+
},
260+
),
232261
}
233-
ModelProvider::Ollama => serde_json::from_str::<ErrorObject<String>>(err_inner).map(
234-
// Ollama just returns a string error message
235-
|ErrorObject {
236-
error: ollama_error,
237-
}| {
238-
// based on the error message, we can come up with out own "dummy" codes
239-
let code = if ollama_error.contains("server busy, please try again.") {
240-
"server_busy"
241-
} else if ollama_error.contains("model requires more system memory") {
242-
"model_requires_more_memory"
243-
} else if ollama_error.contains("cudaMalloc failed: out of memory") {
244-
"cuda_malloc_failed"
245-
} else if ollama_error.contains("CUDA error: out of memory") {
246-
"cuda_oom"
247-
} else {
248-
"unknown"
249-
};
250-
251-
TaskError::ProviderError {
252-
code: code.to_string(),
253-
message: ollama_error,
254-
provider: provider.to_string(),
255-
}
256-
},
257-
),
262+
// if we couldn't parse it, just return a generic prompt error
263+
.unwrap_or(TaskError::ExecutorError(err_inner.clone()))
264+
}
265+
// if its a http error, we can try to parse it as well
266+
PromptError::CompletionError(CompletionError::HttpError(err_inner)) => {
267+
TaskError::HttpError(err_inner.to_string())
258268
}
259-
// if we couldn't parse it, just return a generic prompt error
260-
.unwrap_or(TaskError::ExecutorError(err_inner.clone()))
261-
} else {
262-
// not a provider error, fallback to generic prompt error
263-
TaskError::Other(err.to_string())
269+
// if it's not a completion error, we just return the error as is
270+
err => TaskError::Other(err.to_string()),
264271
}
265272
}

executor/src/executors/mod.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,9 @@ impl DriaExecutor {
3838
pub async fn execute(&self, task: crate::TaskBody) -> Result<String, PromptError> {
3939
match self {
4040
DriaExecutor::Ollama(provider) => provider.execute(task).await,
41-
// .map_err(|e| map_prompt_error(&ModelProvider::Ollama, e)),
4241
DriaExecutor::OpenAI(provider) => provider.execute(task).await,
43-
// .map_err(|e| map_prompt_error(&ModelProvider::OpenAI, e)),
4442
DriaExecutor::Gemini(provider) => provider.execute(task).await,
45-
// .map_err(|e| map_prompt_error(&ModelProvider::Gemini, e)),
4643
DriaExecutor::OpenRouter(provider) => provider.execute(task).await,
47-
// .map_err(|e| map_prompt_error(&ModelProvider::OpenRouter, e)),
4844
}
4945
}
5046

utils/src/payloads/tasks.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub enum TaskError {
6363
#[error("Parse error: {0}")]
6464
ParseError(String),
6565
/// An error returned from the model provider.
66-
#[error("Provider error: {code} - {message} (source: {provider})")]
66+
#[error("{provider} error ({code}): {message}")]
6767
ProviderError {
6868
/// Not necessarily an HTTP status code, but a code that the provider uses to identify the error.
6969
///
@@ -78,6 +78,10 @@ pub enum TaskError {
7878
/// Can be a provider name, or RPC etc.
7979
provider: String,
8080
},
81+
/// A network-related error from the client.
82+
#[error("HTTP error: {0}")]
83+
/// This is a generic HTTP error, not necessarily related to the provider.
84+
HttpError(String),
8185
/// Any other executor error that is not a provider error.
8286
#[error("Executor error: {0}")]
8387
ExecutorError(String),

0 commit comments

Comments
 (0)