@@ -6,6 +6,7 @@ use dkn_p2p::{
6
6
DriaP2PClient , DriaP2PCommander , DriaP2PProtocol ,
7
7
} ;
8
8
use eyre:: Result ;
9
+ use std:: collections:: HashSet ;
9
10
use tokio:: { sync:: mpsc, time:: Duration } ;
10
11
use tokio_util:: { either:: Either , sync:: CancellationToken } ;
11
12
@@ -32,9 +33,15 @@ pub struct DriaComputeNode {
32
33
/// Publish receiver to receive messages to be published.
33
34
publish_rx : mpsc:: Receiver < WorkflowsWorkerOutput > ,
34
35
/// Workflow transmitter to send batchable tasks.
35
- workflow_batch_tx : mpsc:: Sender < WorkflowsWorkerInput > ,
36
+ workflow_batch_tx : Option < mpsc:: Sender < WorkflowsWorkerInput > > ,
36
37
/// Workflow transmitter to send single tasks.
37
- workflow_single_tx : mpsc:: Sender < WorkflowsWorkerInput > ,
38
+ workflow_single_tx : Option < mpsc:: Sender < WorkflowsWorkerInput > > ,
39
+ // TODO: instead of piggybacking task metadata within channels, we can store them here
40
+ // in a hashmap alone, and then use the task_id to get the metadata when needed
41
+ // Single tasks hash-map
42
+ pending_tasks_single : HashSet < String > ,
43
+ // Batch tasks hash-map
44
+ pending_tasks_batch : HashSet < String > ,
38
45
}
39
46
40
47
impl DriaComputeNode {
@@ -46,8 +53,8 @@ impl DriaComputeNode {
46
53
) -> Result < (
47
54
DriaComputeNode ,
48
55
DriaP2PClient ,
49
- WorkflowsWorker ,
50
- WorkflowsWorker ,
56
+ Option < WorkflowsWorker > ,
57
+ Option < WorkflowsWorker > ,
51
58
) > {
52
59
// create the keypair from secret key
53
60
let keypair = secret_to_keypair ( & config. secret_key ) ;
@@ -77,8 +84,24 @@ impl DriaComputeNode {
77
84
78
85
// create workflow workers, all workers use the same publish channel
79
86
let ( publish_tx, publish_rx) = mpsc:: channel ( PUBLISH_CHANNEL_BUFSIZE ) ;
80
- let ( workflows_batch_worker, workflow_batch_tx) = WorkflowsWorker :: new ( publish_tx. clone ( ) ) ;
81
- let ( workflows_single_worker, workflow_single_tx) = WorkflowsWorker :: new ( publish_tx) ;
87
+
88
+ // check if we should create a worker for batchable workflows
89
+ let ( workflows_batch_worker, workflow_batch_tx) = if config. workflows . has_batchable_models ( )
90
+ {
91
+ let worker = WorkflowsWorker :: new ( publish_tx. clone ( ) ) ;
92
+ ( Some ( worker. 0 ) , Some ( worker. 1 ) )
93
+ } else {
94
+ ( None , None )
95
+ } ;
96
+
97
+ // check if we should create a worker for single workflows
98
+ let ( workflows_single_worker, workflow_single_tx) =
99
+ if config. workflows . has_non_batchable_models ( ) {
100
+ let worker = WorkflowsWorker :: new ( publish_tx) ;
101
+ ( Some ( worker. 0 ) , Some ( worker. 1 ) )
102
+ } else {
103
+ ( None , None )
104
+ } ;
82
105
83
106
Ok ( (
84
107
DriaComputeNode {
@@ -89,6 +112,8 @@ impl DriaComputeNode {
89
112
publish_rx,
90
113
workflow_batch_tx,
91
114
workflow_single_tx,
115
+ pending_tasks_single : HashSet :: new ( ) ,
116
+ pending_tasks_batch : HashSet :: new ( ) ,
92
117
} ,
93
118
p2p_client,
94
119
workflows_batch_worker,
@@ -119,10 +144,10 @@ impl DriaComputeNode {
119
144
}
120
145
121
146
/// Returns the task count within the channels, `single` and `batch`.
122
- pub fn get_active_task_count ( & self ) -> [ usize ; 2 ] {
147
+ pub fn get_pending_task_count ( & self ) -> [ usize ; 2 ] {
123
148
[
124
- self . workflow_single_tx . max_capacity ( ) - self . workflow_single_tx . capacity ( ) ,
125
- self . workflow_batch_tx . max_capacity ( ) - self . workflow_batch_tx . capacity ( ) ,
149
+ self . pending_tasks_single . len ( ) ,
150
+ self . pending_tasks_batch . len ( ) ,
126
151
]
127
152
}
128
153
@@ -202,10 +227,32 @@ impl DriaComputeNode {
202
227
// we got acceptance, so something was not right about the workflow and we can ignore it
203
228
Ok ( Either :: Left ( acceptance) ) => Ok ( acceptance) ,
204
229
// we got the parsed workflow itself, send to a worker thread w.r.t batchable
205
- Ok ( Either :: Right ( ( workflow_message, batchable) ) ) => {
206
- if let Err ( e) = match batchable {
207
- true => self . workflow_batch_tx . send ( workflow_message) . await ,
208
- false => self . workflow_single_tx . send ( workflow_message) . await ,
230
+ Ok ( Either :: Right ( workflow_message) ) => {
231
+ if let Err ( e) = match workflow_message. batchable {
232
+ // this is a batchable task, send it to batch worker
233
+ // and keep track of the task id in pending tasks
234
+ true => match self . workflow_batch_tx {
235
+ Some ( ref mut tx) => {
236
+ self . pending_tasks_batch
237
+ . insert ( workflow_message. task_id . clone ( ) ) ;
238
+ tx. send ( workflow_message) . await
239
+ }
240
+ None => unreachable ! (
241
+ "Batchable workflow received but no worker available."
242
+ ) ,
243
+ } ,
244
+ // this is a single task, send it to single worker
245
+ // and keep track of the task id in pending tasks
246
+ false => match self . workflow_single_tx {
247
+ Some ( ref mut tx) => {
248
+ self . pending_tasks_single
249
+ . insert ( workflow_message. task_id . clone ( ) ) ;
250
+ tx. send ( workflow_message) . await
251
+ }
252
+ None => unreachable ! (
253
+ "Single workflow received but no worker available."
254
+ ) ,
255
+ } ,
209
256
} {
210
257
log:: error!( "Error sending workflow message: {:?}" , e) ;
211
258
} ;
@@ -266,18 +313,25 @@ impl DriaComputeNode {
266
313
_ = available_node_refresh_interval. tick( ) => self . handle_available_nodes_refresh( ) . await ,
267
314
// a Workflow message to be published is received from the channel
268
315
// this is expected to be sent by the workflow worker
269
- publish_msg = self . publish_rx. recv( ) => {
270
- if let Some ( result) = publish_msg {
271
- WorkflowHandler :: handle_publish( self , result) . await ?;
316
+ publish_msg_opt = self . publish_rx. recv( ) => {
317
+ if let Some ( publish_msg) = publish_msg_opt {
318
+ // remove the task from pending tasks based on its batchability
319
+ match publish_msg. batchable {
320
+ true => self . pending_tasks_batch. remove( & publish_msg. task_id) ,
321
+ false => self . pending_tasks_single. remove( & publish_msg. task_id) ,
322
+ } ;
323
+
324
+ // publish the message
325
+ WorkflowHandler :: handle_publish( self , publish_msg) . await ?;
272
326
} else {
273
327
log:: error!( "Publish channel closed unexpectedly." ) ;
274
328
break ;
275
329
} ;
276
330
} ,
277
331
// a GossipSub message is received from the channel
278
332
// this is expected to be sent by the p2p client
279
- gossipsub_msg = self . message_rx. recv( ) => {
280
- if let Some ( ( peer_id, message_id, message) ) = gossipsub_msg {
333
+ gossipsub_msg_opt = self . message_rx. recv( ) => {
334
+ if let Some ( ( peer_id, message_id, message) ) = gossipsub_msg_opt {
281
335
// handle the message, returning a message acceptance for the received one
282
336
let acceptance = self . handle_message( ( peer_id, & message_id, message) ) . await ;
283
337
@@ -332,8 +386,8 @@ impl DriaComputeNode {
332
386
}
333
387
334
388
// print task counts
335
- // let [single, batch] = self.get_active_task_count ();
336
- // log::info!("Active Task Count (single/batch): {} / {}", single, batch);
389
+ let [ single, batch] = self . get_pending_task_count ( ) ;
390
+ log:: info!( "Pending Task Count (single/batch): {} / {}" , single, batch) ;
337
391
}
338
392
339
393
/// Updates the local list of available nodes by refreshing it.
0 commit comments