Skip to content

Commit a4d779e

Browse files
committed
slight log and naming rfks, update executor impls
1 parent 77eb012 commit a4d779e

File tree

21 files changed

+143
-149
lines changed

21 files changed

+143
-149
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default-members = ["compute"]
77

88
[workspace.package]
99
edition = "2021"
10-
version = "0.5.5"
10+
version = "0.5.6"
1111
license = "Apache-2.0"
1212
readme = "README.md"
1313

compute/src/main.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,8 @@ async fn main() -> Result<()> {
6969
});
7070

7171
// create configurations
72-
let model_names = dkn_utils::split_csv_line(&env::var("DKN_MODELS").unwrap_or_default())
73-
.into_iter()
74-
.filter_map(|s| Model::try_from(s.as_str()).ok())
75-
.collect::<Vec<_>>();
76-
let executors_config = DriaExecutorsManager::new_from_env_for_models(model_names)?;
72+
let models = Model::from_csv(env::var("DKN_MODELS").unwrap_or_default());
73+
let executors_config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?;
7774
if executors_config.models.is_empty() {
7875
return Err(eyre::eyre!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS."));
7976
}

compute/src/node/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl DriaComputeNode {
100100
protocol,
101101
)?;
102102

103-
// create workflow workers, all workers use the same publish channel
103+
// create channel for task executors, all workers use the same publish channel
104104
let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_BUFSIZE);
105105

106106
// check if we should create a worker for batch executor

compute/src/node/reqres.rs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ use dkn_p2p::libp2p::{
44
PeerId,
55
};
66
use dkn_p2p::DriaReqResMessage;
7-
use dkn_utils::payloads::{HEARTBEAT_TOPIC, SPECS_TOPIC, TASK_REQUEST_TOPIC};
8-
use eyre::{eyre, Result};
7+
use dkn_utils::{
8+
payloads::{HEARTBEAT_TOPIC, SPECS_TOPIC, TASK_REQUEST_TOPIC},
9+
DriaMessage,
10+
};
11+
use eyre::Result;
912

1013
use crate::{reqres::*, workers::task::TaskWorkerOutput};
1114

