Skip to content

Commit aa1d995

Browse files
committed
combine publish channels
1 parent 5629e1a commit aa1d995

File tree

5 files changed

+38
-54
lines changed

5 files changed

+38
-54
lines changed

compute/src/handlers/pingpong.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ struct PingpongPayload {
1919

2020
#[derive(Serialize, Deserialize, Debug, Clone)]
2121
struct PingpongResponse {
22+
/// UUID as given in the ping payload.
2223
pub(crate) uuid: String,
24+
/// Models available in the node.
2325
pub(crate) models: Vec<(ModelProvider, Model)>,
24-
pub(crate) timestamp: u128,
2526
/// Number of tasks in the channel currently, `single` and `batch`.
26-
pub(crate) tasks: [usize; 2],
27+
pub(crate) active_task_count: [usize; 2],
2728
}
2829

2930
impl PingpongHandler {
@@ -65,8 +66,7 @@ impl PingpongHandler {
6566
let response_body = PingpongResponse {
6667
uuid: pingpong.uuid.clone(),
6768
models: node.config.workflows.models.clone(),
68-
timestamp: get_current_time_nanos(),
69-
tasks: node.get_active_task_count(),
69+
active_task_count: node.get_active_task_count(),
7070
};
7171

7272
// publish message

compute/src/handlers/workflow.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ impl WorkflowHandler {
114114
)))
115115
}
116116

