Skip to content

Commit 832df64

Browse files
committed
fix piggybacking issues
1 parent 0d5b69c commit 832df64

File tree

5 files changed

+46
-46
lines changed

5 files changed

+46
-46
lines changed

compute/src/node/core.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ impl DriaComputeNode {
3232
publish_msg_opt = self.publish_rx.recv() => {
3333
if let Some(publish_msg) = publish_msg_opt {
3434
// remove the task from pending tasks based on its batchability
35-
let channel = match publish_msg.batchable {
35+
let task_metadata = match publish_msg.batchable {
3636
true => {
3737
self.completed_tasks_batch += 1;
3838
self.pending_tasks_batch.remove(&publish_msg.task_id)
@@ -44,7 +44,7 @@ impl DriaComputeNode {
4444
};
4545

4646
// respond to the request
47-
match channel {
47+
match task_metadata {
4848
Some(channel) => {
4949
TaskResponder::handle_respond(self, publish_msg, channel).await?;
5050
}

compute/src/node/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::{
1414
config::*,
1515
gossipsub::*,
1616
utils::{crypto::secret_to_keypair, refresh_dria_nodes, SpecCollector},
17-
workers::task::{TaskWorker, TaskWorkerInput, TaskWorkerOutput},
17+
workers::task::{TaskWorker, TaskWorkerInput, TaskWorkerMetadata, TaskWorkerOutput},
1818
};
1919

2020
mod core;
@@ -45,9 +45,9 @@ pub struct DriaComputeNode {
4545
/// Task worker transmitter to send single tasks.
4646
task_single_tx: Option<mpsc::Sender<TaskWorkerInput>>,
4747
// Single tasks
48-
pending_tasks_single: HashMap<String, ResponseChannel<Vec<u8>>>,
48+
pending_tasks_single: HashMap<String, TaskWorkerMetadata>,
4949
// Batchable tasks
50-
pending_tasks_batch: HashMap<String, ResponseChannel<Vec<u8>>>,
50+
pending_tasks_batch: HashMap<String, TaskWorkerMetadata>,
5151
/// Completed single tasks count
5252
completed_tasks_single: usize,
5353
/// Completed batch tasks count

compute/src/node/reqres.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,16 @@ impl DriaComputeNode {
2727
} else if let Ok(task_request) = TaskResponder::try_parse_request(&data) {
2828
log::info!("Received a task request from {}", peer_id);
2929

30-
let task_tx_message = TaskResponder::handle_compute(self, &task_request).await?;
31-
if let Err(e) = match task_tx_message.batchable {
30+
let (task_input, task_metadata) =
31+
TaskResponder::prepare_worker_input(self, &task_request, channel).await?;
32+
if let Err(e) = match task_input.batchable {
3233
// this is a batchable task, send it to batch worker
3334
// and keep track of the task id in pending tasks
3435
true => match self.task_batch_tx {
3536
Some(ref mut tx) => {
3637
self.pending_tasks_batch
37-
.insert(task_tx_message.task_id.clone(), channel);
38-
tx.send(task_tx_message).await
38+
.insert(task_input.task_id.clone(), task_metadata);
39+
tx.send(task_input).await
3940
}
4041
None => {
4142
return Err(eyre!(
@@ -49,8 +50,8 @@ impl DriaComputeNode {
4950
false => match self.task_single_tx {
5051
Some(ref mut tx) => {
5152
self.pending_tasks_single
52-
.insert(task_tx_message.task_id.clone(), channel);
53-
tx.send(task_tx_message).await
53+
.insert(task_input.task_id.clone(), task_metadata);
54+
tx.send(task_input).await
5455
}
5556
None => {
5657
return Err(eyre!("Single workflow received but no worker available."));

compute/src/reqres/task.rs

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,11 @@ pub struct TaskPayload {
3636

3737
impl TaskResponder {
3838
/// Handles the compute message for workflows.
39-
///
40-
/// - FIXME: DOES NOT CHECK FOR FILTER AS IT IS NO LONGER USED
41-
/// - FIXME: GIVES ERROR ON DEADLINE PAST CASE, BUT WE DONT NEED DEADLINE AS WELL
42-
pub(crate) async fn handle_compute(
39+
pub(crate) async fn prepare_worker_input(
4340
node: &mut DriaComputeNode,
4441
compute_message: &DriaMessage,
45-
) -> Result<TaskWorkerInput> {
42+
channel: ResponseChannel<Vec<u8>>,
43+
) -> Result<(TaskWorkerInput, TaskWorkerMetadata)> {
4644
// parse payload
4745
let task = compute_message
4846
.parse_payload::<TaskRequestPayload<TaskPayload>>()
@@ -52,7 +50,7 @@ impl TaskResponder {
5250
let stats = TaskStats::new().record_received_at();
5351

5452
// check if deadline is past or not
55-
// with request-response, we dont expect this to happen much
53+
// FIXME: with request-response, we dont expect this to happen much
5654
if get_current_time_nanos() >= task.deadline {
5755
return Err(eyre!(
5856
"Task {} is past the deadline, ignoring",
@@ -97,34 +95,40 @@ impl TaskResponder {
9795
// get workflow as well
9896
let workflow = task.input.workflow;
9997

100-
Ok(TaskWorkerInput {
98+
let task_input = TaskWorkerInput {
10199
entry,
102100
executor,
103101
workflow,
104-
model_name,
105102
task_id: task.task_id,
106-
public_key: task_public_key,
107103
stats,
108104
batchable,
109-
})
105+
};
106+
107+
let task_metadata = TaskWorkerMetadata {
108+
model_name,
109+
public_key: task_public_key,
110+
channel,
111+
};
112+
113+
Ok((task_input, task_metadata))
110114
}
111115

112116
/// Handles the result of a workflow task.
113117
pub(crate) async fn handle_respond(
114118
node: &mut DriaComputeNode,
115-
task: TaskWorkerOutput,
116-
channel: ResponseChannel<Vec<u8>>,
119+
task_output: TaskWorkerOutput,
120+
task_metadata: TaskWorkerMetadata,
117121
) -> Result<()> {
118-
let response = match task.result {
122+
let response = match task_output.result {
119123
Ok(result) => {
120124
// prepare signed and encrypted payload
121-
log::info!("Publishing result for task {}", task.task_id);
125+
log::info!("Publishing result for task {}", task_output.task_id);
122126
let payload = TaskResponsePayload::new(
123127
result,
124-
&task.task_id,
125-
&task.public_key,
126-
task.model_name,
127-
task.stats.record_published_at(),
128+
&task_output.task_id,
129+
&task_metadata.public_key,
130+
task_metadata.model_name,
131+
task_output.stats.record_published_at(),
128132
)?;
129133

130134
// convert payload to message
@@ -135,14 +139,14 @@ impl TaskResponder {
135139
Err(err) => {
136140
// use pretty display string for error logging with causes
137141
let err_string = format!("{:#}", err);
138-
log::error!("Task {} failed: {}", task.task_id, err_string);
142+
log::error!("Task {} failed: {}", task_output.task_id, err_string);
139143

140144
// prepare error payload
141145
let error_payload = TaskErrorPayload {
142-
task_id: task.task_id,
146+
task_id: task_output.task_id,
143147
error: err_string,
144-
model: task.model_name,
145-
stats: task.stats.record_published_at(),
148+
model: task_metadata.model_name,
149+
stats: task_output.stats.record_published_at(),
146150
};
147151
let error_payload_str = serde_json::json!(error_payload).to_string();
148152

@@ -152,7 +156,7 @@ impl TaskResponder {
152156

153157
// respond through the channel
154158
let data = response.to_bytes()?;
155-
node.p2p.respond(data, channel).await?;
159+
node.p2p.respond(data, task_metadata.channel).await?;
156160

157161
Ok(())
158162
}

compute/src/workers/task.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
1+
use dkn_p2p::libp2p::request_response::ResponseChannel;
12
use dkn_workflows::{Entry, ExecutionError, Executor, Workflow};
23
use libsecp256k1::PublicKey;
34
use tokio::sync::mpsc;
45

56
use crate::payloads::TaskStats;
67

7-
// TODO: instead of piggybacking stuff here, maybe node can hold it in a hashmap w.r.t taskId
8+
pub struct TaskWorkerMetadata {
9+
pub public_key: PublicKey,
10+
pub model_name: String,
11+
pub channel: ResponseChannel<Vec<u8>>,
12+
}
813

914
pub struct TaskWorkerInput {
1015
pub entry: Option<Entry>,
1116
pub executor: Executor,
1217
pub workflow: Workflow,
1318
pub task_id: String,
14-
// piggybacked
15-
pub public_key: PublicKey,
16-
pub model_name: String,
1719
pub stats: TaskStats,
1820
pub batchable: bool,
1921
}
2022

2123
pub struct TaskWorkerOutput {
2224
pub result: Result<String, ExecutionError>,
2325
pub task_id: String,
24-
// piggybacked
25-
pub public_key: PublicKey,
26-
pub model_name: String,
2726
pub stats: TaskStats,
2827
pub batchable: bool,
2928
}
@@ -129,6 +128,7 @@ impl TaskWorker {
129128
"number of tasks cant be larger than batch size"
130129
);
131130
debug_assert!(num_tasks != 0, "number of tasks cant be zero");
131+
132132
log::info!("Processing {} tasks in batch", num_tasks);
133133
let mut batch = tasks.into_iter().map(|b| (b, &self.publish_tx));
134134
match num_tasks {
@@ -226,9 +226,7 @@ impl TaskWorker {
226226

227227
let output = TaskWorkerOutput {
228228
result,
229-
public_key: input.public_key,
230229
task_id: input.task_id,
231-
model_name: input.model_name,
232230
batchable: input.batchable,
233231
stats: input.stats,
234232
};
@@ -242,7 +240,6 @@ impl TaskWorker {
242240
#[cfg(test)]
243241
mod tests {
244242
use dkn_workflows::{Executor, Model};
245-
use libsecp256k1::{PublicKey, SecretKey};
246243

247244
use super::*;
248245
use crate::payloads::TaskStats;
@@ -311,9 +308,7 @@ mod tests {
311308
entry: None,
312309
executor,
313310
workflow,
314-
public_key: PublicKey::from_secret_key(&SecretKey::default()),
315311
task_id: format!("task-{}", i + 1),
316-
model_name: model.to_string(),
317312
stats: TaskStats::default(),
318313
batchable: true,
319314
};

0 commit comments

Comments
 (0)