Skip to content

Commit 7383c05

Browse files
committed
node refactors for parallelization
1 parent 586096e commit 7383c05

File tree

8 files changed

+190
-206
lines changed

8 files changed

+190
-206
lines changed

compute/src/handlers/mod.rs

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,5 @@
1-
use crate::{utils::DKNMessage, DriaComputeNode};
2-
use async_trait::async_trait;
3-
use dkn_p2p::libp2p::gossipsub::MessageAcceptance;
4-
use eyre::Result;
5-
61
mod pingpong;
72
pub use pingpong::PingpongHandler;
83

94
mod workflow;
105
pub use workflow::WorkflowHandler;
11-
12-
/// A DKN task is to be handled by the compute node, respecting this trait.
13-
///
14-
/// It is expected for the implemented handler to handle messages coming from `LISTEN_TOPIC`,
15-
/// and then respond back to the `RESPONSE_TOPIC`.
16-
#[async_trait]
17-
pub trait ComputeHandler {
18-
/// Gossipsub topic name to listen for incoming messages from the network.
19-
const LISTEN_TOPIC: &'static str;
20-
/// Gossipsub topic name to respond with messages to the network.
21-
const RESPONSE_TOPIC: &'static str;
22-
23-
/// A generic handler for DKN tasks.
24-
///
25-
/// Returns a `MessageAcceptance` value that tells the P2P client to accept the incoming message.
26-
///
27-
/// The handler has mutable reference to the compute node, and therefore can respond within the handler itself in any way it would like.
28-
async fn handle_compute(
29-
node: &mut DriaComputeNode,
30-
message: DKNMessage,
31-
) -> Result<MessageAcceptance>;
32-
}

compute/src/handlers/pingpong.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
use super::ComputeHandler;
21
use crate::{
32
utils::{get_current_time_nanos, DKNMessage},
43
DriaComputeNode,
54
};
6-
use async_trait::async_trait;
75
use dkn_p2p::libp2p::gossipsub::MessageAcceptance;
86
use dkn_workflows::{Model, ModelProvider};
97
use eyre::{Context, Result};
@@ -24,12 +22,20 @@ struct PingpongResponse {
2422
pub(crate) timestamp: u128,
2523
}
2624

