Skip to content

Commit 2dc292e

Browse files
committed
added error reporting
1 parent 5eed72c commit 2dc292e

File tree

12 files changed

+125
-95
lines changed

12 files changed

+125
-95
lines changed

src/handlers/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ use async_trait::async_trait;
22
use eyre::Result;
33
use libp2p::gossipsub::MessageAcceptance;
44

5+
mod topics;
6+
pub use topics::*;
7+
58
mod pingpong;
69
pub use pingpong::PingpongHandler;
710

@@ -10,8 +13,10 @@ pub use workflow::WorkflowHandler;
1013

1114
use crate::{utils::DKNMessage, DriaComputeNode};
1215

16+
/// A DKN task is to be handled by the compute node, respecting this trait.
1317
#[async_trait]
1418
pub trait ComputeHandler {
19+
/// A generic handler for DKN tasks.
1520
async fn handle_compute(
1621
node: &mut DriaComputeNode,
1722
message: DKNMessage,

src/handlers/pingpong.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
2-
node::DriaComputeNode,
32
utils::{get_current_time_nanos, DKNMessage},
3+
DriaComputeNode,
44
};
55
use async_trait::async_trait;
66
use eyre::Result;
@@ -55,14 +55,14 @@ impl ComputeHandler for PingpongHandler {
5555
timestamp: get_current_time_nanos(),
5656
};
5757

58+
// publish message
5859
let message = DKNMessage::new_signed(
5960
serde_json::json!(response_body).to_string(),
6061
result_topic,
6162
&node.config.secret_key,
6263
);
6364
node.publish(message)?;
6465

65-
// accept message, someone else may be included in the filter
6666
Ok(MessageAcceptance::Accept)
6767
}
6868
}

src/handlers/topics.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
pub const PINGPONG_LISTEN_TOPIC: &str = "ping";
2+
pub const PINGPONG_RESPONSE_TOPIC: &str = "pong";
3+
pub const WORKFLOW_LISTEN_TOPIC: &str = "task";
4+
pub const WORKFLOW_RESPONSE_TOPIC: &str = "results";

src/handlers/workflow.rs

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use libp2p::gossipsub::MessageAcceptance;
44
use ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow};
55
use serde::Deserialize;
66

7-
use crate::node::DriaComputeNode;
8-
use crate::utils::payload::{TaskRequestPayload, TaskResponsePayload};
7+
use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload};
98
use crate::utils::{get_current_time_nanos, DKNMessage};
9+
use crate::DriaComputeNode;
1010

1111
use super::ComputeHandler;
1212

