@@ -36,13 +36,11 @@ pub struct TaskPayload {
36
36
37
37
impl TaskResponder {
38
38
/// Handles the compute message for workflows.
39
- ///
40
- /// - FIXME: DOES NOT CHECK FOR FILTER AS IT IS NO LONGER USED
41
- /// - FIXME: GIVES ERROR ON DEADLINE PAST CASE, BUT WE DONT NEED DEADLINE AS WELL
42
- pub ( crate ) async fn handle_compute (
39
+ pub ( crate ) async fn prepare_worker_input (
43
40
node : & mut DriaComputeNode ,
44
41
compute_message : & DriaMessage ,
45
- ) -> Result < TaskWorkerInput > {
42
+ channel : ResponseChannel < Vec < u8 > > ,
43
+ ) -> Result < ( TaskWorkerInput , TaskWorkerMetadata ) > {
46
44
// parse payload
47
45
let task = compute_message
48
46
. parse_payload :: < TaskRequestPayload < TaskPayload > > ( )
@@ -52,7 +50,7 @@ impl TaskResponder {
52
50
let stats = TaskStats :: new ( ) . record_received_at ( ) ;
53
51
54
52
// check if deadline is past or not
55
- // with request-response, we dont expect this to happen much
53
+ // FIXME: with request-response, we dont expect this to happen much
56
54
if get_current_time_nanos ( ) >= task. deadline {
57
55
return Err ( eyre ! (
58
56
"Task {} is past the deadline, ignoring" ,
@@ -97,34 +95,40 @@ impl TaskResponder {
97
95
// get workflow as well
98
96
let workflow = task. input . workflow ;
99
97
100
- Ok ( TaskWorkerInput {
98
+ let task_input = TaskWorkerInput {
101
99
entry,
102
100
executor,
103
101
workflow,
104
- model_name,
105
102
task_id : task. task_id ,
106
- public_key : task_public_key,
107
103
stats,
108
104
batchable,
109
- } )
105
+ } ;
106
+
107
+ let task_metadata = TaskWorkerMetadata {
108
+ model_name,
109
+ public_key : task_public_key,
110
+ channel,
111
+ } ;
112
+
113
+ Ok ( ( task_input, task_metadata) )
110
114
}
111
115
112
116
/// Handles the result of a workflow task.
113
117
pub ( crate ) async fn handle_respond (
114
118
node : & mut DriaComputeNode ,
115
- task : TaskWorkerOutput ,
116
- channel : ResponseChannel < Vec < u8 > > ,
119
+ task_output : TaskWorkerOutput ,
120
+ task_metadata : TaskWorkerMetadata ,
117
121
) -> Result < ( ) > {
118
- let response = match task . result {
122
+ let response = match task_output . result {
119
123
Ok ( result) => {
120
124
// prepare signed and encrypted payload
121
- log:: info!( "Publishing result for task {}" , task . task_id) ;
125
+ log:: info!( "Publishing result for task {}" , task_output . task_id) ;
122
126
let payload = TaskResponsePayload :: new (
123
127
result,
124
- & task . task_id ,
125
- & task . public_key ,
126
- task . model_name ,
127
- task . stats . record_published_at ( ) ,
128
+ & task_output . task_id ,
129
+ & task_metadata . public_key ,
130
+ task_metadata . model_name ,
131
+ task_output . stats . record_published_at ( ) ,
128
132
) ?;
129
133
130
134
// convert payload to message
@@ -135,14 +139,14 @@ impl TaskResponder {
135
139
Err ( err) => {
136
140
// use pretty display string for error logging with causes
137
141
let err_string = format ! ( "{:#}" , err) ;
138
- log:: error!( "Task {} failed: {}" , task . task_id, err_string) ;
142
+ log:: error!( "Task {} failed: {}" , task_output . task_id, err_string) ;
139
143
140
144
// prepare error payload
141
145
let error_payload = TaskErrorPayload {
142
- task_id : task . task_id ,
146
+ task_id : task_output . task_id ,
143
147
error : err_string,
144
- model : task . model_name ,
145
- stats : task . stats . record_published_at ( ) ,
148
+ model : task_metadata . model_name ,
149
+ stats : task_output . stats . record_published_at ( ) ,
146
150
} ;
147
151
let error_payload_str = serde_json:: json!( error_payload) . to_string ( ) ;
148
152
@@ -152,7 +156,7 @@ impl TaskResponder {
152
156
153
157
// respond through the channel
154
158
let data = response. to_bytes ( ) ?;
155
- node. p2p . respond ( data, channel) . await ?;
159
+ node. p2p . respond ( data, task_metadata . channel ) . await ?;
156
160
157
161
Ok ( ( ) )
158
162
}
0 commit comments