Skip to content

Commit 650bdf5

Browse files
authored
Merge pull request #157 from firstbatchxyz/erhant/batch-publish
feat: publish-in-task
2 parents 1c6fded + 819e4d6 commit 650bdf5

File tree

12 files changed

+120
-62
lines changed

12 files changed

+120
-62
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ DKN_ADMIN_PUBLIC_KEY=0208ef5e65a9c656a6f92fb2c770d5d5e2ecffe02a6aade19207f75110b
99
# example: phi3:3.8b,gpt-4o-mini
1010
DKN_MODELS=
1111

12+
1213
## DRIA (optional) ##
1314
# P2P address, you don't need to change this unless this port is already in use.
1415
DKN_P2P_LISTEN_ADDR=/ip4/0.0.0.0/tcp/4001
1516
# Comma-separated static relay nodes
1617
DKN_RELAY_NODES=
1718
# Comma-separated static bootstrap nodes
1819
DKN_BOOTSTRAP_NODES=
20+
# Batch size for workflows, you do not need to edit this.
21+
DKN_BATCH_SIZE=
1922

2023
## DRIA (profiling only, do not uncomment) ##
2124
# Set to a number of seconds to wait before exiting, only use in profiling build!

Cargo.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default-members = ["compute"]
88

99
[workspace.package]
1010
edition = "2021"
11-
version = "0.2.26"
11+
version = "0.2.27"
1212
license = "Apache-2.0"
1313
readme = "README.md"
1414

compute/src/config.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ use libsecp256k1::{PublicKey, SecretKey};
99

1010
use std::{env, str::FromStr};
1111