27-
#[async_trait]
28-
impl ComputeHandler for PingpongHandler {
29-
const LISTEN_TOPIC: &'static str = "ping";
30-
const RESPONSE_TOPIC: &'static str = "pong";
25+
impl PingpongHandler {
26+
pub(crate) const LISTEN_TOPIC: &'static str = "ping";
27+
pub(crate) const RESPONSE_TOPIC: &'static str = "pong";
3128

32-
async fn handle_compute(
29+
/// Handles the ping message and responds with a pong message.
30+
///
31+
/// 1. Parses the payload of the incoming message into a `PingpongPayload`.
32+
/// 2. Checks if the current time is past the deadline specified in the ping request.
33+
/// 3. If the current time is past the deadline, logs a debug message and ignores the ping request.
34+
/// 4. If the current time is within the deadline, constructs a `PingpongResponse` with the UUID from the ping request, the models from the node's configuration, and the current timestamp.
35+
/// 5. Creates a new signed `DKNMessage` with the response body and the `RESPONSE_TOPIC`.
36+
/// 6. Publishes the response message.
37+
/// 7. Returns `MessageAcceptance::Accept` so that ping is propagated to others as well.
38+
pub(crate) async fn handle_ping(
3339
node: &mut DriaComputeNode,
3440
message: DKNMessage,
3541
) -> Result<MessageAcceptance> {

compute/src/handlers/workflow.rs

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
1-
use std::time::Instant;
2-
3-
use async_trait::async_trait;
41
use dkn_p2p::libp2p::gossipsub::MessageAcceptance;
52
use dkn_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow};
63
use eyre::{eyre, Context, Result};
74
use libsecp256k1::PublicKey;
85
use serde::Deserialize;
6+
use std::time::Instant;
97

108
use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload, TaskStats};
119
use crate::utils::{get_current_time_nanos, DKNMessage};
1210
use crate::DriaComputeNode;
1311

14-
use super::ComputeHandler;
15-
1612
pub struct WorkflowHandler;
1713

1814
#[derive(Debug, Deserialize)]
1915
struct WorkflowPayload {
20-
/// [Workflow](https://github.com/andthattoo/ollama-workflows/) object to be parsed.
16+
/// [Workflow](https://github.com/andthattoo/ollama-workflows/blob/main/src/program/workflow.rs) object to be parsed.
2117
pub(crate) workflow: Workflow,
2218
/// A lıst of model (that can be parsed into `Model`) or model provider names.
2319
/// If model provider is given, the first matching model in the node config is used for that.
@@ -28,12 +24,11 @@ struct WorkflowPayload {
2824
pub(crate) prompt: Option<String>,
2925
}
3026

31-
#[async_trait]
32-
impl ComputeHandler for WorkflowHandler {
33-
const LISTEN_TOPIC: &'static str = "task";
34-
const RESPONSE_TOPIC: &'static str = "results";
27+
impl WorkflowHandler {
28+
pub(crate) const LISTEN_TOPIC: &'static str = "task";
29+
pub(crate) const RESPONSE_TOPIC: &'static str = "results";
3530

36-
async fn handle_compute(
31+
pub(crate) async fn handle_compute(
3732
node: &mut DriaComputeNode,
3833
message: DKNMessage,
3934
) -> Result<MessageAcceptance> {
@@ -85,26 +80,29 @@ impl ComputeHandler for WorkflowHandler {
8580
} else {
8681
Executor::new(model)
8782
};
88-
let mut memory = ProgramMemory::new();
8983
let entry: Option<Entry> = task
9084
.input
9185
.prompt
9286
.map(|prompt| Entry::try_value_or_str(&prompt));
9387

9488
// execute workflow with cancellation
95-
let exec_result: Result<String>;
89+
let mut memory = ProgramMemory::new();
90+
9691
let exec_started_at = Instant::now();
97-
tokio::select! {
98-
_ = node.cancellation.cancelled() => {
99-
log::info!("Received cancellation, quitting all tasks.");
100-
return Ok(MessageAcceptance::Accept);
101-
},
102-
exec_result_inner = executor.execute(entry.as_ref(), &task.input.workflow, &mut memory) => {
103-
exec_result = exec_result_inner.map_err(|e| eyre!("Execution error: {}", e.to_string()));
104-
}
105-
}
92+
let exec_result = executor
93+
.execute(entry.as_ref(), &task.input.workflow, &mut memory)
94+
.await
95+
.map_err(|e| eyre!("Execution error: {}", e.to_string()));
10696
task_stats = task_stats.record_execution_time(exec_started_at);
10797

98+
Ok(MessageAcceptance::Accept)
99+
}
100+
101+
async fn handle_publish(
102+
node: &mut DriaComputeNode,
103+
result: String,
104+
task_id: String,
105+
) -> Result<()> {
108106
let (message, acceptance) = match exec_result {
109107
Ok(result) => {
110108
// obtain public key from the payload
@@ -115,7 +113,7 @@ impl ComputeHandler for WorkflowHandler {
115113
// prepare signed and encrypted payload
116114
let payload = TaskResponsePayload::new(
117115
result,
118-
&task.task_id,
116+
&task_id,
119117
&task_public_key,
120118
&node.config.secret_key,
121119
model_name,
@@ -125,23 +123,19 @@ impl ComputeHandler for WorkflowHandler {
125123
.wrap_err("Could not serialize response payload")?;
126124

127125
// prepare signed message
128-
log::debug!(
129-
"Publishing result for task {}\n{}",
130-
task.task_id,
131-
payload_str
132-
);
126+
log::debug!("Publishing result for task {}\n{}", task_id, payload_str);
133127
let message = DKNMessage::new(payload_str, Self::RESPONSE_TOPIC);
134128
// accept so that if there are others included in filter they can do the task
135129
(message, MessageAcceptance::Accept)
136130
}
137131
Err(err) => {
138132
// use pretty display string for error logging with causes
139133
let err_string = format!("{:#}", err);
140-
log::error!("Task {} failed: {}", task.task_id, err_string);
134+
log::error!("Task {} failed: {}", task_id, err_string);
141135

142136
// prepare error payload
143137
let error_payload = TaskErrorPayload {
144-
task_id: task.task_id.clone(),
138+
task_id,
145139
error: err_string,
146140
model: model_name,
147141
stats: task_stats.record_published_at(),
@@ -166,7 +160,7 @@ impl ComputeHandler for WorkflowHandler {
166160
log::error!("{}", err_msg);
167161

168162
let payload = serde_json::json!({
169-
"taskId": task.task_id,
163+
"taskId": task_id,
170164
"error": err_msg,
171165
});
172166
let message = DKNMessage::new_signed(
@@ -175,8 +169,8 @@ impl ComputeHandler for WorkflowHandler {
175169
&node.config.secret_key,
176170
);
177171
node.publish(message).await?;
178-
}
172+
};
179173

180-
Ok(acceptance)
174+
Ok(())
181175
}
182176
}

compute/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async fn main() -> Result<()> {
8484
// launch the node in a separate thread
8585
log::info!("Spawning compute node thread.");
8686
let node_handle = tokio::spawn(async move {
87-
if let Err(err) = node.launch().await {
87+
if let Err(err) = node.run().await {
8888
log::error!("Node launch error: {}", err);
8989
panic!("Node failed.")
9090
};

0 commit comments

Comments
 (0)