1
- use std:: time:: Instant ;
2
-
3
- use async_trait:: async_trait;
4
1
use dkn_p2p:: libp2p:: gossipsub:: MessageAcceptance ;
5
2
use dkn_workflows:: { Entry , Executor , ModelProvider , ProgramMemory , Workflow } ;
6
3
use eyre:: { eyre, Context , Result } ;
7
4
use libsecp256k1:: PublicKey ;
8
5
use serde:: Deserialize ;
6
+ use std:: time:: Instant ;
9
7
10
8
use crate :: payloads:: { TaskErrorPayload , TaskRequestPayload , TaskResponsePayload , TaskStats } ;
11
9
use crate :: utils:: { get_current_time_nanos, DKNMessage } ;
12
10
use crate :: DriaComputeNode ;
13
11
14
- use super :: ComputeHandler ;
15
-
16
12
pub struct WorkflowHandler ;
17
13
18
14
#[ derive( Debug , Deserialize ) ]
19
15
struct WorkflowPayload {
20
- /// [Workflow](https://github.com/andthattoo/ollama-workflows/) object to be parsed.
16
+ /// [Workflow](https://github.com/andthattoo/ollama-workflows/blob/main/src/program/workflow.rs ) object to be parsed.
21
17
pub ( crate ) workflow : Workflow ,
22
18
/// A lıst of model (that can be parsed into `Model`) or model provider names.
23
19
/// If model provider is given, the first matching model in the node config is used for that.
@@ -28,12 +24,11 @@ struct WorkflowPayload {
28
24
pub ( crate ) prompt : Option < String > ,
29
25
}
30
26
31
- #[ async_trait]
32
- impl ComputeHandler for WorkflowHandler {
33
- const LISTEN_TOPIC : & ' static str = "task" ;
34
- const RESPONSE_TOPIC : & ' static str = "results" ;
27
+ impl WorkflowHandler {
28
+ pub ( crate ) const LISTEN_TOPIC : & ' static str = "task" ;
29
+ pub ( crate ) const RESPONSE_TOPIC : & ' static str = "results" ;
35
30
36
- async fn handle_compute (
31
+ pub ( crate ) async fn handle_compute (
37
32
node : & mut DriaComputeNode ,
38
33
message : DKNMessage ,
39
34
) -> Result < MessageAcceptance > {
@@ -85,26 +80,29 @@ impl ComputeHandler for WorkflowHandler {
85
80
} else {
86
81
Executor :: new ( model)
87
82
} ;
88
- let mut memory = ProgramMemory :: new ( ) ;
89
83
let entry: Option < Entry > = task
90
84
. input
91
85
. prompt
92
86
. map ( |prompt| Entry :: try_value_or_str ( & prompt) ) ;
93
87
94
88
// execute workflow with cancellation
95
- let exec_result: Result < String > ;
89
+ let mut memory = ProgramMemory :: new ( ) ;
90
+
96
91
let exec_started_at = Instant :: now ( ) ;
97
- tokio:: select! {
98
- _ = node. cancellation. cancelled( ) => {
99
- log:: info!( "Received cancellation, quitting all tasks." ) ;
100
- return Ok ( MessageAcceptance :: Accept ) ;
101
- } ,
102
- exec_result_inner = executor. execute( entry. as_ref( ) , & task. input. workflow, & mut memory) => {
103
- exec_result = exec_result_inner. map_err( |e| eyre!( "Execution error: {}" , e. to_string( ) ) ) ;
104
- }
105
- }
92
+ let exec_result = executor
93
+ . execute ( entry. as_ref ( ) , & task. input . workflow , & mut memory)
94
+ . await
95
+ . map_err ( |e| eyre ! ( "Execution error: {}" , e. to_string( ) ) ) ;
106
96
task_stats = task_stats. record_execution_time ( exec_started_at) ;
107
97
98
+ Ok ( MessageAcceptance :: Accept )
99
+ }
100
+
101
+ async fn handle_publish (
102
+ node : & mut DriaComputeNode ,
103
+ result : String ,
104
+ task_id : String ,
105
+ ) -> Result < ( ) > {
108
106
let ( message, acceptance) = match exec_result {
109
107
Ok ( result) => {
110
108
// obtain public key from the payload
@@ -115,7 +113,7 @@ impl ComputeHandler for WorkflowHandler {
115
113
// prepare signed and encrypted payload
116
114
let payload = TaskResponsePayload :: new (
117
115
result,
118
- & task . task_id ,
116
+ & task_id,
119
117
& task_public_key,
120
118
& node. config . secret_key ,
121
119
model_name,
@@ -125,23 +123,19 @@ impl ComputeHandler for WorkflowHandler {
125
123
. wrap_err ( "Could not serialize response payload" ) ?;
126
124
127
125
// prepare signed message
128
- log:: debug!(
129
- "Publishing result for task {}\n {}" ,
130
- task. task_id,
131
- payload_str
132
- ) ;
126
+ log:: debug!( "Publishing result for task {}\n {}" , task_id, payload_str) ;
133
127
let message = DKNMessage :: new ( payload_str, Self :: RESPONSE_TOPIC ) ;
134
128
// accept so that if there are others included in filter they can do the task
135
129
( message, MessageAcceptance :: Accept )
136
130
}
137
131
Err ( err) => {
138
132
// use pretty display string for error logging with causes
139
133
let err_string = format ! ( "{:#}" , err) ;
140
- log:: error!( "Task {} failed: {}" , task . task_id, err_string) ;
134
+ log:: error!( "Task {} failed: {}" , task_id, err_string) ;
141
135
142
136
// prepare error payload
143
137
let error_payload = TaskErrorPayload {
144
- task_id : task . task_id . clone ( ) ,
138
+ task_id,
145
139
error : err_string,
146
140
model : model_name,
147
141
stats : task_stats. record_published_at ( ) ,
@@ -166,7 +160,7 @@ impl ComputeHandler for WorkflowHandler {
166
160
log:: error!( "{}" , err_msg) ;
167
161
168
162
let payload = serde_json:: json!( {
169
- "taskId" : task . task_id,
163
+ "taskId" : task_id,
170
164
"error" : err_msg,
171
165
} ) ;
172
166
let message = DKNMessage :: new_signed (
@@ -175,8 +169,8 @@ impl ComputeHandler for WorkflowHandler {
175
169
& node. config . secret_key ,
176
170
) ;
177
171
node. publish ( message) . await ?;
178
- }
172
+ } ;
179
173
180
- Ok ( acceptance )
174
+ Ok ( ( ) )
181
175
}
182
176
}
0 commit comments