Skip to content

Commit a62f978

Browse files
committed
some tests and envlogger fixes
1 parent c888180 commit a62f978

File tree

4 files changed

+117
-5
lines changed

4 files changed

+117
-5
lines changed

compute/src/node.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,13 +471,16 @@ impl DriaComputeNode {
471471
#[cfg(test)]
472472
mod tests {
473473
use super::*;
474-
use std::env;
475474

476475
#[tokio::test]
477476
#[ignore = "run this manually"]
478477
async fn test_publish_message() -> eyre::Result<()> {
479-
env::set_var("RUST_LOG", "none,dkn_compute=debug,dkn_p2p=debug");
480-
let _ = env_logger::builder().is_test(true).try_init();
478+
let _ = env_logger::builder()
479+
.filter_level(log::LevelFilter::Off)
480+
.filter_module("dkn_compute", log::LevelFilter::Debug)
481+
.filter_module("dkn_p2p", log::LevelFilter::Debug)
482+
.is_test(true)
483+
.try_init();
481484

482485
// create node
483486
let cancellation = CancellationToken::new();

compute/src/workers/workflow.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ impl WorkflowsWorker {
106106
// (1) there are no tasks, or,
107107
// (2) there are tasks less than the batch size and the channel is not empty
108108
while tasks.is_empty() || (tasks.len() < batch_size && !self.workflow_rx.is_empty()) {
109+
log::info!("Waiting for more workflows to process ({})", tasks.len());
109110
let limit = batch_size - tasks.len();
110111
match self.workflow_rx.recv_many(&mut tasks, limit).await {
111112
// 0 tasks returned means that the channel is closed
@@ -235,3 +236,107 @@ impl WorkflowsWorker {
235236
}
236237
}
237238
}
239+
240+
#[cfg(test)]
241+
mod tests {
242+
use super::*;
243+
use crate::payloads::TaskStats;
244+
245+
use dkn_workflows::{Executor, Model};
246+
use libsecp256k1::{PublicKey, SecretKey};
247+
use tokio::sync::mpsc;
248+
249+
// cargo test --package dkn-compute --lib --all-features -- workers::workflow::tests::test_workflows_worker --exact --show-output --nocapture --ignored
250+
#[tokio::test]
251+
#[ignore = "run manually"]
252+
async fn test_workflows_worker() {
253+
let _ = env_logger::builder()
254+
.filter_level(log::LevelFilter::Off)
255+
.filter_module("dkn_compute", log::LevelFilter::Debug)
256+
.is_test(true)
257+
.try_init();
258+
259+
let (publish_tx, mut publish_rx) = mpsc::channel(1024);
260+
let (mut worker, workflow_tx) = WorkflowsWorker::new(publish_tx);
261+
262+
// create batch workflow worker
263+
let worker_handle = tokio::spawn(async move {
264+
worker.run_batch(4).await;
265+
});
266+
267+
let num_tasks = 4;
268+
let model = Model::O1Preview;
269+
let workflow = serde_json::json!({
270+
"config": {
271+
"max_steps": 10,
272+
"max_time": 250,
273+
"tools": [""]
274+
},
275+
"tasks": [
276+
{
277+
"id": "A",
278+
"name": "",
279+
"description": "",
280+
"operator": "generation",
281+
"messages": [{ "role": "user", "content": "Write a 4 paragraph poem about Julius Caesar." }],
282+
"inputs": [],
283+
"outputs": [ { "type": "write", "key": "result", "value": "__result" } ]
284+
},
285+
{
286+
"id": "__end",
287+
"name": "end",
288+
"description": "End of the task",
289+
"operator": "end",
290+
"messages": [{ "role": "user", "content": "End of the task" }],
291+
"inputs": [],
292+
"outputs": []
293+
}
294+
],
295+
"steps": [ { "source": "A", "target": "__end" } ],
296+
"return_value": { "input": { "type": "read", "key": "result" }
297+
}
298+
});
299+
300+
for i in 0..num_tasks {
301+
log::info!("Sending task {}", i + 1);
302+
303+
let workflow = serde_json::from_value(workflow.clone()).unwrap();
304+
305+
let executor = Executor::new(model.clone());
306+
let input = WorkflowsWorkerInput {
307+
entry: None,
308+
executor,
309+
workflow,
310+
public_key: PublicKey::from_secret_key(&SecretKey::default()),
311+
task_id: "task_id".to_string(),
312+
model_name: model.to_string(),
313+
stats: TaskStats::default(),
314+
batchable: true,
315+
};
316+
317+
// send workflow to worker
318+
workflow_tx.send(input).await.unwrap();
319+
}
320+
321+
// now wait for all results
322+
let mut results = Vec::new();
323+
for i in 0..num_tasks {
324+
log::info!("Waiting for result {}", i + 1);
325+
let result = publish_rx.recv().await.unwrap();
326+
log::info!(
327+
"Got result {} (exeuction time: {})",
328+
i + 1,
329+
(result.stats.execution_time as f64) / 1_000_000_000f64
330+
);
331+
if result.result.is_err() {
332+
println!("Error: {:?}", result.result);
333+
}
334+
results.push(result);
335+
}
336+
337+
log::info!("Got all results, closing channel.");
338+
publish_rx.close();
339+
workflow_tx.worker_handle.await.unwrap();
340+
log::info!("Done.");
341+
}
342+
}

p2p/src/client.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ impl DriaP2PClient {
126126
// dial rpc nodes
127127
for rpc_addr in &nodes.rpc_nodes {
128128
log::info!("Dialing RPC node: {}", rpc_addr);
129-
swarm.dial(rpc_addr.clone())?;
129+
if let Err(e) = swarm.dial(rpc_addr.clone()) {
130+
log::error!("Error dialing RPC node: {:?}", e);
131+
};
130132
}
131133

132134
// create commander

p2p/tests/listen_test.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ async fn test_listen_topic_once() -> Result<()> {
88
const TOPIC: &str = "pong";
99

1010
let _ = env_logger::builder()
11-
.parse_filters("none,listen_test=debug,dkn_p2p=debug")
11+
.filter_level(log::LevelFilter::Off)
12+
.filter_module("listen_test", log::LevelFilter::Debug)
13+
.filter_module("dkn_p2p", log::LevelFilter::Debug)
1214
.is_test(true)
1315
.try_init();
1416

0 commit comments

Comments
 (0)