1
1
use colored:: Colorize ;
2
- use dkn_executor:: TaskBody ;
2
+ use dkn_executor:: { CompletionError , ModelProvider , PromptError , TaskBody } ;
3
3
use dkn_p2p:: libp2p:: request_response:: ResponseChannel ;
4
- use dkn_utils:: payloads:: { TaskRequestPayload , TaskResponsePayload , TaskStats , TASK_RESULT_TOPIC } ;
4
+ use dkn_utils:: payloads:: {
5
+ TaskError , TaskRequestPayload , TaskResponsePayload , TaskStats , TASK_RESULT_TOPIC ,
6
+ } ;
5
7
use dkn_utils:: DriaMessage ;
6
8
use eyre:: { Context , Result } ;
7
9
@@ -25,27 +27,23 @@ impl TaskResponder {
25
27
let task = compute_message
26
28
. parse_payload :: < TaskRequestPayload < serde_json:: Value > > ( )
27
29
. wrap_err ( "could not parse task request payload" ) ?;
28
- let task_body = match serde_json:: from_value :: < TaskBody > ( task. input )
29
- . wrap_err ( "could not parse task body" )
30
- {
30
+ let task_body = match serde_json:: from_value :: < TaskBody > ( task. input ) {
31
31
Ok ( task_body) => task_body,
32
32
Err ( err) => {
33
- let err_string = format ! ( "{:#}" , err) ;
34
33
log:: error!(
35
- "Task {}/{} failed due to parsing error: {}" ,
34
+ "Task {}/{} failed due to parsing error: {err }" ,
36
35
task. file_id,
37
36
task. row_id,
38
- err_string
39
37
) ;
40
38
41
39
// prepare error payload
42
40
let error_payload = TaskResponsePayload {
43
41
result : None ,
44
- error : Some ( err_string ) ,
42
+ error : Some ( TaskError :: ParseError ( err . to_string ( ) ) ) ,
45
43
row_id : task. row_id ,
46
44
file_id : task. file_id ,
47
45
task_id : task. task_id ,
48
- model : Default :: default ( ) ,
46
+ model : "<n/a>" . to_string ( ) , // no model available due to parsing error
49
47
stats : TaskStats :: new ( ) ,
50
48
} ;
51
49
@@ -56,7 +54,8 @@ impl TaskResponder {
56
54
let response = node. new_message ( error_payload_str, TASK_RESULT_TOPIC ) ;
57
55
node. p2p . respond ( response. into ( ) , channel) . await ?;
58
56
59
- return Err ( err) ;
57
+ // return with error
58
+ eyre:: bail!( "could not parse task body: {err}" )
60
59
}
61
60
} ;
62
61
@@ -75,7 +74,7 @@ impl TaskResponder {
75
74
let task_metadata = TaskWorkerMetadata {
76
75
task_id : task. task_id ,
77
76
file_id : task. file_id ,
78
- model_name : task_body. model . to_string ( ) ,
77
+ model : task_body. model ,
79
78
channel,
80
79
} ;
81
80
let task_input = TaskWorkerInput {
@@ -112,7 +111,7 @@ impl TaskResponder {
112
111
file_id : task_metadata. file_id ,
113
112
task_id : task_metadata. task_id ,
114
113
row_id : task_output. row_id ,
115
- model : task_metadata. model_name ,
114
+ model : task_metadata. model . to_string ( ) ,
116
115
stats : task_output
117
116
. stats
118
117
. record_published_at ( )
@@ -125,22 +124,21 @@ impl TaskResponder {
125
124
}
126
125
Err ( err) => {
127
126
// use pretty display string for error logging with causes
128
- let err_string = format ! ( "{:#}" , err) ;
129
127
log:: error!(
130
- "Task {}/{} failed: {}" ,
128
+ "Task {}/{} failed: {:# }" ,
131
129
task_metadata. file_id,
132
130
task_output. row_id,
133
- err_string
131
+ err
134
132
) ;
135
133
136
134
// prepare error payload
137
135
let error_payload = TaskResponsePayload {
138
136
result : None ,
139
- error : Some ( err_string ) ,
137
+ error : Some ( map_prompt_error ( task_metadata . model . provider ( ) , err ) ) ,
140
138
row_id : task_output. row_id ,
141
139
file_id : task_metadata. file_id ,
142
140
task_id : task_metadata. task_id ,
143
- model : task_metadata. model_name ,
141
+ model : task_metadata. model . to_string ( ) ,
144
142
stats : task_output
145
143
. stats
146
144
. record_published_at ( )
@@ -161,3 +159,104 @@ impl TaskResponder {
161
159
Ok ( ( ) )
162
160
}
163
161
}
162
+
163
+ /// Maps a [`PromptError`] to a [`DriaExecutorError`] with respect to the given provider.
164
+ fn map_prompt_error ( provider : ModelProvider , err : PromptError ) -> TaskError {
165
+ if let PromptError :: CompletionError ( CompletionError :: ProviderError ( err_inner) ) = & err {
166
+ /// A wrapper for `{ error: T }` to match the provider error format.
167
+ #[ derive( Clone , serde:: Deserialize ) ]
168
+ struct ErrorObject < T > {
169
+ error : T ,
170
+ }
171
+
172
+ match provider {
173
+ ModelProvider :: Gemini => {
174
+ /// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
175
+ #[ derive( Clone , serde:: Deserialize ) ]
176
+ pub struct GeminiError {
177
+ code : u32 ,
178
+ message : String ,
179
+ status : String ,
180
+ }
181
+
182
+ serde_json:: from_str :: < ErrorObject < GeminiError > > ( err_inner) . map (
183
+ |ErrorObject {
184
+ error : gemini_error,
185
+ } | TaskError :: ProviderError {
186
+ code : format ! ( "{} ({})" , gemini_error. code, gemini_error. status) ,
187
+ message : gemini_error. message ,
188
+ provider : provider. to_string ( ) ,
189
+ } ,
190
+ )
191
+ }
192
+ ModelProvider :: OpenAI => {
193
+ /// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
194
+ #[ derive( Clone , serde:: Deserialize ) ]
195
+ pub struct OpenAIError {
196
+ code : String ,
197
+ message : String ,
198
+ }
199
+
200
+ serde_json:: from_str :: < ErrorObject < OpenAIError > > ( err_inner) . map (
201
+ |ErrorObject {
202
+ error : openai_error,
203
+ } | TaskError :: ProviderError {
204
+ code : openai_error. code ,
205
+ message : openai_error. message ,
206
+ provider : provider. to_string ( ) ,
207
+ } ,
208
+ )
209
+ }
210
+ ModelProvider :: OpenRouter => {
211
+ /// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
212
+ #[ derive( Clone , serde:: Deserialize ) ]
213
+ pub struct OpenRouterError {
214
+ code : u32 ,
215
+ message : String ,
216
+ }
217
+
218
+ serde_json:: from_str :: < ErrorObject < OpenRouterError > > ( err_inner) . map (
219
+ |ErrorObject {
220
+ error : openrouter_error,
221
+ } | {
222
+ TaskError :: ProviderError {
223
+ code : openrouter_error. code . to_string ( ) ,
224
+ message : openrouter_error. message ,
225
+ provider : provider. to_string ( ) ,
226
+ }
227
+ } ,
228
+ )
229
+ }
230
+ ModelProvider :: Ollama => serde_json:: from_str :: < ErrorObject < String > > ( err_inner) . map (
231
+ // Ollama just returns a string error message
232
+ |ErrorObject {
233
+ error : ollama_error,
234
+ } | {
235
+ // based on the error message, we can come up with out own "dummy" codes
236
+ let code = if ollama_error. contains ( "server busy, please try again." ) {
237
+ "server_busy"
238
+ } else if ollama_error. contains ( "model requires more system memory" ) {
239
+ "model_requires_more_memory"
240
+ } else if ollama_error. contains ( "cudaMalloc failed: out of memory" ) {
241
+ "cuda_malloc_failed"
242
+ } else if ollama_error. contains ( "CUDA error: out of memory" ) {
243
+ "cuda_oom"
244
+ } else {
245
+ "unknown"
246
+ } ;
247
+
248
+ TaskError :: ProviderError {
249
+ code : code. to_string ( ) ,
250
+ message : ollama_error,
251
+ provider : provider. to_string ( ) ,
252
+ }
253
+ } ,
254
+ ) ,
255
+ }
256
+ // if we couldn't parse it, just return a generic prompt error
257
+ . unwrap_or ( TaskError :: Other ( err. to_string ( ) ) )
258
+ } else {
259
+ // not a provider error, fallback to generic prompt error
260
+ TaskError :: Other ( err. to_string ( ) )
261
+ }
262
+ }
0 commit comments