Skip to content

Commit 138fa6b

Browse files
committed
add thiserror impls
1 parent f66178a commit 138fa6b

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

compute/src/reqres/task.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,10 @@ impl TaskResponder {
134134
// prepare error payload
135135
let error_payload = TaskResponsePayload {
136136
result: None,
137-
error: Some(map_prompt_error(task_metadata.model.provider(), err)),
137+
error: Some(map_prompt_error_to_task_error(
138+
task_metadata.model.provider(),
139+
err,
140+
)),
138141
row_id: task_output.row_id,
139142
file_id: task_metadata.file_id,
140143
task_id: task_metadata.task_id,
@@ -160,8 +163,8 @@ impl TaskResponder {
160163
}
161164
}
162165

163-
/// Maps a [`PromptError`] to a [`DriaExecutorError`] with respect to the given provider.
164-
fn map_prompt_error(provider: ModelProvider, err: PromptError) -> TaskError {
166+
/// Maps a [`PromptError`] to a [`TaskError`] with respect to the given provider.
167+
fn map_prompt_error_to_task_error(provider: ModelProvider, err: PromptError) -> TaskError {
165168
if let PromptError::CompletionError(CompletionError::ProviderError(err_inner)) = &err {
166169
/// A wrapper for `{ error: T }` to match the provider error format.
167170
#[derive(Clone, serde::Deserialize)]
@@ -254,7 +257,7 @@ fn map_prompt_error(provider: ModelProvider, err: PromptError) -> TaskError {
254257
),
255258
}
256259
// if we couldn't parse it, just return a generic prompt error
257-
.unwrap_or(TaskError::Other(err.to_string()))
260+
.unwrap_or(TaskError::ExecutorError(err_inner.clone()))
258261
} else {
259262
// not a provider error, fallback to generic prompt error
260263
TaskError::Other(err.to_string())

utils/src/payloads/tasks.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@ pub struct TaskRequestPayload<T> {
5757
pub input: T,
5858
}
5959

60-
#[derive(Debug, Clone, Serialize, Deserialize)]
60+
#[derive(Debug, Clone, Serialize, Deserialize, thiserror::Error)]
6161
pub enum TaskError {
6262
/// A parse error occurred while parsing the task request or response.
63+
#[error("Parse error: {0}")]
6364
ParseError(String),
6465
/// An error returned from the model provider.
66+
#[error("Provider error: {code} - {message} (source: {provider})")]
6567
ProviderError {
6668
/// Not necessarily an HTTP status code, but a code that the provider uses to identify the error.
6769
///
@@ -76,13 +78,18 @@ pub enum TaskError {
7678
/// Can be a provider name, or RPC etc.
7779
provider: String,
7880
},
81+
/// Any other executor error that is not a provider error.
82+
#[error("Executor error: {0}")]
83+
ExecutorError(String),
7984
/// The task request had failed for some network reason.
85+
#[error("Outbound request error: {code} - {message}")]
8086
OutboundRequestError {
8187
code: String,
8288
/// The error message returned by the network.
8389
message: String,
8490
},
85-
/// An error that returned by executor.
91+
/// Any other error
92+
#[error("Other error: {0}")]
8693
Other(String),
8794
}
8895

0 commit comments

Comments
 (0)