1
1
use dkn_p2p:: libp2p:: gossipsub:: MessageAcceptance ;
2
- use dkn_workflows:: { Entry , Executor , ModelProvider , ProgramMemory , Workflow } ;
3
- use eyre:: { eyre , Context , Result } ;
2
+ use dkn_workflows:: { Entry , Executor , ModelProvider , Workflow } ;
3
+ use eyre:: { Context , Result } ;
4
4
use libsecp256k1:: PublicKey ;
5
5
use serde:: Deserialize ;
6
- use std :: time :: Instant ;
6
+ use tokio_util :: either :: Either ;
7
7
8
8
use crate :: payloads:: { TaskErrorPayload , TaskRequestPayload , TaskResponsePayload , TaskStats } ;
9
9
use crate :: utils:: { get_current_time_nanos, DKNMessage } ;
10
+ use crate :: workers:: workflow:: * ;
10
11
use crate :: DriaComputeNode ;
11
12
12
13
pub struct WorkflowHandler ;
@@ -31,11 +32,13 @@ impl WorkflowHandler {
31
32
pub ( crate ) async fn handle_compute (
32
33
node : & mut DriaComputeNode ,
33
34
message : DKNMessage ,
34
- ) -> Result < MessageAcceptance > {
35
+ ) -> Result < Either < MessageAcceptance , WorkflowsWorkerInput > > {
35
36
let task = message
36
37
. parse_payload :: < TaskRequestPayload < WorkflowPayload > > ( true )
37
38
. wrap_err ( "Could not parse workflow task" ) ?;
38
- let mut task_stats = TaskStats :: default ( ) . record_received_at ( ) ;
39
+
40
+ // TODO: !!!
41
+ let task_stats = TaskStats :: default ( ) . record_received_at ( ) ;
39
42
40
43
// check if deadline is past or not
41
44
let current_time = get_current_time_nanos ( ) ;
@@ -48,7 +51,7 @@ impl WorkflowHandler {
48
51
) ;
49
52
50
53
// ignore the message
51
- return Ok ( MessageAcceptance :: Ignore ) ;
54
+ return Ok ( Either :: Left ( MessageAcceptance :: Ignore ) ) ;
52
55
}
53
56
54
57
// check task inclusion via the bloom filter
@@ -59,9 +62,15 @@ impl WorkflowHandler {
59
62
) ;
60
63
61
64
// accept the message, someone else may be included in filter
62
- return Ok ( MessageAcceptance :: Accept ) ;
65
+ return Ok ( Either :: Left ( MessageAcceptance :: Accept ) ) ;
63
66
}
64
67
68
+ // obtain public key from the payload
69
+ // do this early to avoid unnecessary processing
70
+ let task_public_key_bytes =
71
+ hex:: decode ( & task. public_key ) . wrap_err ( "could not decode public key" ) ?;
72
+ let task_public_key = PublicKey :: parse_slice ( & task_public_key_bytes, None ) ?;
73
+
65
74
// read model / provider from the task
66
75
let ( model_provider, model) = node
67
76
. config
@@ -80,65 +89,66 @@ impl WorkflowHandler {
80
89
} else {
81
90
Executor :: new ( model)
82
91
} ;
92
+
93
+ // prepare entry from prompt
83
94
let entry: Option < Entry > = task
84
95
. input
85
96
. prompt
86
97
. map ( |prompt| Entry :: try_value_or_str ( & prompt) ) ;
87
98
88
- // execute workflow with cancellation
89
- let mut memory = ProgramMemory :: new ( ) ;
90
-
91
- let exec_started_at = Instant :: now ( ) ;
92
- let exec_result = executor
93
- . execute ( entry. as_ref ( ) , & task. input . workflow , & mut memory)
94
- . await
95
- . map_err ( |e| eyre ! ( "Execution error: {}" , e. to_string( ) ) ) ;
96
- task_stats = task_stats. record_execution_time ( exec_started_at) ;
97
-
98
- Ok ( MessageAcceptance :: Accept )
99
+ // get workflow as well
100
+ let workflow = task. input . workflow ;
101
+
102
+ Ok ( Either :: Right ( WorkflowsWorkerInput {
103
+ entry,
104
+ executor,
105
+ workflow,
106
+ model_name,
107
+ task_id : task. task_id ,
108
+ public_key : task_public_key,
109
+ stats : task_stats,
110
+ } ) )
99
111
}
100
112
101
- async fn handle_publish (
113
+ pub ( crate ) async fn handle_publish (
102
114
node : & mut DriaComputeNode ,
103
- result : String ,
104
- task_id : String ,
105
- ) -> Result < ( ) > {
106
- let ( message, acceptance) = match exec_result {
115
+ task : WorkflowsWorkerOutput ,
116
+ ) -> Result < MessageAcceptance > {
117
+ let ( message, acceptance) = match task. result {
107
118
Ok ( result) => {
108
- // obtain public key from the payload
109
- let task_public_key_bytes =
110
- hex:: decode ( & task. public_key ) . wrap_err ( "Could not decode public key" ) ?;
111
- let task_public_key = PublicKey :: parse_slice ( & task_public_key_bytes, None ) ?;
112
-
113
119
// prepare signed and encrypted payload
114
120
let payload = TaskResponsePayload :: new (
115
121
result,
116
- & task_id,
117
- & task_public_key ,
122
+ & task . task_id ,
123
+ & task . public_key ,
118
124
& node. config . secret_key ,
119
- model_name,
120
- task_stats . record_published_at ( ) ,
125
+ task . model_name ,
126
+ task . stats . record_published_at ( ) ,
121
127
) ?;
122
128
let payload_str = serde_json:: to_string ( & payload)
123
129
. wrap_err ( "Could not serialize response payload" ) ?;
124
130
125
131
// prepare signed message
126
- log:: debug!( "Publishing result for task {}\n {}" , task_id, payload_str) ;
132
+ log:: debug!(
133
+ "Publishing result for task {}\n {}" ,
134
+ task. task_id,
135
+ payload_str
136
+ ) ;
127
137
let message = DKNMessage :: new ( payload_str, Self :: RESPONSE_TOPIC ) ;
128
138
// accept so that if there are others included in filter they can do the task
129
139
( message, MessageAcceptance :: Accept )
130
140
}
131
141
Err ( err) => {
132
142
// use pretty display string for error logging with causes
133
143
let err_string = format ! ( "{:#}" , err) ;
134
- log:: error!( "Task {} failed: {}" , task_id, err_string) ;
144
+ log:: error!( "Task {} failed: {}" , task . task_id, err_string) ;
135
145
136
146
// prepare error payload
137
147
let error_payload = TaskErrorPayload {
138
- task_id,
148
+ task_id : task . task_id . clone ( ) ,
139
149
error : err_string,
140
- model : model_name,
141
- stats : task_stats . record_published_at ( ) ,
150
+ model : task . model_name ,
151
+ stats : task . stats . record_published_at ( ) ,
142
152
} ;
143
153
let error_payload_str = serde_json:: to_string ( & error_payload)
144
154
. wrap_err ( "Could not serialize error payload" ) ?;
@@ -160,7 +170,7 @@ impl WorkflowHandler {
160
170
log:: error!( "{}" , err_msg) ;
161
171
162
172
let payload = serde_json:: json!( {
163
- "taskId" : task_id,
173
+ "taskId" : task . task_id,
164
174
"error" : err_msg,
165
175
} ) ;
166
176
let message = DKNMessage :: new_signed (
@@ -171,6 +181,6 @@ impl WorkflowHandler {
171
181
node. publish ( message) . await ?;
172
182
} ;
173
183
174
- Ok ( ( ) )
184
+ Ok ( acceptance )
175
185
}
176
186
}
0 commit comments