12+
// TODO: make this configurable later
13+
const DEFAULT_WORKFLOW_BATCH_SIZE: usize = 5;
14+
1215
#[derive(Debug, Clone)]
1316
pub struct DriaComputeNodeConfig {
1417
/// Wallet secret/private key.
@@ -25,6 +28,11 @@ pub struct DriaComputeNodeConfig {
2528
pub workflows: DriaWorkflowsConfig,
2629
/// Network type of the node.
2730
pub network_type: DriaNetworkType,
31+
/// Batch size for batchable workflows.
32+
///
33+
/// A higher value will help execute more tasks concurrently,
34+
/// at the risk of hitting rate-limits.
35+
pub batch_size: usize,
2836
}
2937

3038
/// The default P2P network listen address.
@@ -103,6 +111,11 @@ impl DriaComputeNodeConfig {
103111
.map(|s| DriaNetworkType::from(s.as_str()))
104112
.unwrap_or_default();
105113

114+
// parse batch size
115+
let batch_size = env::var("DKN_BATCH_SIZE")
116+
.map(|s| s.parse::<usize>().unwrap_or(DEFAULT_WORKFLOW_BATCH_SIZE))
117+
.unwrap_or(DEFAULT_WORKFLOW_BATCH_SIZE);
118+
106119
Self {
107120
admin_public_key,
108121
secret_key,
@@ -111,6 +124,7 @@ impl DriaComputeNodeConfig {
111124
workflows,
112125
p2p_listen_addr,
113126
network_type,
127+
batch_size,
114128
}
115129
}
116130

compute/src/handlers/workflow.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,7 @@ impl WorkflowHandler {
129129

130130
// convert payload to message
131131
let payload_str = serde_json::json!(payload).to_string();
132-
log::debug!(
133-
"Publishing result for task {}\n{}",
134-
task.task_id,
135-
payload_str
136-
);
132+
log::info!("Publishing result for task {}", task.task_id);
137133
DriaMessage::new(payload_str, Self::RESPONSE_TOPIC)
138134
}
139135
Err(err) => {
@@ -161,7 +157,7 @@ impl WorkflowHandler {
161157

162158
// try publishing the result
163159
if let Err(publish_err) = node.publish(message).await {
164-
let err_msg = format!("could not publish result: {:?}", publish_err);
160+
let err_msg = format!("Could not publish task result: {:?}", publish_err);
165161
log::error!("{}", err_msg);
166162

167163
let payload = serde_json::json!({

compute/src/main.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use dkn_workflows::DriaWorkflowsConfig;
33
use eyre::Result;
44
use std::env;
55
use tokio_util::{sync::CancellationToken, task::TaskTracker};
6+
use workers::workflow::WorkflowsWorker;
67

78
#[tokio::main]
89
async fn main() -> Result<()> {
@@ -86,6 +87,7 @@ async fn main() -> Result<()> {
8687
log::warn!("Using models: {:#?}", config.workflows.models);
8788

8889
// create the node
90+
let batch_size = config.batch_size;
8991
let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?;
9092

9193
// spawn p2p client first
@@ -94,14 +96,21 @@ async fn main() -> Result<()> {
9496

9597
// spawn batch worker thread if we are using such models (e.g. OpenAI, Gemini, OpenRouter)
9698
if let Some(mut worker_batch) = worker_batch {
97-
log::info!("Spawning workflows batch worker thread.");
98-
task_tracker.spawn(async move { worker_batch.run_batch().await });
99+
assert!(
100+
batch_size <= WorkflowsWorker::MAX_BATCH_SIZE,
101+
"batch size too large"
102+
);
103+
log::info!(
104+
"Spawning workflows batch worker thread. (batch size {})",
105+
batch_size
106+
);
107+
task_tracker.spawn(async move { worker_batch.run_batch(batch_size).await });
99108
}
100109

101110
// spawn single worker thread if we are using such models (e.g. Ollama)
102111
if let Some(mut worker_single) = worker_single {
103112
log::info!("Spawning workflows single worker thread.");
104-
task_tracker.spawn(async move { worker_single.run().await });
113+
task_tracker.spawn(async move { worker_single.run_series().await });
105114
}
106115

107116
// spawn compute node thread

compute/src/node.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ impl DriaComputeNode {
8181
let (p2p_client, p2p_commander, message_rx) = DriaP2PClient::new(
8282
keypair,
8383
config.p2p_listen_addr.clone(),
84-
available_nodes.bootstrap_nodes.clone().into_iter(),
85-
available_nodes.relay_nodes.clone().into_iter(),
86-
available_nodes.rpc_nodes.clone().into_iter(),
84+
&available_nodes,
8785
protocol,
8886
)?;
8987

compute/src/workers/workflow.rs

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,21 @@ pub struct WorkflowsWorkerOutput {
3232
///
3333
/// It is expected to be spawned in another thread, with `run_batch` for batch processing and `run` for single processing.
3434
pub struct WorkflowsWorker {
35+
/// Workflow message channel receiver, the sender is most likely the compute node itself.
3536
workflow_rx: mpsc::Receiver<WorkflowsWorkerInput>,
37+
/// Publish message channel sender, the receiver is most likely the compute node itself.
3638
publish_tx: mpsc::Sender<WorkflowsWorkerOutput>,
3739
}
3840

3941
/// Buffer size for workflow tasks (per worker).
4042
const WORKFLOW_CHANNEL_BUFSIZE: usize = 1024;
4143

4244
impl WorkflowsWorker {
43-
/// Batch size that defines how many tasks can be executed in parallel at once.
44-
/// IMPORTANT NOTE: `run` function is designed to handle the batch size here specifically,
45+
/// Batch size that defines how many tasks can be executed concurrently at once.
46+
///
47+
/// The `run` function is designed to handle the batch size here specifically,
4548
/// if there are more tasks than the batch size, the function will panic.
46-
const BATCH_SIZE: usize = 8;
49+
pub const MAX_BATCH_SIZE: usize = 8;
4750

4851
/// Creates a worker and returns the sender and receiver for the worker.
4952
pub fn new(
@@ -65,24 +68,20 @@ impl WorkflowsWorker {
6568
self.workflow_rx.close();
6669
}
6770

68-
/// Launches the thread that can process tasks one by one.
71+
/// Launches the thread that can process tasks one by one (in series).
6972
/// This function will block until the channel is closed.
7073
///
7174
/// It is suitable for task streams that consume local resources, unlike API calls.
72-
pub async fn run(&mut self) {
75+
pub async fn run_series(&mut self) {
7376
loop {
7477
let task = self.workflow_rx.recv().await;
7578

76-
let result = if let Some(task) = task {
79+
if let Some(task) = task {
7780
log::info!("Processing single workflow for task {}", task.task_id);
78-
WorkflowsWorker::execute(task).await
81+
WorkflowsWorker::execute((task, self.publish_tx.clone())).await
7982
} else {
8083
return self.shutdown();
8184
};
82-
83-
if let Err(e) = self.publish_tx.send(result).await {
84-
log::error!("Error sending workflow result: {}", e);
85-
}
8685
}
8786
}
8887

@@ -91,13 +90,16 @@ impl WorkflowsWorker {
9190
///
9291
/// It is suitable for task streams that make use of API calls, unlike Ollama-like
9392
/// tasks that consumes local resources and would not make sense to run in parallel.
94-
pub async fn run_batch(&mut self) {
93+
///
94+
/// Batch size must NOT be larger than `MAX_BATCH_SIZE`, otherwise will panic.
95+
pub async fn run_batch(&mut self, batch_size: usize) {
96+
// TODO: need some better batch_size error handling here
9597
loop {
9698
// get tasks in batch from the channel
9799
let mut task_buffer = Vec::new();
98100
let num_tasks = self
99101
.workflow_rx
100-
.recv_many(&mut task_buffer, Self::BATCH_SIZE)
102+
.recv_many(&mut task_buffer, batch_size)
101103
.await;
102104

103105
if num_tasks == 0 {
@@ -106,8 +108,10 @@ impl WorkflowsWorker {
106108

107109
// process the batch
108110
log::info!("Processing {} workflows in batch", num_tasks);
109-
let mut batch = task_buffer.into_iter();
110-
let results = match num_tasks {
111+
let mut batch = task_buffer
112+
.into_iter()
113+
.map(|b| (b, self.publish_tx.clone()));
114+
match num_tasks {
111115
1 => {
112116
let r0 = WorkflowsWorker::execute(batch.next().unwrap()).await;
113117
vec![r0]
@@ -186,23 +190,17 @@ impl WorkflowsWorker {
186190
unreachable!(
187191
"number of tasks cant be larger than batch size ({} > {})",
188192
num_tasks,
189-
Self::BATCH_SIZE
193+
Self::MAX_BATCH_SIZE
190194
);
191195
}
192196
};
193-
194-
// publish all results
195-
log::info!("Publishing {} workflow results", results.len());
196-
for result in results {
197-
if let Err(e) = self.publish_tx.send(result).await {
198-
log::error!("Error sending workflow result: {}", e);
199-
}
200-
}
201197
}
202198
}
203199

204-
/// A single task execution.
205-
pub async fn execute(input: WorkflowsWorkerInput) -> WorkflowsWorkerOutput {
200+
/// Executes a single task, and publishes the output.
201+
pub async fn execute(
202+
(input, publish_tx): (WorkflowsWorkerInput, mpsc::Sender<WorkflowsWorkerOutput>),
203+
) {
206204
let mut memory = ProgramMemory::new();
207205

208206
let started_at = std::time::Instant::now();
@@ -211,13 +209,17 @@ impl WorkflowsWorker {
211209
.execute(input.entry.as_ref(), &input.workflow, &mut memory)
212210
.await;
213211

214-
WorkflowsWorkerOutput {
212+
let output = WorkflowsWorkerOutput {
215213
result,
216214
public_key: input.public_key,
217215
task_id: input.task_id,
218216
model_name: input.model_name,
219217
batchable: input.batchable,
220218
stats: input.stats.record_execution_time(started_at),
219+
};
220+
221+
if let Err(e) = publish_tx.send(output).await {
222+
log::error!("Error sending workflow result: {}", e);
221223
}
222224
}
223225
}

monitor/src/main.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@ async fn main() -> eyre::Result<()> {
3333
let (client, commander, msg_rx) = DriaP2PClient::new(
3434
keypair,
3535
listen_addr,
36-
nodes.bootstrap_nodes.into_iter(),
37-
nodes.relay_nodes.into_iter(),
38-
nodes.rpc_nodes.into_iter(),
36+
&nodes,
3937
DriaP2PProtocol::new_major_minor(network.protocol_name()),
4038
)?;
4139

0 commit comments

Comments
 (0)