Skip to content

Commit 601e0e8

Browse files
committed
parallel workflows first version works
1 parent 7383c05 commit 601e0e8

File tree

6 files changed

+241
-52
lines changed

6 files changed

+241
-52
lines changed

compute/src/handlers/workflow.rs

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use dkn_p2p::libp2p::gossipsub::MessageAcceptance;
2-
use dkn_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow};
3-
use eyre::{eyre, Context, Result};
2+
use dkn_workflows::{Entry, Executor, ModelProvider, Workflow};
3+
use eyre::{Context, Result};
44
use libsecp256k1::PublicKey;
55
use serde::Deserialize;
6-
use std::time::Instant;
6+
use tokio_util::either::Either;
77

88
use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload, TaskStats};
99
use crate::utils::{get_current_time_nanos, DKNMessage};
10+
use crate::workers::workflow::*;
1011
use crate::DriaComputeNode;
1112

1213
pub struct WorkflowHandler;
@@ -31,11 +32,13 @@ impl WorkflowHandler {
3132
pub(crate) async fn handle_compute(
3233
node: &mut DriaComputeNode,
3334
message: DKNMessage,
34-
) -> Result<MessageAcceptance> {
35+
) -> Result<Either<MessageAcceptance, WorkflowsWorkerInput>> {
3536
let task = message
3637
.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)
3738
.wrap_err("Could not parse workflow task")?;
38-
let mut task_stats = TaskStats::default().record_received_at();
39+
40+
// TODO: !!!
41+
let task_stats = TaskStats::default().record_received_at();
3942

4043
// check if deadline is past or not
4144
let current_time = get_current_time_nanos();
@@ -48,7 +51,7 @@ impl WorkflowHandler {
4851
);
4952

5053
// ignore the message
51-
return Ok(MessageAcceptance::Ignore);
54+
return Ok(Either::Left(MessageAcceptance::Ignore));
5255
}
5356

5457
// check task inclusion via the bloom filter
@@ -59,9 +62,15 @@ impl WorkflowHandler {
5962
);
6063

6164
// accept the message, someone else may be included in filter
62-
return Ok(MessageAcceptance::Accept);
65+
return Ok(Either::Left(MessageAcceptance::Accept));
6366
}
6467

68+
// obtain public key from the payload
69+
// do this early to avoid unnecessary processing
70+
let task_public_key_bytes =
71+
hex::decode(&task.public_key).wrap_err("could not decode public key")?;
72+
let task_public_key = PublicKey::parse_slice(&task_public_key_bytes, None)?;
73+
6574
// read model / provider from the task
6675
let (model_provider, model) = node
6776
.config
@@ -80,65 +89,66 @@ impl WorkflowHandler {
8089
} else {
8190
Executor::new(model)
8291
};
92+
93+
// prepare entry from prompt
8394
let entry: Option<Entry> = task
8495
.input
8596
.prompt
8697
.map(|prompt| Entry::try_value_or_str(&prompt));
8798

88-
// execute workflow with cancellation
89-
let mut memory = ProgramMemory::new();
90-
91-
let exec_started_at = Instant::now();
92-
let exec_result = executor
93-
.execute(entry.as_ref(), &task.input.workflow, &mut memory)
94-
.await
95-
.map_err(|e| eyre!("Execution error: {}", e.to_string()));
96-
task_stats = task_stats.record_execution_time(exec_started_at);
97-
98-
Ok(MessageAcceptance::Accept)
99+
// get workflow as well
100+
let workflow = task.input.workflow;
101+
102+
Ok(Either::Right(WorkflowsWorkerInput {
103+
entry,
104+
executor,
105+
workflow,
106+
model_name,
107+
task_id: task.task_id,
108+
public_key: task_public_key,
109+
stats: task_stats,
110+
}))
99111
}
100112

