Skip to content

Commit 0365ded

Browse files
committed
optional threads w.r.t model
1 parent aa1d995 commit 0365ded

File tree

6 files changed

+141
-83
lines changed

6 files changed

+141
-83
lines changed

compute/src/handlers/pingpong.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct PingpongResponse {
2424
/// Models available in the node.
2525
pub(crate) models: Vec<(ModelProvider, Model)>,
2626
/// Number of tasks in the channel currently, `single` and `batch`.
27-
pub(crate) active_task_count: [usize; 2],
27+
pub(crate) pending_tasks: [usize; 2],
2828
}
2929

3030
impl PingpongHandler {
@@ -66,7 +66,7 @@ impl PingpongHandler {
6666
let response_body = PingpongResponse {
6767
uuid: pingpong.uuid.clone(),
6868
models: node.config.workflows.models.clone(),
69-
active_task_count: node.get_active_task_count(),
69+
pending_tasks: node.get_pending_task_count(),
7070
};
7171

7272
// publish message

compute/src/handlers/workflow.rs

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl WorkflowHandler {
3232
pub(crate) async fn handle_compute(
3333
node: &mut DriaComputeNode,
3434
compute_message: &DKNMessage,
35-
) -> Result<Either<MessageAcceptance, (WorkflowsWorkerInput, bool)>> {
35+
) -> Result<Either<MessageAcceptance, WorkflowsWorkerInput>> {
3636
let stats = TaskStats::new().record_received_at();
3737
let task = compute_message
3838
.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)
@@ -100,18 +100,16 @@ impl WorkflowHandler {
100100
// get workflow as well
101101
let workflow = task.input.workflow;
102102

103-
Ok(Either::Right((
104-
WorkflowsWorkerInput {
105-
entry,
106-
executor,
107-
workflow,
108-
model_name,
109-
task_id: task.task_id,
110-
public_key: task_public_key,
111-
stats,
112-
},
103+
Ok(Either::Right(WorkflowsWorkerInput {
104+
entry,
105+
executor,
106+
workflow,
107+
model_name,
108+
task_id: task.task_id,
109+
public_key: task_public_key,
110+
stats,
113111
batchable,
114-
)))
112+
}))
115113
}
116114

117115
/// Handles the result of a workflow task.

compute/src/main.rs

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use dkn_compute::*;
22
use eyre::Result;
33
use std::env;
4-
use tokio_util::sync::CancellationToken;
4+
use tokio_util::{sync::CancellationToken, task::TaskTracker};
55

66
#[tokio::main]
77
async fn main() -> Result<()> {
88
let dotenv_result = dotenvy::dotenv();
9+
910
// TODO: remove me later when the launcher is fixed
1011
amend_log_levels();
1112

@@ -28,88 +29,80 @@ async fn main() -> Result<()> {
2829
"#
2930
);
3031

31-
let token = CancellationToken::new();
32-
let cancellation_token = token.clone();
32+
// task tracker for multiple threads
33+
let task_tracker = TaskTracker::new();
34+
let cancellation = CancellationToken::new();
35+
36+
// spawn the background task to wait for termination signals
37+
let task_tracker_to_close = task_tracker.clone();
38+
let cancellation_token = cancellation.clone();
3339
tokio::spawn(async move {
34-
// the timeout is done for profiling only, and should not be used in production
3540
if let Ok(Ok(duration_secs)) =
3641
env::var("DKN_EXIT_TIMEOUT").map(|s| s.to_string().parse::<u64>())
3742
{
43+
// the timeout is done for profiling only, and should not be used in production
3844
log::warn!("Waiting for {} seconds before exiting.", duration_secs);
3945
tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await;
4046

4147
log::warn!("Exiting due to DKN_EXIT_TIMEOUT.");
48+
4249
cancellation_token.cancel();
4350
} else if let Err(err) = wait_for_termination(cancellation_token.clone()).await {
51+
// if there is no timeout, we wait for termination signals here
4452
log::error!("Error waiting for termination: {:?}", err);
4553
log::error!("Cancelling due to unexpected error.");
4654
cancellation_token.cancel();
4755
};
56+
57+
// close tracker in any case
58+
task_tracker_to_close.close();
4859
});
4960

5061
// create configurations & check required services & address in use
5162
let mut config = DriaComputeNodeConfig::new();
5263
config.assert_address_not_in_use()?;
53-
let service_check_token = token.clone();
54-
let config = tokio::spawn(async move {
55-
tokio::select! {
56-
result = config.workflows.check_services() => {
57-
if let Err(err) = result {
58-
log::error!("Error checking services: {:?}", err);
59-
panic!("Service check failed.")
60-
}
61-
log::warn!("Using models: {:#?}", config.workflows.models);
62-
config
63-
}
64-
_ = service_check_token.cancelled() => {
65-
log::info!("Service check cancelled.");
66-
config
67-
}
64+
// check services & models, will exit if there is an error
65+
// since service check can take time, we allow early-exit here as well
66+
tokio::select! {
67+
result = config.workflows.check_services() => result,
68+
_ = cancellation.cancelled() => {
69+
log::info!("Service check cancelled, exiting.");
70+
return Ok(());
6871
}
69-
})
70-
.await?;
71-
72-
// check early exit due to failed service check
73-
if token.is_cancelled() {
74-
log::warn!("Not launching node due to early exit, bye!");
75-
return Ok(());
76-
}
72+
}?;
73+
log::warn!("Using models: {:#?}", config.workflows.models);
7774

7875
// create the node
79-
let (mut node, p2p, mut worker_batch, mut worker_single) = DriaComputeNode::new(config).await?;
76+
let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?;
8077

78+
// spawn threads
8179
log::info!("Spawning peer-to-peer client thread.");
82-
let p2p_handle = tokio::spawn(async move { p2p.run().await });
80+
task_tracker.spawn(async move { p2p.run().await });
8381

84-
log::info!("Spawning workflows batch worker thread.");
85-
let worker_batch_handle = tokio::spawn(async move { worker_batch.run_batch().await });
82+
if let Some(mut worker_batch) = worker_batch {
83+
log::info!("Spawning workflows batch worker thread.");
84+
task_tracker.spawn(async move { worker_batch.run_batch().await });
85+
}
8686

87-
log::info!("Spawning workflows single worker thread.");
88-
let worker_single_handle = tokio::spawn(async move { worker_single.run().await });
87+
if let Some(mut worker_single) = worker_single {
88+
log::info!("Spawning workflows single worker thread.");
89+
task_tracker.spawn(async move { worker_single.run().await });
90+
}
8991

9092
// launch the node in a separate thread
9193
log::info!("Spawning compute node thread.");
92-
let node_token = token.clone();
93-
let node_handle = tokio::spawn(async move {
94+
let node_token = cancellation.clone();
95+
task_tracker.spawn(async move {
9496
if let Err(err) = node.run(node_token).await {
9597
log::error!("Node launch error: {}", err);
9698
panic!("Node failed.")
9799
};
100+
log::info!("Closing node.")
98101
});
99102

100-
// wait for tasks to complete
101-
if let Err(err) = node_handle.await {
102-
log::error!("Node handle error: {}", err);
103-
};
104-
if let Err(err) = worker_single_handle.await {
105-
log::error!("Workflows single worker handle error: {}", err);
106-
};
107-
if let Err(err) = worker_batch_handle.await {
108-
log::error!("Workflows batch worker handle error: {}", err);
109-
};
110-
if let Err(err) = p2p_handle.await {
111-
log::error!("P2P handle error: {}", err);
112-
};
103+
// wait for all tasks to finish
104+
task_tracker.wait().await;
105+
log::info!("All tasks have exited succesfully.");
113106

114107
log::info!("Bye!");
115108
Ok(())
@@ -168,7 +161,7 @@ async fn wait_for_termination(cancellation: CancellationToken) -> Result<()> {
168161
cancellation.cancel();
169162
}
170163

171-
log::info!("Terminating the node...");
164+
log::info!("Terminating the application...");
172165

173166
Ok(())
174167
}

compute/src/node.rs

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use dkn_p2p::{
66
DriaP2PClient, DriaP2PCommander, DriaP2PProtocol,
77
};
88
use eyre::Result;
9+
use std::collections::HashSet;
910
use tokio::{sync::mpsc, time::Duration};
1011
use tokio_util::{either::Either, sync::CancellationToken};
1112

@@ -32,9 +33,15 @@ pub struct DriaComputeNode {
3233
/// Publish receiver to receive messages to be published.
3334
publish_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
3435
/// Workflow transmitter to send batchable tasks.
35-
workflow_batch_tx: mpsc::Sender<WorkflowsWorkerInput>,
36+
workflow_batch_tx: Option<mpsc::Sender<WorkflowsWorkerInput>>,
3637
/// Workflow transmitter to send single tasks.
37-
workflow_single_tx: mpsc::Sender<WorkflowsWorkerInput>,
38+
workflow_single_tx: Option<mpsc::Sender<WorkflowsWorkerInput>>,
39+
// TODO: instead of piggybacking task metadata within channels, we can store them here
40+
// in a hashmap alone, and then use the task_id to get the metadata when needed
41+
// Single tasks hash-map
42+
pending_tasks_single: HashSet<String>,
43+
// Batch tasks hash-map
44+
pending_tasks_batch: HashSet<String>,
3845
}
3946

4047
impl DriaComputeNode {
@@ -46,8 +53,8 @@ impl DriaComputeNode {
4653
) -> Result<(
4754
DriaComputeNode,
4855
DriaP2PClient,
49-
WorkflowsWorker,
50-
WorkflowsWorker,
56+
Option<WorkflowsWorker>,
57+
Option<WorkflowsWorker>,
5158
)> {
5259
// create the keypair from secret key
5360
let keypair = secret_to_keypair(&config.secret_key);
@@ -77,8 +84,24 @@ impl DriaComputeNode {
7784

7885
// create workflow workers, all workers use the same publish channel
7986
let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE);
80-
let (workflows_batch_worker, workflow_batch_tx) = WorkflowsWorker::new(publish_tx.clone());
81-
let (workflows_single_worker, workflow_single_tx) = WorkflowsWorker::new(publish_tx);
87+
88+
// check if we should create a worker for batchable workflows
89+
let (workflows_batch_worker, workflow_batch_tx) = if config.workflows.has_batchable_models()
90+
{
91+
let worker = WorkflowsWorker::new(publish_tx.clone());
92+
(Some(worker.0), Some(worker.1))
93+
} else {
94+
(None, None)
95+
};
96+
97+
// check if we should create a worker for single workflows
98+
let (workflows_single_worker, workflow_single_tx) =
99+
if config.workflows.has_non_batchable_models() {
100+
let worker = WorkflowsWorker::new(publish_tx);
101+
(Some(worker.0), Some(worker.1))
102+
} else {
103+
(None, None)
104+
};
82105

83106
Ok((
84107
DriaComputeNode {
@@ -89,6 +112,8 @@ impl DriaComputeNode {
89112
publish_rx,
90113
workflow_batch_tx,
91114
workflow_single_tx,
115+
pending_tasks_single: HashSet::new(),
116+
pending_tasks_batch: HashSet::new(),
92117
},
93118
p2p_client,
94119
workflows_batch_worker,
@@ -119,10 +144,10 @@ impl DriaComputeNode {
119144
}
120145

121146
/// Returns the task count within the channels, `single` and `batch`.
122-
pub fn get_active_task_count(&self) -> [usize; 2] {
147+
pub fn get_pending_task_count(&self) -> [usize; 2] {
123148
[
124-
self.workflow_single_tx.max_capacity() - self.workflow_single_tx.capacity(),
125-
self.workflow_batch_tx.max_capacity() - self.workflow_batch_tx.capacity(),
149+
self.pending_tasks_single.len(),
150+
self.pending_tasks_batch.len(),
126151
]
127152
}
128153

@@ -202,10 +227,32 @@ impl DriaComputeNode {
202227
// we got acceptance, so something was not right about the workflow and we can ignore it
203228
Ok(Either::Left(acceptance)) => Ok(acceptance),
204229
// we got the parsed workflow itself, send to a worker thread w.r.t batchable
205-
Ok(Either::Right((workflow_message, batchable))) => {
206-
if let Err(e) = match batchable {
207-
true => self.workflow_batch_tx.send(workflow_message).await,
208-
false => self.workflow_single_tx.send(workflow_message).await,
230+
Ok(Either::Right(workflow_message)) => {
231+
if let Err(e) = match workflow_message.batchable {
232+
// this is a batchable task, send it to batch worker
233+
// and keep track of the task id in pending tasks
234+
true => match self.workflow_batch_tx {
235+
Some(ref mut tx) => {
236+
self.pending_tasks_batch
237+
.insert(workflow_message.task_id.clone());
238+
tx.send(workflow_message).await
239+
}
240+
None => unreachable!(
241+
"Batchable workflow received but no worker available."
242+
),
243+
},
244+
// this is a single task, send it to single worker
245+
// and keep track of the task id in pending tasks
246+
false => match self.workflow_single_tx {
247+
Some(ref mut tx) => {
248+
self.pending_tasks_single
249+
.insert(workflow_message.task_id.clone());
250+
tx.send(workflow_message).await
251+
}
252+
None => unreachable!(
253+
"Single workflow received but no worker available."
254+
),
255+
},
209256
} {
210257
log::error!("Error sending workflow message: {:?}", e);
211258
};
@@ -266,18 +313,25 @@ impl DriaComputeNode {
266313
_ = available_node_refresh_interval.tick() => self.handle_available_nodes_refresh().await,
267314
// a Workflow message to be published is received from the channel
268315
// this is expected to be sent by the workflow worker
269-
publish_msg = self.publish_rx.recv() => {
270-
if let Some(result) = publish_msg {
271-
WorkflowHandler::handle_publish(self, result).await?;
316+
publish_msg_opt = self.publish_rx.recv() => {
317+
if let Some(publish_msg) = publish_msg_opt {
318+
// remove the task from pending tasks based on its batchability
319+
match publish_msg.batchable {
320+
true => self.pending_tasks_batch.remove(&publish_msg.task_id),
321+
false => self.pending_tasks_single.remove(&publish_msg.task_id),
322+
};
323+
324+
// publish the message
325+
WorkflowHandler::handle_publish(self, publish_msg).await?;
272326
} else {
273327
log::error!("Publish channel closed unexpectedly.");
274328
break;
275329
};
276330
},
277331
// a GossipSub message is received from the channel
278332
// this is expected to be sent by the p2p client
279-
gossipsub_msg = self.message_rx.recv() => {
280-
if let Some((peer_id, message_id, message)) = gossipsub_msg {
333+
gossipsub_msg_opt = self.message_rx.recv() => {
334+
if let Some((peer_id, message_id, message)) = gossipsub_msg_opt {
281335
// handle the message, returning a message acceptance for the received one
282336
let acceptance = self.handle_message((peer_id, &message_id, message)).await;
283337

@@ -332,8 +386,8 @@ impl DriaComputeNode {
332386
}
333387

334388
// print task counts
335-
// let [single, batch] = self.get_active_task_count();
336-
// log::info!("Active Task Count (single/batch): {} / {}", single, batch);
389+
let [single, batch] = self.get_pending_task_count();
390+
log::info!("Pending Task Count (single/batch): {} / {}", single, batch);
337391
}
338392

339393
/// Updates the local list of available nodes by refreshing it.

compute/src/workers/workflow.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub struct WorkflowsWorkerInput {
1313
pub task_id: String,
1414
pub model_name: String,
1515
pub stats: TaskStats,
16+
pub batchable: bool,
1617
}
1718

1819
pub struct WorkflowsWorkerOutput {
@@ -22,6 +23,7 @@ pub struct WorkflowsWorkerOutput {
2223
pub task_id: String,
2324
pub model_name: String,
2425
pub stats: TaskStats,
26+
pub batchable: bool,
2527
}
2628

2729
pub struct WorkflowsWorker {
@@ -217,6 +219,7 @@ impl WorkflowsWorker {
217219
public_key: input.public_key,
218220
task_id: input.task_id,
219221
model_name: input.model_name,
222+
batchable: input.batchable,
220223
stats: input.stats.record_execution_time(started_at),
221224
}
222225
}

0 commit comments

Comments
 (0)