Skip to content

Commit a17b54e

Browse files
committed
added separate ollama worker, small rfks
1 parent 304edc6 commit a17b54e

File tree

7 files changed

+153
-65
lines changed

7 files changed

+153
-65
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ jobs:
3232

3333
- name: Run tests
3434
run: cargo test --workspace
35+
36+
- name: Run linter
37+
run: cargo clippy --workspace

compute/src/handlers/workflow.rs

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl WorkflowHandler {
3232
pub(crate) async fn handle_compute(
3333
node: &mut DriaComputeNode,
3434
compute_message: &DKNMessage,
35-
) -> Result<Either<MessageAcceptance, WorkflowsWorkerInput>> {
35+
) -> Result<Either<MessageAcceptance, (WorkflowsWorkerInput, bool)>> {
3636
let stats = TaskStats::new().record_received_at();
3737
let task = compute_message
3838
.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)
@@ -78,14 +78,17 @@ impl WorkflowHandler {
7878
log::info!("Using model {} for task {}", model_name, task.task_id);
7979

8080
// prepare workflow executor
81-
let executor = if model_provider == ModelProvider::Ollama {
82-
Executor::new_at(
83-
model,
84-
&node.config.workflows.ollama.host,
85-
node.config.workflows.ollama.port,
81+
let (executor, batchable) = if model_provider == ModelProvider::Ollama {
82+
(
83+
Executor::new_at(
84+
model,
85+
&node.config.workflows.ollama.host,
86+
node.config.workflows.ollama.port,
87+
),
88+
false,
8689
)
8790
} else {
88-
Executor::new(model)
91+
(Executor::new(model), true)
8992
};
9093

9194
// prepare entry from prompt
@@ -97,15 +100,18 @@ impl WorkflowHandler {
97100
// get workflow as well
98101
let workflow = task.input.workflow;
99102

100-
Ok(Either::Right(WorkflowsWorkerInput {
101-
entry,
102-
executor,
103-
workflow,
104-
model_name,
105-
task_id: task.task_id,
106-
public_key: task_public_key,
107-
stats,
108-
}))
103+
Ok(Either::Right((
104+
WorkflowsWorkerInput {
105+
entry,
106+
executor,
107+
workflow,
108+
model_name,
109+
task_id: task.task_id,
110+
public_key: task_public_key,
111+
stats,
112+
},
113+
batchable,
114+
)))
109115
}
110116

111117
pub(crate) async fn handle_publish(
@@ -123,16 +129,15 @@ impl WorkflowHandler {
123129
task.model_name,
124130
task.stats.record_published_at(),
125131
)?;
132+
133+
// convert payload to message
126134
let payload_str = serde_json::to_string(&payload)
127135
.wrap_err("could not serialize response payload")?;
128-
129-
// prepare signed message
130136
log::debug!(
131137
"Publishing result for task {}\n{}",
132138
task.task_id,
133139
payload_str
134140
);
135-
136141
DKNMessage::new(payload_str, Self::RESPONSE_TOPIC)
137142
}
138143
Err(err) => {

compute/src/main.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ async fn main() -> Result<()> {
3131
let token = CancellationToken::new();
3232
let cancellation_token = token.clone();
3333
tokio::spawn(async move {
34+
// the timeout is done for profiling only, and should not be used in production
3435
if let Ok(Ok(duration_secs)) =
3536
env::var("DKN_EXIT_TIMEOUT").map(|s| s.to_string().parse::<u64>())
3637
{
@@ -75,15 +76,17 @@ async fn main() -> Result<()> {
7576
}
7677

7778
let node_token = token.clone();
78-
let (mut node, p2p, mut workflows) = DriaComputeNode::new(config, node_token).await?;
79+
let (mut node, p2p, mut worker_batch, mut worker_single) =
80+
DriaComputeNode::new(config, node_token).await?;
7981

80-
// launch the p2p in a separate thread
8182
log::info!("Spawning peer-to-peer client thread.");
8283
let p2p_handle = tokio::spawn(async move { p2p.run().await });
8384

84-
// launch the workflows in a separate thread
85-
log::info!("Spawning workflows worker thread.");
86-
let workflows_handle = tokio::spawn(async move { workflows.run().await });
85+
log::info!("Spawning workflows batch worker thread.");
86+
let worker_batch_handle = tokio::spawn(async move { worker_batch.run_batch().await });
87+
88+
log::info!("Spawning workflows single worker thread.");
89+
let worker_single_handle = tokio::spawn(async move { worker_single.run().await });
8790

8891
// launch the node in a separate thread
8992
log::info!("Spawning compute node thread.");
@@ -98,8 +101,11 @@ async fn main() -> Result<()> {
98101
if let Err(err) = node_handle.await {
99102
log::error!("Node handle error: {}", err);
100103
};
101-
if let Err(err) = workflows_handle.await {
102-
log::error!("Workflows handle error: {}", err);
104+
if let Err(err) = worker_single_handle.await {
105+
log::error!("Workflows single worker handle error: {}", err);
106+
};
107+
if let Err(err) = worker_batch_handle.await {
108+
log::error!("Workflows batch worker handle error: {}", err);
103109
};
104110
if let Err(err) = p2p_handle.await {
105111
log::error!("P2P handle error: {}", err);

compute/src/node.rs

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ pub struct DriaComputeNode {
2626
pub p2p: DriaP2PCommander,
2727
pub available_nodes: AvailableNodes,
2828
pub cancellation: CancellationToken,
29-
// channels
29+
/// Gossipsub message receiver.
3030
message_rx: mpsc::Receiver<(PeerId, MessageId, Message)>,
31-
worklow_tx: mpsc::Sender<WorkflowsWorkerInput>,
32-
publish_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
31+
/// Workflow transmitter to send batchable tasks.
32+
workflow_batch_tx: mpsc::Sender<WorkflowsWorkerInput>,
33+
/// Publish receiver to receive messages to be published.
34+
publish_batch_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
35+
/// Workflow transmitter to send single tasks.
36+
workflow_single_tx: mpsc::Sender<WorkflowsWorkerInput>,
37+
/// Publish receiver to receive messages to be published.
38+
publish_single_rx: mpsc::Receiver<WorkflowsWorkerOutput>,
3339
}
3440

3541
impl DriaComputeNode {
@@ -39,7 +45,12 @@ impl DriaComputeNode {
3945
pub async fn new(
4046
config: DriaComputeNodeConfig,
4147
cancellation: CancellationToken,
42-
) -> Result<(DriaComputeNode, DriaP2PClient, WorkflowsWorker)> {
48+
) -> Result<(
49+
DriaComputeNode,
50+
DriaP2PClient,
51+
WorkflowsWorker,
52+
WorkflowsWorker,
53+
)> {
4354
// create the keypair from secret key
4455
let keypair = secret_to_keypair(&config.secret_key);
4556

@@ -66,10 +77,10 @@ impl DriaComputeNode {
6677
protocol,
6778
)?;
6879

69-
// create workflow worker
70-
let (worklow_tx, workflow_rx) = mpsc::channel(256);
71-
let (publish_tx, publish_rx) = mpsc::channel(256);
72-
let workflows_worker = WorkflowsWorker::new(workflow_rx, publish_tx);
80+
// create workflow workers
81+
let (workflows_batch_worker, workflow_batch_tx, publish_batch_rx) = WorkflowsWorker::new();
82+
let (workflows_single_worker, workflow_single_tx, publish_single_rx) =
83+
WorkflowsWorker::new();
7384

7485
Ok((
7586
DriaComputeNode {
@@ -78,11 +89,14 @@ impl DriaComputeNode {
7889
cancellation,
7990
available_nodes,
8091
message_rx,
81-
worklow_tx,
82-
publish_rx,
92+
workflow_batch_tx,
93+
publish_batch_rx,
94+
workflow_single_tx,
95+
publish_single_rx,
8396
},
8497
p2p_client,
85-
workflows_worker,
98+
workflows_batch_worker,
99+
workflows_single_worker,
86100
))
87101
}
88102

@@ -164,7 +178,7 @@ impl DriaComputeNode {
164178
return MessageAcceptance::Ignore;
165179
}
166180

167-
// first, parse the raw gossipsub message to a prepared message
181+
// parse the raw gossipsub message to a prepared DKN message
168182
let message = match DKNMessage::try_from_gossipsub_message(
169183
&message,
170184
&self.config.admin_public_key,
@@ -177,19 +191,25 @@ impl DriaComputeNode {
177191
}
178192
};
179193

180-
// then handle the prepared message
194+
// handle the DKN message with respect to the topic
181195
let handler_result = match message.topic.as_str() {
182196
WorkflowHandler::LISTEN_TOPIC => {
183197
match WorkflowHandler::handle_compute(self, &message).await {
198+
// we got acceptance, so something was not right about the workflow and we can ignore it
184199
Ok(Either::Left(acceptance)) => Ok(acceptance),
185-
Ok(Either::Right(workflow_message)) => {
186-
if let Err(e) = self.worklow_tx.send(workflow_message).await {
200+
// we got the parsed workflow itself, send to a worker thread w.r.t batchable
201+
Ok(Either::Right((workflow_message, batchable))) => {
202+
if let Err(e) = match batchable {
203+
true => self.workflow_batch_tx.send(workflow_message).await,
204+
false => self.workflow_single_tx.send(workflow_message).await,
205+
} {
187206
log::error!("Error sending workflow message: {:?}", e);
188207
};
189208

190209
// accept the message in case others may be included in the filter as well
191210
Ok(MessageAcceptance::Accept)
192211
}
212+
// something went wrong, handle this outside
193213
Err(err) => Err(err),
194214
}
195215
}
@@ -241,7 +261,16 @@ impl DriaComputeNode {
241261
_ = tokio::time::sleep(available_node_refresh_duration) => self.handle_available_nodes_refresh().await,
242262
// a Workflow message to be published is received from the channel
243263
// this is expected to be sent by the workflow worker
244-
publish_msg = self.publish_rx.recv() => {
264+
publish_msg = self.publish_batch_rx.recv() => {
265+
if let Some(result) = publish_msg {
266+
WorkflowHandler::handle_publish(self, result).await?;
267+
} else {
268+
log::error!("Publish channel closed unexpectedly.");
269+
break;
270+
};
271+
},
272+
// TODO: make the both receivers handled together somehow
273+
publish_msg = self.publish_single_rx.recv() => {
245274
if let Some(result) = publish_msg {
246275
WorkflowHandler::handle_publish(self, result).await?;
247276
} else {
@@ -293,7 +322,7 @@ impl DriaComputeNode {
293322
self.message_rx.close();
294323

295324
log::debug!("Closing publish channel.");
296-
self.publish_rx.close();
325+
self.publish_batch_rx.close();
297326

298327
Ok(())
299328
}
@@ -339,7 +368,7 @@ mod tests {
339368

340369
// create node
341370
let cancellation = CancellationToken::new();
342-
let (mut node, p2p, _) =
371+
let (mut node, p2p, _, _) =
343372
DriaComputeNode::new(DriaComputeNodeConfig::default(), cancellation.clone())
344373
.await
345374
.expect("should create node");

compute/src/utils/message.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ impl DKNMessage {
133133

134134
// check dria signature
135135
// NOTE: when we have many public keys, we should check the signature against all of them
136-
if !message.is_signed(&public_key)? {
136+
if !message.is_signed(public_key)? {
137137
return Err(eyre!("Invalid signature."));
138138
}
139139

0 commit comments

Comments
 (0)