@@ -32,6 +32,7 @@ impl ComputeHandler for WorkflowHandler {
3232
message: DKNMessage,
3333
result_topic: &str,
3434
) -> Result<MessageAcceptance> {
35+
let config = &node.config;
3536
let task = message.parse_payload::<TaskRequestPayload<WorkflowPayload>>(true)?;
3637

3738
// check if deadline is past or not
@@ -49,7 +50,7 @@ impl ComputeHandler for WorkflowHandler {
4950
}
5051

5152
// check task inclusion via the bloom filter
52-
if !task.filter.contains(&node.config.address)? {
53+
if !task.filter.contains(&config.address)? {
5354
log::info!(
5455
"Task {} does not include this node within the filter.",
5556
task.task_id
@@ -59,23 +60,15 @@ impl ComputeHandler for WorkflowHandler {
5960
return Ok(MessageAcceptance::Accept);
6061
}
6162

62-
// obtain public key from the payload
63-
let task_public_key = hex::decode(&task.public_key)?;
64-
6563
// read model / provider from the task
66-
let (model_provider, model) = node
67-
.config
64+
let (model_provider, model) = config
6865
.model_config
6966
.get_any_matching_model(task.input.model)?;
7067
log::info!("Using model {} for task {}", model, task.task_id);
7168

7269
// prepare workflow executor
7370
let executor = if model_provider == ModelProvider::Ollama {
74-
Executor::new_at(
75-
model,
76-
&node.config.ollama_config.host,
77-
node.config.ollama_config.port,
78-
)
71+
Executor::new_at(model, &config.ollama_config.host, config.ollama_config.port)
7972
} else {
8073
Executor::new(model)
8174
};
@@ -86,38 +79,52 @@ impl ComputeHandler for WorkflowHandler {
8679
.map(|prompt| Entry::try_value_or_str(&prompt));
8780

8881
// execute workflow with cancellation
89-
// TODO: is there a better way to handle this?
90-
let result: String;
82+
let exec_result: Result<String>;
9183
tokio::select! {
9284
_ = node.cancellation.cancelled() => {
9385
log::info!("Received cancellation, quitting all tasks.");
94-
return Ok(MessageAcceptance::Accept)
86+
return Ok(MessageAcceptance::Accept);
9587
},
96-
exec_result = executor.execute(entry.as_ref(), task.input.workflow, &mut memory) => {
97-
match exec_result {
98-
Ok(exec_result) => {
99-
result = exec_result;
100-
}
101-
Err(e) => {
102-
return Err(eyre!("Workflow failed with error {}", e));
103-
}
104-
}
88+
exec_result_inner = executor.execute(entry.as_ref(), task.input.workflow, &mut memory) => {
89+
exec_result = exec_result_inner.map_err(|e| eyre!("{}", e.to_string()));
10590
}
10691
}
10792

108-
// prepare signed and encrypted payload
109-
let payload = TaskResponsePayload::new(
110-
result,
111-
&task.task_id,
112-
&task_public_key,
113-
&node.config.secret_key,
114-
)?;
115-
let payload_str = payload.to_string()?;
93+
match exec_result {
94+
Ok(result) => {
95+
// obtain public key from the payload
96+
let task_public_key = hex::decode(&task.public_key)?;
97+
98+
// prepare signed and encrypted payload
99+
let payload = TaskResponsePayload::new(
100+
result,
101+
&task.task_id,
102+
&task_public_key,
103+
&config.secret_key,
104+
)?;
105+
let payload_str = serde_json::to_string(&payload)?;
106+
107+
// publish the result
108+
let message = DKNMessage::new(payload_str, result_topic);
109+
node.publish(message)?;
110+
111+
// accept so that if there are others included in filter they can do the task
112+
Ok(MessageAcceptance::Accept)
113+
}
114+
Err(err) => {
115+
log::error!("Task {} failed: {}", task.task_id, err);
116+
117+
// prepare error payload
118+
let error_payload = TaskErrorPayload::new(task.task_id, err.to_string());
119+
let error_payload_str = serde_json::to_string(&error_payload)?;
116120

117-
// publish the result
118-
let message = DKNMessage::new(payload_str, result_topic);
119-
node.publish(message)?;
121+
// publish the error result for diagnostics
122+
let message = DKNMessage::new(error_payload_str, result_topic);
123+
node.publish(message)?;
120124

121-
Ok(MessageAcceptance::Accept)
125+
// ignore just in case, workflow may be bugged
126+
Ok(MessageAcceptance::Ignore)
127+
}
128+
}
122129
}
123130
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub(crate) mod config;
55
pub(crate) mod handlers;
66
pub(crate) mod node;
77
pub(crate) mod p2p;
8+
pub(crate) mod payloads;
89
pub(crate) mod utils;
910

1011
/// Crate version of the compute node.

src/node.rs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use tokio_util::sync::CancellationToken;
55

66
use crate::{
77
config::DriaComputeNodeConfig,
8-
handlers::{ComputeHandler, PingpongHandler, WorkflowHandler},
8+
handlers::*,
99
p2p::P2PClient,
1010
utils::{crypto::secret_to_keypair, AvailableNodes, DKNMessage},
1111
};
@@ -105,11 +105,6 @@ impl DriaComputeNode {
105105
/// Launches the main loop of the compute node.
106106
/// This method is not expected to return until cancellation occurs.
107107
pub async fn launch(&mut self) -> Result<()> {
108-
const PINGPONG_LISTEN_TOPIC: &str = "ping";
109-
const PINGPONG_RESPONSE_TOPIC: &str = "pong";
110-
const WORKFLOW_LISTEN_TOPIC: &str = "task";
111-
const WORKFLOW_RESPONSE_TOPIC: &str = "results";
112-
113108
// subscribe to topics
114109
self.subscribe(PINGPONG_LISTEN_TOPIC)?;
115110
self.subscribe(PINGPONG_RESPONSE_TOPIC)?;
@@ -172,7 +167,7 @@ impl DriaComputeNode {
172167
};
173168

174169
// then handle the prepared message
175-
let handle_result = match topic_str {
170+
let handler_result = match topic_str {
176171
WORKFLOW_LISTEN_TOPIC => {
177172
WorkflowHandler::handle_compute(self, message, WORKFLOW_RESPONSE_TOPIC).await
178173
}
@@ -185,26 +180,23 @@ impl DriaComputeNode {
185180
};
186181

187182
// validate the message based on the result
188-
match handle_result {
183+
match handler_result {
189184
Ok(acceptance) => {
190185
self.p2p.validate_message(&message_id, &peer_id, acceptance)?;
191186
},
192187
Err(err) => {
193188
log::error!("Error handling {} message: {}", topic_str, err);
194-
self.p2p.validate_message(&message_id, &peer_id, gossipsub::MessageAcceptance::Reject)?;
189+
self.p2p.validate_message(&message_id, &peer_id, gossipsub::MessageAcceptance::Ignore)?;
195190
}
196191
}
197192
} else if std::matches!(topic_str, PINGPONG_RESPONSE_TOPIC | WORKFLOW_RESPONSE_TOPIC) {
198193
// since we are responding to these topics, we might receive messages from other compute nodes
199-
// we can gracefully ignore them
194+
// we can gracefully ignore them and propagate it to to others
200195
log::debug!("Ignoring message for topic: {}", topic_str);
201-
202-
// accept this message for propagation
203196
self.p2p.validate_message(&message_id, &peer_id, gossipsub::MessageAcceptance::Accept)?;
204197
} else {
205-
log::warn!("Received message from unexpected topic: {}", topic_str);
206-
207198
// reject this message as its from a foreign topic
199+
log::warn!("Received message from unexpected topic: {}", topic_str);
208200
self.p2p.validate_message(&message_id, &peer_id, gossipsub::MessageAcceptance::Reject)?;
209201
}
210202
},

src/payloads/error.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use serde::{Deserialize, Serialize};
2+
3+
/// A generic task request, given by Dria.
4+
#[derive(Debug, Clone, Serialize, Deserialize)]
5+
#[serde(rename_all = "camelCase")]
6+
pub struct TaskErrorPayload {
7+
/// The unique identifier of the task.
8+
pub task_id: String,
9+
/// The stringified error object
10+
pub(crate) error: String,
11+
}
12+
13+
impl TaskErrorPayload {
14+
pub fn new(task_id: String, error: String) -> Self {
15+
Self { task_id, error }
16+
}
17+
}

src/payloads/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
mod error;
2+
pub use error::TaskErrorPayload;
3+
4+
mod request;
5+
pub use request::TaskRequestPayload;
6+
7+
mod response;
8+
pub use response::TaskResponsePayload;

src/payloads/request.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use crate::utils::{filter::FilterPayload, get_current_time_nanos};
2+
use fastbloom_rs::BloomFilter;
3+
use serde::{Deserialize, Serialize};
4+
use uuid::Uuid;
5+
6+
/// A generic task request, given by Dria.
7+
#[derive(Debug, Clone, Serialize, Deserialize)]
8+
#[serde(rename_all = "camelCase")]
9+
pub struct TaskRequestPayload<T> {
10+
/// The unique identifier of the task.
11+
pub task_id: String,
12+
/// The deadline of the task in nanoseconds.
13+
pub(crate) deadline: u128,
14+
/// The input to the compute function.
15+
pub(crate) input: T,
16+
/// The Bloom filter of the task.
17+
pub(crate) filter: FilterPayload,
18+
/// The public key of the requester, in hexadecimals.
19+
pub(crate) public_key: String,
20+
}
21+
22+
impl<T> TaskRequestPayload<T> {
23+
#[allow(unused)]
24+
pub fn new(input: T, filter: BloomFilter, time_ns: u128, public_key: Option<String>) -> Self {
25+
Self {
26+
task_id: Uuid::new_v4().into(),
27+
deadline: get_current_time_nanos() + time_ns,
28+
input,
29+
filter: filter.into(),
30+
public_key: public_key.unwrap_or_default(),
31+
}
32+
}
33+
}
Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
use super::crypto::sha256hash;
2-
use crate::utils::{filter::FilterPayload, get_current_time_nanos};
1+
use crate::utils::crypto::sha256hash;
32
use eyre::Result;
4-
use fastbloom_rs::BloomFilter;
53
use libsecp256k1::SecretKey;
64
use serde::{Deserialize, Serialize};
7-
use uuid::Uuid;
85

96
/// A computation task is the task of computing a result from a given input. The result is encrypted with the public key of the requester.
107
/// Plain result is signed by the compute node's private key, and a commitment is computed from the signature and plain result.
@@ -23,24 +20,20 @@ pub struct TaskResponsePayload {
2320
}
2421

2522
impl TaskResponsePayload {
26-
pub fn to_string(&self) -> Result<String> {
27-
serde_json::to_string(&serde_json::json!(self)).map_err(Into::into)
28-
}
29-
3023
/// Creates the payload of a computation result.
3124
///
3225
/// - Sign `task_id || payload` with node `self.secret_key`
3326
/// - Encrypt `result` with `task_public_key`
3427
pub fn new(
35-
payload: impl AsRef<[u8]>,
28+
result: impl AsRef<[u8]>,
3629
task_id: &str,
3730
encrypting_public_key: &[u8],
3831
signing_secret_key: &SecretKey,
3932
) -> Result<Self> {
4033
// create the message `task_id || payload`
4134
let mut preimage = Vec::new();
4235
preimage.extend_from_slice(task_id.as_ref());
43-
preimage.extend_from_slice(payload.as_ref());
36+
preimage.extend_from_slice(result.as_ref());
4437

4538
// sign the message
4639
// TODO: use `sign_recoverable` here instead?
@@ -50,7 +43,7 @@ impl TaskResponsePayload {
5043
let recid: [u8; 1] = [recid.serialize()];
5144

5245
// encrypt payload itself
53-
let ciphertext = ecies::encrypt(encrypting_public_key, payload.as_ref())?;
46+
let ciphertext = ecies::encrypt(encrypting_public_key, result.as_ref())?;
5447

5548
Ok(TaskResponsePayload {
5649
ciphertext: hex::encode(ciphertext),
@@ -59,32 +52,3 @@ impl TaskResponsePayload {
5952
})
6053
}
6154
}
62-
63-
/// A generic task request, given by Dria.
64-
#[derive(Debug, Clone, Serialize, Deserialize)]
65-
#[serde(rename_all = "camelCase")]
66-
pub struct TaskRequestPayload<T> {
67-
/// The unique identifier of the task.
68-
pub task_id: String,
69-
/// The deadline of the task in nanoseconds.
70-
pub(crate) deadline: u128,
71-
/// The input to the compute function.
72-
pub(crate) input: T,
73-
/// The Bloom filter of the task.
74-
pub(crate) filter: FilterPayload,
75-
/// The public key of the requester.
76-
pub(crate) public_key: String,
77-
}
78-
79-
impl<T> TaskRequestPayload<T> {
80-
#[allow(unused)]
81-
pub fn new(input: T, filter: BloomFilter, time_ns: u128, public_key: Option<String>) -> Self {
82-
Self {
83-
task_id: Uuid::new_v4().into(),
84-
deadline: get_current_time_nanos() + time_ns,
85-
input,
86-
filter: filter.into(),
87-
public_key: public_key.unwrap_or_default(),
88-
}
89-
}
90-
}

0 commit comments

Comments
 (0)