117+
/// Handles the result of a workflow task.
117118
pub(crate) async fn handle_publish(
118119
node: &mut DriaComputeNode,
119120
task: WorkflowsWorkerOutput,
@@ -131,8 +132,7 @@ impl WorkflowHandler {
131132
)?;
132133

133134
// convert payload to message
134-
let payload_str = serde_json::to_string(&payload)
135-
.wrap_err("could not serialize response payload")?;
135+
let payload_str = serde_json::json!(payload).to_string();
136136
log::debug!(
137137
"Publishing result for task {}\n{}",
138138
task.task_id,
@@ -152,8 +152,7 @@ impl WorkflowHandler {
152152
model: task.model_name,
153153
stats: task.stats.record_published_at(),
154154
};
155-
let error_payload_str = serde_json::to_string(&error_payload)
156-
.wrap_err("could not serialize error payload")?;
155+
let error_payload_str = serde_json::json!(error_payload).to_string();
157156

158157
// prepare signed message
159158
DKNMessage::new_signed(
@@ -178,6 +177,7 @@ impl WorkflowHandler {
178177
Self::RESPONSE_TOPIC,
179178
&node.config.secret_key,
180179
);
180+
181181
node.publish(message).await?;
182182
};
183183

compute/src/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,8 @@ async fn main() -> Result<()> {
7575
return Ok(());
7676
}
7777

78-
let node_token = token.clone();
79-
let (mut node, p2p, mut worker_batch, mut worker_single) =
80-
DriaComputeNode::new(config, node_token).await?;
78+
// create the node
79+
let (mut node, p2p, mut worker_batch, mut worker_single) = DriaComputeNode::new(config).await?;
8180

8281
log::info!("Spawning peer-to-peer client thread.");
8382
let p2p_handle = tokio::spawn(async move { p2p.run().await });
@@ -90,8 +89,9 @@ async fn main() -> Result<()> {
9089

9190
// launch the node in a separate thread
9291
log::info!("Spawning compute node thread.");
92+
let node_token = token.clone();
9393
let node_handle = tokio::spawn(async move {
94-
if let Err(err) = node.run().await {
94+
if let Err(err) = node.run(node_token).await {
9595
log::error!("Node launch error: {}", err);
9696
panic!("Node failed.")
9797
};

compute/src/node.rs

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,21 @@ use crate::{
2020
const DIAGNOSTIC_REFRESH_INTERVAL_SECS: u64 = 30;
2121
/// Number of seconds between refreshing the available nodes.
2222
const AVAILABLE_NODES_REFRESH_INTERVAL_SECS: u64 = 30 * 60; // 30 minutes
23+
/// Buffer size for message publishes.
24+
const PUBLISH_CHANNEL_BUFSIZE: usize = 1024;
2325

2426
pub struct DriaComputeNode {
2527
pub config: DriaComputeNodeConfig,
2628
pub p2p: DriaP2PCommander,
2729
pub available_nodes: AvailableNodes,
28-
pub cancellation: CancellationToken,
2930
/// Gossipsub message receiver.
3031
message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>,
32+
/// Publish receiver to receive messages to be published.
33+
publish_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
3134
/// Workflow transmitter to send batchable tasks.
3235
workflow_batch_tx: mpsc::Sender<WorkflowsWorkerInput>,
33-
/// Publish receiver to receive messages to be published.
34-
publish_batch_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
3536
/// Workflow transmitter to send single tasks.
3637
workflow_single_tx: mpsc::Sender<WorkflowsWorkerInput>,
37-
/// Publish receiver to receive messages to be published.
38-
publish_single_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
3938
}
4039

4140
impl DriaComputeNode {
@@ -44,7 +43,6 @@ impl DriaComputeNode {
4443
/// Returns the node instance and p2p client together. P2p MUST be run in a separate task before this node is used at all.
4544
pub async fn new(
4645
config: DriaComputeNodeConfig,
47-
cancellation: CancellationToken,
4846
) -> Result<(
4947
DriaComputeNode,
5048
DriaP2PClient,
@@ -77,22 +75,20 @@ impl DriaComputeNode {
7775
protocol,
7876
)?;
7977

80-
// create workflow workers
81-
let (workflows_batch_worker, workflow_batch_tx, publish_batch_rx) = WorkflowsWorker::new();
82-
let (workflows_single_worker, workflow_single_tx, publish_single_rx) =
83-
WorkflowsWorker::new();
78+
// create workflow workers, all workers use the same publish channel
79+
let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE);
80+
let (workflows_batch_worker, workflow_batch_tx) = WorkflowsWorker::new(publish_tx.clone());
81+
let (workflows_single_worker, workflow_single_tx) = WorkflowsWorker::new(publish_tx);
8482

8583
Ok((
8684
DriaComputeNode {
8785
config,
8886
p2p: p2p_commander,
89-
cancellation,
9087
available_nodes,
9188
message_rx,
89+
publish_rx,
9290
workflow_batch_tx,
93-
publish_batch_rx,
9491
workflow_single_tx,
95-
publish_single_rx,
9692
},
9793
p2p_client,
9894
workflows_batch_worker,
@@ -248,8 +244,8 @@ impl DriaComputeNode {
248244
}
249245

250246
/// Runs the main loop of the compute node.
251-
/// This method is not expected to return until cancellation occurs.
252-
pub async fn run(&mut self) -> Result<()> {
247+
/// This method is not expected to return until cancellation occurs for the given token.
248+
pub async fn run(&mut self, cancellation: CancellationToken) -> Result<()> {
253249
// prepare durations for sleeps
254250
let mut peer_refresh_interval =
255251
tokio::time::interval(Duration::from_secs(DIAGNOSTIC_REFRESH_INTERVAL_SECS));
@@ -270,16 +266,7 @@ impl DriaComputeNode {
270266
_ = available_node_refresh_interval.tick() => self.handle_available_nodes_refresh().await,
271267
// a Workflow message to be published is received from the channel
272268
// this is expected to be sent by the workflow worker
273-
publish_msg = self.publish_batch_rx.recv() => {
274-
if let Some(result) = publish_msg {
275-
WorkflowHandler::handle_publish(self, result).await?;
276-
} else {
277-
log::error!("Publish channel closed unexpectedly.");
278-
break;
279-
};
280-
},
281-
// TODO: make the both receivers handled together somehow
282-
publish_msg = self.publish_single_rx.recv() => {
269+
publish_msg = self.publish_rx.recv() => {
283270
if let Some(result) = publish_msg {
284271
WorkflowHandler::handle_publish(self, result).await?;
285272
} else {
@@ -306,7 +293,7 @@ impl DriaComputeNode {
306293
},
307294
// check if the cancellation token is cancelled
308295
// this is expected to be cancelled by the main thread with signal handling
309-
_ = self.cancellation.cancelled() => break,
296+
_ = cancellation.cancelled() => break,
310297
}
311298
}
312299

@@ -331,7 +318,7 @@ impl DriaComputeNode {
331318
self.message_rx.close();
332319

333320
log::debug!("Closing publish channel.");
334-
self.publish_batch_rx.close();
321+
self.publish_rx.close();
335322

336323
Ok(())
337324
}
@@ -345,8 +332,8 @@ impl DriaComputeNode {
345332
}
346333

347334
// print task counts
348-
let [single, batch] = self.get_active_task_count();
349-
log::info!("Active Task Count (single/batch): {} / {}", single, batch);
335+
// let [single, batch] = self.get_active_task_count();
336+
// log::info!("Active Task Count (single/batch): {} / {}", single, batch);
350337
}
351338

352339
/// Updates the local list of available nodes by refreshing it.
@@ -382,18 +369,18 @@ mod tests {
382369

383370
// create node
384371
let cancellation = CancellationToken::new();
385-
let (mut node, p2p, _, _) =
386-
DriaComputeNode::new(DriaComputeNodeConfig::default(), cancellation.clone())
387-
.await
388-
.expect("should create node");
372+
let (mut node, p2p, _, _) = DriaComputeNode::new(DriaComputeNodeConfig::default())
373+
.await
374+
.expect("should create node");
389375

390376
// spawn p2p task
391377
let p2p_task = tokio::spawn(async move { p2p.run().await });
392378

393379
// launch & wait for a while for connections
394380
log::info!("Waiting a bit for peer setup.");
381+
let run_cancellation = cancellation.clone();
395382
tokio::select! {
396-
_ = node.run() => (),
383+
_ = node.run(run_cancellation) => (),
397384
_ = tokio::time::sleep(tokio::time::Duration::from_secs(20)) => cancellation.cancel(),
398385
}
399386
log::info!("Connected Peers:\n{:#?}", node.peers().await?);

compute/src/workers/workflow.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ pub struct WorkflowsWorker {
2929
publish_tx: mpsc::Sender<WorkflowsWorkerOutput>,
3030
}
3131

32+
/// Buffer size for workflow tasks (per worker).
3233
const WORKFLOW_CHANNEL_BUFSIZE: usize = 1024;
33-
const PUBLISH_CHANNEL_BUFSIZE: usize = 1024;
3434

3535
impl WorkflowsWorker {
3636
/// Batch size that defines how many tasks can be executed in parallel at once.
@@ -39,24 +39,21 @@ impl WorkflowsWorker {
3939
const BATCH_SIZE: usize = 8;
4040

4141
/// Creates a worker and returns the sender and receiver for the worker.
42-
pub fn new() -> (
43-
WorkflowsWorker,
44-
mpsc::Sender<WorkflowsWorkerInput>,
45-
mpsc::Receiver<WorkflowsWorkerOutput>,
46-
) {
42+
pub fn new(
43+
publish_tx: mpsc::Sender<WorkflowsWorkerOutput>,
44+
) -> (WorkflowsWorker, mpsc::Sender<WorkflowsWorkerInput>) {
4745
let (workflow_tx, workflow_rx) = mpsc::channel(WORKFLOW_CHANNEL_BUFSIZE);
48-
let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE);
4946

5047
(
5148
Self {
5249
workflow_rx,
5350
publish_tx,
5451
},
5552
workflow_tx,
56-
publish_rx,
5753
)
5854
}
5955

56+
/// Closes the workflow receiver channel.
6057
fn shutdown(&mut self) {
6158
log::warn!("Closing workflows worker.");
6259
self.workflow_rx.close();

0 commit comments

Comments
 (0)