101-
async fn handle_publish(
113+
pub(crate) async fn handle_publish(
102114
node: &mut DriaComputeNode,
103-
result: String,
104-
task_id: String,
105-
) -> Result<()> {
106-
let (message, acceptance) = match exec_result {
115+
task: WorkflowsWorkerOutput,
116+
) -> Result<MessageAcceptance> {
117+
let (message, acceptance) = match task.result {
107118
Ok(result) => {
108-
// obtain public key from the payload
109-
let task_public_key_bytes =
110-
hex::decode(&task.public_key).wrap_err("Could not decode public key")?;
111-
let task_public_key = PublicKey::parse_slice(&task_public_key_bytes, None)?;
112-
113119
// prepare signed and encrypted payload
114120
let payload = TaskResponsePayload::new(
115121
result,
116-
&task_id,
117-
&task_public_key,
122+
&task.task_id,
123+
&task.public_key,
118124
&node.config.secret_key,
119-
model_name,
120-
task_stats.record_published_at(),
125+
task.model_name,
126+
task.stats.record_published_at(),
121127
)?;
122128
let payload_str = serde_json::to_string(&payload)
123129
.wrap_err("Could not serialize response payload")?;
124130

125131
// prepare signed message
126-
log::debug!("Publishing result for task {}\n{}", task_id, payload_str);
132+
log::debug!(
133+
"Publishing result for task {}\n{}",
134+
task.task_id,
135+
payload_str
136+
);
127137
let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC);
128138
// accept so that if there are others included in filter they can do the task
129139
(message, MessageAcceptance::Accept)
130140
}
131141
Err(err) => {
132142
// use pretty display string for error logging with causes
133143
let err_string = format!("{:#}", err);
134-
log::error!("Task {} failed: {}", task_id, err_string);
144+
log::error!("Task {} failed: {}", task.task_id, err_string);
135145

136146
// prepare error payload
137147
let error_payload = TaskErrorPayload {
138-
task_id,
148+
task_id: task.task_id.clone(),
139149
error: err_string,
140-
model: model_name,
141-
stats: task_stats.record_published_at(),
150+
model: task.model_name,
151+
stats: task.stats.record_published_at(),
142152
};
143153
let error_payload_str = serde_json::to_string(&error_payload)
144154
.wrap_err("Could not serialize error payload")?;
@@ -160,7 +170,7 @@ impl WorkflowHandler {
160170
log::error!("{}", err_msg);
161171

162172
let payload = serde_json::json!({
163-
"taskId": task_id,
173+
"taskId": task.task_id,
164174
"error": err_msg,
165175
});
166176
let message = DKNMessage::new_signed(
@@ -171,6 +181,6 @@ impl WorkflowHandler {
171181
node.publish(message).await?;
172182
};
173183

174-
Ok(())
184+
Ok(acceptance)
175185
}
176186
}

compute/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub(crate) mod handlers;
66
pub(crate) mod node;
77
pub(crate) mod payloads;
88
pub(crate) mod utils;
9+
pub(crate) mod workers;
910

1011
/// Crate version of the compute node.
1112
/// This value is attached within the published messages.

compute/src/main.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,16 @@ async fn main() -> Result<()> {
7575
}
7676

7777
let node_token = token.clone();
78-
let (mut node, p2p) = DriaComputeNode::new(config, node_token).await?;
78+
let (mut node, p2p, mut workflows) = DriaComputeNode::new(config, node_token).await?;
7979

8080
// launch the p2p in a separate thread
8181
log::info!("Spawning peer-to-peer client thread.");
8282
let p2p_handle = tokio::spawn(async move { p2p.run().await });
8383

84+
// launch the workflows in a separate thread
85+
log::info!("Spawning workflows worker thread.");
86+
let workflows_handle = tokio::spawn(async move { workflows.run().await });
87+
8488
// launch the node in a separate thread
8589
log::info!("Spawning compute node thread.");
8690
let node_handle = tokio::spawn(async move {
@@ -94,6 +98,9 @@ async fn main() -> Result<()> {
9498
if let Err(err) = node_handle.await {
9599
log::error!("Node handle error: {}", err);
96100
};
101+
if let Err(err) = workflows_handle.await {
102+
log::error!("Workflows handle error: {}", err);
103+
};
97104
if let Err(err) = p2p_handle.await {
98105
log::error!("P2P handle error: {}", err);
99106
};

compute/src/node.rs

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ use tokio::{
1010
sync::mpsc,
1111
time::{Duration, Instant},
1212
};
13-
use tokio_util::sync::CancellationToken;
13+
use tokio_util::{either::Either, sync::CancellationToken};
1414

1515
use crate::{
1616
config::*,
1717
handlers::*,
1818
utils::{crypto::secret_to_keypair, AvailableNodes, DKNMessage},
19+
workers::workflow::{WorkflowsWorker, WorkflowsWorkerInput, WorkflowsWorkerOutput},
1920
};
2021

2122
/// Number of seconds between refreshing the Kademlia DHT.
@@ -27,7 +28,10 @@ pub struct DriaComputeNode {
2728
pub available_nodes: AvailableNodes,
2829
pub cancellation: CancellationToken,
2930
peer_last_refreshed: Instant,
30-
msg_rx: mpsc::Receiver<(PeerId, MessageId, Message)>,
31+
// channels
32+
message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>,
33+
worklow_tx: mpsc::Sender<WorkflowsWorkerInput>,
34+
publish_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
3135
}
3236

3337
impl DriaComputeNode {
@@ -37,7 +41,7 @@ impl DriaComputeNode {
3741
pub async fn new(
3842
config: DriaComputeNodeConfig,
3943
cancellation: CancellationToken,
40-
) -> Result<(DriaComputeNode, DriaP2PClient)> {
44+
) -> Result<(DriaComputeNode, DriaP2PClient, WorkflowsWorker)> {
4145
// create the keypair from secret key
4246
let keypair = secret_to_keypair(&config.secret_key);
4347

@@ -55,7 +59,7 @@ impl DriaComputeNode {
5559
log::info!("Using identity: {}", protocol);
5660

5761
// create p2p client
58-
let (p2p_client, p2p_commander, msg_rx) = DriaP2PClient::new(
62+
let (p2p_client, p2p_commander, message_rx) = DriaP2PClient::new(
5963
keypair,
6064
config.p2p_listen_addr.clone(),
6165
available_nodes.bootstrap_nodes.clone().into_iter(),
@@ -64,16 +68,24 @@ impl DriaComputeNode {
6468
protocol,
6569
)?;
6670

71+
// create workflow worker
72+
let (worklow_tx, workflow_rx) = mpsc::channel(256);
73+
let (publish_tx, publish_rx) = mpsc::channel(256);
74+
let workflows_worker = WorkflowsWorker::new(workflow_rx, publish_tx);
75+
6776
Ok((
6877
DriaComputeNode {
6978
config,
7079
p2p: p2p_commander,
7180
cancellation,
7281
available_nodes,
73-
msg_rx,
82+
message_rx,
83+
worklow_tx,
84+
publish_rx,
7485
peer_last_refreshed: Instant::now(),
7586
},
7687
p2p_client,
88+
workflows_worker,
7789
))
7890
}
7991

@@ -207,7 +219,18 @@ impl DriaComputeNode {
207219
// then handle the prepared message
208220
let handler_result = match topic_str {
209221
WorkflowHandler::LISTEN_TOPIC => {
210-
WorkflowHandler::handle_compute(self, message).await
222+
let compute_result = WorkflowHandler::handle_compute(self, message).await;
223+
match compute_result {
224+
Ok(Either::Left(acceptance)) => Ok(acceptance),
225+
Ok(Either::Right(workflow_message)) => {
226+
if let Err(e) = self.worklow_tx.send(workflow_message).await {
227+
log::error!("Error sending workflow message: {:?}", e);
228+
};
229+
230+
Ok(MessageAcceptance::Accept)
231+
}
232+
Err(err) => Err(err),
233+
}
211234
}
212235
PingpongHandler::LISTEN_TOPIC => PingpongHandler::handle_ping(self, message).await,
213236
_ => unreachable!(), // unreachable because of the if condition
@@ -251,7 +274,12 @@ impl DriaComputeNode {
251274
// the underlying p2p client is expected to handle the rest within its own loop
252275
loop {
253276
tokio::select! {
254-
gossipsub_msg = self.msg_rx.recv() => {
277+
publish_msg = self.publish_rx.recv() => {
278+
if let Some(result) = publish_msg {
279+
WorkflowHandler::handle_publish(self, result).await?;
280+
}
281+
},
282+
gossipsub_msg = self.message_rx.recv() => {
255283
if let Some((peer_id, message_id, message)) = gossipsub_msg {
256284
// handle the message, returning a message acceptance for the received one
257285
let acceptance = self.handle_message((peer_id, &message_id, message)).await;
@@ -282,15 +310,16 @@ impl DriaComputeNode {
282310
Ok(())
283311
}
284312

285-
/// Shutdown channels between p2p and yourself.
313+
/// Shutdown channels between p2p, worker and yourself.
286314
pub async fn shutdown(&mut self) -> Result<()> {
287-
// send shutdown signal
288315
log::debug!("Sending shutdown command to p2p client.");
289316
self.p2p.shutdown().await?;
290317

291-
// close message channel
292318
log::debug!("Closing message channel.");
293-
self.msg_rx.close();
319+
self.message_rx.close();
320+
321+
log::debug!("Closing publish channel.");
322+
self.publish_rx.close();
294323

295324
Ok(())
296325
}
@@ -329,7 +358,7 @@ mod tests {
329358

330359
// create node
331360
let cancellation = CancellationToken::new();
332-
let (mut node, p2p) =
361+
let (mut node, p2p, _) =
333362
DriaComputeNode::new(DriaComputeNodeConfig::default(), cancellation.clone())
334363
.await
335364
.expect("should create node");

compute/src/workers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod workflow;

0 commit comments

Comments
 (0)