Skip to content

Commit f66178a

Browse files
committed
better error reporting
1 parent 0abd136 commit f66178a

File tree

4 files changed

+151
-23
lines changed

4 files changed

+151
-23
lines changed

compute/src/reqres/task.rs

Lines changed: 117 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use colored::Colorize;
2-
use dkn_executor::TaskBody;
2+
use dkn_executor::{CompletionError, ModelProvider, PromptError, TaskBody};
33
use dkn_p2p::libp2p::request_response::ResponseChannel;
4-
use dkn_utils::payloads::{TaskRequestPayload, TaskResponsePayload, TaskStats, TASK_RESULT_TOPIC};
4+
use dkn_utils::payloads::{
5+
TaskError, TaskRequestPayload, TaskResponsePayload, TaskStats, TASK_RESULT_TOPIC,
6+
};
57
use dkn_utils::DriaMessage;
68
use eyre::{Context, Result};
79

@@ -25,27 +27,23 @@ impl TaskResponder {
2527
let task = compute_message
2628
.parse_payload::<TaskRequestPayload<serde_json::Value>>()
2729
.wrap_err("could not parse task request payload")?;
28-
let task_body = match serde_json::from_value::<TaskBody>(task.input)
29-
.wrap_err("could not parse task body")
30-
{
30+
let task_body = match serde_json::from_value::<TaskBody>(task.input) {
3131
Ok(task_body) => task_body,
3232
Err(err) => {
33-
let err_string = format!("{:#}", err);
3433
log::error!(
35-
"Task {}/{} failed due to parsing error: {}",
34+
"Task {}/{} failed due to parsing error: {err}",
3635
task.file_id,
3736
task.row_id,
38-
err_string
3937
);
4038

4139
// prepare error payload
4240
let error_payload = TaskResponsePayload {
4341
result: None,
44-
error: Some(err_string),
42+
error: Some(TaskError::ParseError(err.to_string())),
4543
row_id: task.row_id,
4644
file_id: task.file_id,
4745
task_id: task.task_id,
48-
model: Default::default(),
46+
model: "<n/a>".to_string(), // no model available due to parsing error
4947
stats: TaskStats::new(),
5048
};
5149

@@ -56,7 +54,8 @@ impl TaskResponder {
5654
let response = node.new_message(error_payload_str, TASK_RESULT_TOPIC);
5755
node.p2p.respond(response.into(), channel).await?;
5856

59-
return Err(err);
57+
// return with error
58+
eyre::bail!("could not parse task body: {err}")
6059
}
6160
};
6261

@@ -75,7 +74,7 @@ impl TaskResponder {
7574
let task_metadata = TaskWorkerMetadata {
7675
task_id: task.task_id,
7776
file_id: task.file_id,
78-
model_name: task_body.model.to_string(),
77+
model: task_body.model,
7978
channel,
8079
};
8180
let task_input = TaskWorkerInput {
@@ -112,7 +111,7 @@ impl TaskResponder {
112111
file_id: task_metadata.file_id,
113112
task_id: task_metadata.task_id,
114113
row_id: task_output.row_id,
115-
model: task_metadata.model_name,
114+
model: task_metadata.model.to_string(),
116115
stats: task_output
117116
.stats
118117
.record_published_at()
@@ -125,22 +124,21 @@ impl TaskResponder {
125124
}
126125
Err(err) => {
127126
// use pretty display string for error logging with causes
128-
let err_string = format!("{:#}", err);
129127
log::error!(
130-
"Task {}/{} failed: {}",
128+
"Task {}/{} failed: {:#}",
131129
task_metadata.file_id,
132130
task_output.row_id,
133-
err_string
131+
err
134132
);
135133

136134
// prepare error payload
137135
let error_payload = TaskResponsePayload {
138136
result: None,
139-
error: Some(err_string),
137+
error: Some(map_prompt_error(task_metadata.model.provider(), err)),
140138
row_id: task_output.row_id,
141139
file_id: task_metadata.file_id,
142140
task_id: task_metadata.task_id,
143-
model: task_metadata.model_name,
141+
model: task_metadata.model.to_string(),
144142
stats: task_output
145143
.stats
146144
.record_published_at()
@@ -161,3 +159,104 @@ impl TaskResponder {
161159
Ok(())
162160
}
163161
}
162+
163+
/// Maps a [`PromptError`] to a [`DriaExecutorError`] with respect to the given provider.
164+
fn map_prompt_error(provider: ModelProvider, err: PromptError) -> TaskError {
165+
if let PromptError::CompletionError(CompletionError::ProviderError(err_inner)) = &err {
166+
/// A wrapper for `{ error: T }` to match the provider error format.
167+
#[derive(Clone, serde::Deserialize)]
168+
struct ErrorObject<T> {
169+
error: T,
170+
}
171+
172+
match provider {
173+
ModelProvider::Gemini => {
174+
/// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
175+
#[derive(Clone, serde::Deserialize)]
176+
pub struct GeminiError {
177+
code: u32,
178+
message: String,
179+
status: String,
180+
}
181+
182+
serde_json::from_str::<ErrorObject<GeminiError>>(err_inner).map(
183+
|ErrorObject {
184+
error: gemini_error,
185+
}| TaskError::ProviderError {
186+
code: format!("{} ({})", gemini_error.code, gemini_error.status),
187+
message: gemini_error.message,
188+
provider: provider.to_string(),
189+
},
190+
)
191+
}
192+
ModelProvider::OpenAI => {
193+
/// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
194+
#[derive(Clone, serde::Deserialize)]
195+
pub struct OpenAIError {
196+
code: String,
197+
message: String,
198+
}
199+
200+
serde_json::from_str::<ErrorObject<OpenAIError>>(err_inner).map(
201+
|ErrorObject {
202+
error: openai_error,
203+
}| TaskError::ProviderError {
204+
code: openai_error.code,
205+
message: openai_error.message,
206+
provider: provider.to_string(),
207+
},
208+
)
209+
}
210+
ModelProvider::OpenRouter => {
211+
/// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
212+
#[derive(Clone, serde::Deserialize)]
213+
pub struct OpenRouterError {
214+
code: u32,
215+
message: String,
216+
}
217+
218+
serde_json::from_str::<ErrorObject<OpenRouterError>>(err_inner).map(
219+
|ErrorObject {
220+
error: openrouter_error,
221+
}| {
222+
TaskError::ProviderError {
223+
code: openrouter_error.code.to_string(),
224+
message: openrouter_error.message,
225+
provider: provider.to_string(),
226+
}
227+
},
228+
)
229+
}
230+
ModelProvider::Ollama => serde_json::from_str::<ErrorObject<String>>(err_inner).map(
231+
// Ollama just returns a string error message
232+
|ErrorObject {
233+
error: ollama_error,
234+
}| {
235+
// based on the error message, we can come up with out own "dummy" codes
236+
let code = if ollama_error.contains("server busy, please try again.") {
237+
"server_busy"
238+
} else if ollama_error.contains("model requires more system memory") {
239+
"model_requires_more_memory"
240+
} else if ollama_error.contains("cudaMalloc failed: out of memory") {
241+
"cuda_malloc_failed"
242+
} else if ollama_error.contains("CUDA error: out of memory") {
243+
"cuda_oom"
244+
} else {
245+
"unknown"
246+
};
247+
248+
TaskError::ProviderError {
249+
code: code.to_string(),
250+
message: ollama_error,
251+
provider: provider.to_string(),
252+
}
253+
},
254+
),
255+
}
256+
// if we couldn't parse it, just return a generic prompt error
257+
.unwrap_or(TaskError::Other(err.to_string()))
258+
} else {
259+
// not a provider error, fallback to generic prompt error
260+
TaskError::Other(err.to_string())
261+
}
262+
}

compute/src/workers/task.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use colored::Colorize;
2-
use dkn_executor::{DriaExecutor, TaskBody};
2+
use dkn_executor::{DriaExecutor, Model, TaskBody};
33
use dkn_p2p::libp2p::request_response::ResponseChannel;
44
use dkn_utils::payloads::TaskStats;
55
use tokio::sync::mpsc;
@@ -9,7 +9,7 @@ use uuid::Uuid;
99
///
1010
/// This is put into a map before execution, and then removed after the task is done.
1111
pub struct TaskWorkerMetadata {
12-
pub model_name: String,
12+
pub model: Model,
1313
pub task_id: String,
1414
pub file_id: Uuid,
1515
/// If for any reason this object is dropped before `channel` is responded to,

utils/src/payloads/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mod tasks;
2-
pub use tasks::{TaskRequestPayload, TaskResponsePayload, TaskStats};
2+
pub use tasks::{TaskError, TaskRequestPayload, TaskResponsePayload, TaskStats};
33
pub use tasks::{TASK_REQUEST_TOPIC, TASK_RESULT_TOPIC};
44

55
mod heartbeat;

utils/src/payloads/tasks.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ pub struct TaskResponsePayload {
3333
/// If this is `None`, the task failed, and you should check the `error` field.
3434
#[serde(skip_serializing_if = "Option::is_none")]
3535
pub result: Option<String>,
36-
/// An error message, if any.
36+
/// An error, if any.
3737
///
3838
/// If this is `Some`, you can ignore the `result` field.
3939
#[serde(skip_serializing_if = "Option::is_none")]
40-
pub error: Option<String>,
40+
pub error: Option<TaskError>,
4141
}
4242

4343
/// A generic task request, given by Dria.
@@ -57,6 +57,35 @@ pub struct TaskRequestPayload<T> {
5757
pub input: T,
5858
}
5959

60+
#[derive(Debug, Clone, Serialize, Deserialize)]
61+
pub enum TaskError {
62+
/// A parse error occurred while parsing the task request or response.
63+
ParseError(String),
64+
/// An error returned from the model provider.
65+
ProviderError {
66+
/// Not necessarily an HTTP status code, but a code that the provider uses to identify the error.
67+
///
68+
/// For example, OpenAI uses a string code like "invalid_request_error".
69+
code: String,
70+
/// The error message returned by the provider.
71+
///
72+
/// May contain additional information about the error.
73+
message: String,
74+
/// The source of the error.
75+
///
76+
/// Can be a provider name, or RPC etc.
77+
provider: String,
78+
},
79+
/// The task request had failed for some network reason.
80+
OutboundRequestError {
81+
code: String,
82+
/// The error message returned by the network.
83+
message: String,
84+
},
85+
/// An error that returned by executor.
86+
Other(String),
87+
}
88+
6089
/// Task stats for diagnostics.
6190
///
6291
/// Returning this as the payload helps to debug the errors received at client side, and latencies.

0 commit comments

Comments
 (0)