@@ -33,7 +36,7 @@ impl DriaComputeNode {
3336
if self.dria_rpc.peer_id != peer_id {
3437
log::warn!("Received request from unauthorized source: {}", peer_id);
3538
log::debug!("Allowed source: {}", self.dria_rpc.peer_id);
36-
} else if let Err(e) = self.handle_request(peer_id, request, channel).await {
39+
} else if let Err(e) = self.handle_request(peer_id, &request, channel).await {
3740
log::error!("Error handling request: {:?}", e);
3841
}
3942
}
@@ -61,6 +64,11 @@ impl DriaComputeNode {
6164
request_id: OutboundRequestId,
6265
data: Vec<u8>,
6366
) -> Result<()> {
67+
if peer_id != self.dria_rpc.peer_id {
68+
log::warn!("Received response from unauthorized source: {}", peer_id);
69+
log::debug!("Allowed source: {}", self.dria_rpc.peer_id);
70+
}
71+
6472
if let Ok(heartbeat_response) = HeartbeatRequester::try_parse_response(&data) {
6573
log::info!(
6674
"Received a {} response ({request_id}) from {peer_id}",
@@ -85,14 +93,18 @@ impl DriaComputeNode {
8593
async fn handle_request(
8694
&mut self,
8795
peer_id: PeerId,
88-
data: Vec<u8>,
96+
message_data: &[u8],
8997
channel: ResponseChannel<Vec<u8>>,
9098
) -> Result<()> {
91-
if let Ok(task_request) = TaskResponder::try_parse_request(&data) {
92-
self.handle_task_request(peer_id, task_request, channel)
93-
.await
94-
} else {
95-
Err(eyre::eyre!("Received unhandled request from {peer_id}"))
99+
let message = DriaMessage::from_slice_checked(
100+
message_data,
101+
self.p2p.protocol().name.clone(),
102+
self.config.version,
103+
)?;
104+
105+
match message.topic.as_str() {
106+
TASK_REQUEST_TOPIC => self.handle_task_request(peer_id, message, channel).await,
107+
_ => Err(eyre::eyre!("Received unhandled request from {peer_id}")),
96108
}
97109
}
98110

@@ -113,7 +125,7 @@ impl DriaComputeNode {
113125
);
114126

115127
let (task_input, task_metadata) =
116-
TaskResponder::prepare_worker_input(self, &task_request, channel).await?;
128+
TaskResponder::parse_task_request(self, &task_request, channel).await?;
117129
if let Err(e) = match task_input.task.is_batchable() {
118130
// this is a batchable task, send it to batch worker
119131
// and keep track of the task id in pending tasks
@@ -123,11 +135,7 @@ impl DriaComputeNode {
123135
.insert(task_input.row_id, task_metadata);
124136
tx.send(task_input).await
125137
}
126-
None => {
127-
return Err(eyre!(
128-
"Batchable workflow received but no worker available."
129-
));
130-
}
138+
None => eyre::bail!("Batchable task received but no worker available."),
131139
},
132140

133141
// this is a single task, send it to single worker
@@ -138,12 +146,10 @@ impl DriaComputeNode {
138146
.insert(task_input.row_id, task_metadata);
139147
tx.send(task_input).await
140148
}
141-
None => {
142-
return Err(eyre!("Single workflow received but no worker available."));
143-
}
149+
None => eyre::bail!("Single task received but no worker available."),
144150
},
145151
} {
146-
log::error!("Error sending workflow message: {:?}", e);
152+
log::error!("Could not send task to worker: {:?}", e);
147153
};
148154

149155
Ok(())
@@ -165,7 +171,7 @@ impl DriaComputeNode {
165171
// respond to the response channel with the result
166172
match task_metadata {
167173
Some(task_metadata) => {
168-
TaskResponder::send_output(self, task_response, task_metadata).await?;
174+
TaskResponder::send_task_output(self, task_response, task_metadata).await?;
169175
}
170176
None => {
171177
// totally unexpected case, wont happen at all

compute/src/reqres/task.rs

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,60 +11,54 @@ use crate::DriaComputeNode;
1111
pub struct TaskResponder;
1212

1313
impl super::IsResponder for TaskResponder {
14-
type Request = DriaMessage; // TODO: TaskRequestPayload<TaskWorkflow>;
15-
type Response = DriaMessage; // TODO: TaskResponsePayload;
14+
type Request = DriaMessage; // TODO: can we do this typed?
15+
type Response = DriaMessage; // TODO: can we do this typed?
1616
}
1717

1818
impl TaskResponder {
19-
pub(crate) async fn prepare_worker_input(
19+
pub(crate) async fn parse_task_request(
2020
node: &mut DriaComputeNode,
2121
compute_message: &DriaMessage,
2222
channel: ResponseChannel<Vec<u8>>,
2323
) -> Result<(TaskWorkerInput, TaskWorkerMetadata)> {
24-
// parse payload
2524
let task = compute_message
2625
.parse_payload::<TaskRequestPayload<TaskBody>>()
27-
.wrap_err("could not parse workflow task")?;
28-
log::info!("Handling task {}", task.row_id);
29-
30-
// record received time
26+
.wrap_err("could not parse task payload")?;
3127
let stats = TaskStats::new().record_received_at();
32-
3328
log::info!(
34-
"Using model {} for {} {}",
35-
task.input.model.to_string().yellow(),
29+
"Handling {} {} with model {}",
3630
"task".yellow(),
37-
task.row_id
31+
task.row_id,
32+
task.input.model.to_string().yellow()
3833
);
39-
let task_body = task.input;
4034

4135
// check if the model is available in this node, if so
4236
// it will return an executor that can run this model
4337
let executor = node
4438
.config
4539
.executors
46-
.get_executor(&task_body.model)
40+
.get_executor(&task.input.model)
4741
.await
4842
.wrap_err("could not get an executor")?;
4943

5044
let task_metadata = TaskWorkerMetadata {
5145
task_id: task.task_id,
5246
file_id: task.file_id,
53-
model_name: task_body.model.to_string(),
47+
model_name: task.input.model.to_string(),
5448
channel,
5549
};
5650
let task_input = TaskWorkerInput {
5751
executor,
58-
task: task_body,
52+
task: task.input,
5953
row_id: task.row_id,
6054
stats,
6155
};
6256

6357
Ok((task_input, task_metadata))
6458
}
6559

66-
/// Handles the result of a workflow task.
67-
pub(crate) async fn send_output(
60+
/// Handles the result of a task.
61+
pub(crate) async fn send_task_output(
6862
node: &mut DriaComputeNode,
6963
task_output: TaskWorkerOutput,
7064
task_metadata: TaskWorkerMetadata,

compute/src/workers/task.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ pub struct TaskWorkerOutput {
4040

4141
/// It is expected to be spawned in another thread, with [`Self::run_batch`] for batch processing and [`Self::run_series`] for single processing.
4242
pub struct TaskWorker {
43-
/// Workflow message channel receiver, the sender is most likely the compute node itself.
43+
/// Task channel receiver, the sender is most likely the compute node itself.
4444
task_rx: mpsc::Receiver<TaskWorkerInput>,
4545
/// Publish message channel sender, the receiver is most likely the compute node itself.
4646
publish_tx: mpsc::Sender<TaskWorkerOutput>,
4747
// TODO: batch size must be defined here
4848
}
4949

50-
/// Buffer size for workflow tasks (per worker).
50+
/// Buffer size for task channels (per worker).
5151
const TASK_RX_CHANNEL_BUFSIZE: usize = 1024;
5252

5353
impl TaskWorker {
@@ -225,11 +225,7 @@ impl TaskWorker {
225225
) {
226226
let batchable = input.task.is_batchable();
227227
input.stats = input.stats.record_execution_started_at();
228-
let result = input
229-
.executor
230-
// takes no explicit prompt input, everything is in the workflow
231-
.execute(input.task)
232-
.await;
228+
let result = input.executor.execute(input.task).await;
233229
input.stats = input.stats.record_execution_ended_at();
234230

235231
let output = TaskWorkerOutput {
@@ -240,7 +236,7 @@ impl TaskWorker {
240236
};
241237

242238
if let Err(e) = publish_tx.send(output).await {
243-
log::error!("Error sending workflow result: {}", e);
239+
log::error!("Error sending task result: {}", e);
244240
}
245241
}
246242
}
@@ -255,7 +251,7 @@ mod tests {
255251
/// ## Run command
256252
///
257253
/// ```sh
258-
/// cargo test --package dkn-compute --lib --all-features -- workers::workflow::tests::test_executor_worker --exact --show-output --nocapture --ignored
254+
/// cargo test --package dkn-compute --lib --all-features -- workers::task::tests::test_executor_worker --exact --show-output --nocapture --ignored
259255
/// ```
260256
#[tokio::test]
261257
#[ignore = "run manually"]
@@ -269,7 +265,7 @@ mod tests {
269265
let (publish_tx, mut publish_rx) = mpsc::channel(1024);
270266
let (mut worker, task_tx) = TaskWorker::new(publish_tx);
271267

272-
// create batch workflow worker
268+
// create batch worker
273269
let worker_handle = tokio::spawn(async move {
274270
worker.run_batch(4).await;
275271
});
@@ -290,7 +286,7 @@ mod tests {
290286
stats: TaskStats::default(),
291287
};
292288

293-
// send workflow to worker
289+
// send task to worker
294290
task_tx.send(task_input).await.unwrap();
295291
}
296292

executor/examples/gemini.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ async fn main() -> eyre::Result<()> {
66

77
let model = Model::Gemini2_0Flash;
88
let models = vec![model];
9-
let mut config = DriaExecutorsManager::new_from_env_for_models(models)?;
9+
let mut config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?;
1010
config.check_services().await?;
1111

1212
assert!(config.models.contains(&model));

executor/examples/ollama.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ async fn main() -> eyre::Result<()> {
66

77
let model = Model::Llama3_2_1bInstructQ4Km;
88
let models = vec![model];
9-
let mut config = DriaExecutorsManager::new_from_env_for_models(models)?;
9+
let mut config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?;
1010
config.check_services().await?;
1111

1212
assert!(config.models.contains(&model));

executor/examples/openai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ async fn main() -> eyre::Result<()> {
66

77
let model = Model::GPT4o;
88
let models = vec![model];
9-
let mut config = DriaExecutorsManager::new_from_env_for_models(models)?;
9+
let mut config = DriaExecutorsManager::new_from_env_for_models(models.into_iter())?;
1010
config.check_services().await?;
1111

1212
assert!(config.models.contains(&model));

0 commit comments

Comments
 (0)