Skip to content

Commit 907d0d2

Browse files
committed
added task monitor
1 parent 014c5c8 commit 907d0d2

File tree

14 files changed

+262
-70
lines changed

14 files changed

+262
-70
lines changed

compute/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ base64 = "0.22.0"
2828
hex = "0.4.3"
2929
hex-literal = "0.4.1"
3030
uuid = { version = "1.8.0", features = ["v4"] }
31+
rand.workspace = true
3132

3233
# logging & errors
33-
rand.workspace = true
3434
env_logger.workspace = true
3535
log.workspace = true
3636
eyre.workspace = true

compute/src/bin/monitor.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use dkn_compute::{refresh_dria_nodes, DriaMonitorNode};
2+
use dkn_p2p::{
3+
libp2p_identity::Keypair, DriaNetworkType, DriaNodes, DriaP2PClient, DriaP2PProtocol,
4+
};
5+
use tokio_util::sync::CancellationToken;
6+
7+
#[tokio::main]
8+
async fn main() -> eyre::Result<()> {
9+
dotenvy::dotenv().expect("could not load .env");
10+
11+
env_logger::builder()
12+
.filter(None, log::LevelFilter::Off)
13+
.filter_module("dkn_p2p", log::LevelFilter::Info)
14+
.filter_module("dkn_compute", log::LevelFilter::Info)
15+
.filter_module("monitor", log::LevelFilter::Info)
16+
.parse_default_env() // reads RUST_LOG variable
17+
.init();
18+
19+
log::info!("Starting Dria Task Monitor");
20+
21+
let network = DriaNetworkType::Pro;
22+
let mut nodes = DriaNodes::new(network);
23+
refresh_dria_nodes(&mut nodes).await?;
24+
25+
// setup p2p client
26+
let keypair = Keypair::generate_secp256k1();
27+
log::info!("PeerID: {}", keypair.public().to_peer_id());
28+
let (client, commander, msg_rx) = DriaP2PClient::new(
29+
Keypair::generate_secp256k1(),
30+
"/ip4/0.0.0.0/tcp/4069".parse()?,
31+
nodes.bootstrap_nodes.into_iter(),
32+
nodes.relay_nodes.into_iter(),
33+
nodes.rpc_nodes.into_iter(),
34+
DriaP2PProtocol::new_major_minor(network.protocol_name()),
35+
)?;
36+
37+
// spawn p2p task
38+
let token = CancellationToken::new();
39+
let p2p_handle = tokio::spawn(async move { client.run().await });
40+
41+
// wait for SIGTERM & SIGINT signal in another thread
42+
let sig_token = token.clone();
43+
let sig_handle = tokio::spawn(async move {
44+
use tokio::signal::unix::{signal, SignalKind};
45+
46+
let mut sigterm = signal(SignalKind::terminate()).unwrap(); // Docker sends SIGTERM
47+
let mut sigint = signal(SignalKind::interrupt()).unwrap(); // Ctrl+C sends SIGINT
48+
tokio::select! {
49+
_ = sigterm.recv() => log::warn!("Recieved SIGTERM"),
50+
_ = sigint.recv() => log::warn!("Recieved SIGINT"),
51+
_ = sig_token.cancelled() => return,
52+
};
53+
sig_token.cancel();
54+
});
55+
56+
// create monitor node
57+
let mut monitor = DriaMonitorNode::new(commander, msg_rx);
58+
59+
// setup monitor
60+
monitor.setup().await?;
61+
monitor.run(token).await;
62+
monitor.shutdown().await?;
63+
64+
log::info!("Waiting for task handles...");
65+
p2p_handle.await?;
66+
sig_handle.await?;
67+
68+
log::info!("Done!");
69+
Ok(())
70+
}

compute/src/config.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub(crate) const DEFAULT_P2P_LISTEN_ADDR: &str = "/ip4/0.0.0.0/tcp/4001";
3333
#[allow(clippy::new_without_default)]
3434
impl DriaComputeNodeConfig {
3535
/// Creates new config from environment variables.
36-
pub fn new() -> Self {
36+
pub fn new(workflows: DriaWorkflowsConfig) -> Self {
3737
let secret_key = match env::var("DKN_WALLET_SECRET_KEY") {
3838
Ok(secret_env) => {
3939
let secret_dec = hex::decode(secret_env.trim_start_matches("0x"))
@@ -91,15 +91,7 @@ impl DriaComputeNodeConfig {
9191
hex::encode(admin_public_key.serialize_compressed())
9292
);
9393

94-
let workflows =
95-
DriaWorkflowsConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default());
96-
#[cfg(not(test))]
97-
if workflows.models.is_empty() {
98-
log::error!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS.");
99-
panic!("No models provided.");
100-
}
101-
log::info!("Configured models: {:?}", workflows.models);
102-
94+
// parse listen address
10395
let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR")
10496
.map(|addr| addr.trim_matches('"').to_string())
10597
.unwrap_or(DEFAULT_P2P_LISTEN_ADDR.to_string());
@@ -152,7 +144,7 @@ impl Default for DriaComputeNodeConfig {
152144
);
153145
env::set_var("DKN_MODELS", "gpt-3.5-turbo");
154146

155-
Self::new()
147+
Self::new(Default::default())
156148
}
157149
}
158150

compute/src/handlers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ mod pingpong;
22
pub use pingpong::PingpongHandler;
33

44
mod workflow;
5-
pub use workflow::WorkflowHandler;
5+
pub use workflow::{WorkflowHandler, WorkflowPayload};

compute/src/handlers/pingpong.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ use serde::{Deserialize, Serialize};
88
pub struct PingpongHandler;
99

1010
#[derive(Serialize, Deserialize, Debug, Clone)]
11-
struct PingpongPayload {
11+
pub struct PingpongPayload {
1212
/// UUID of the ping request, prevents replay attacks.
1313
uuid: String,
1414
/// Deadline for the ping request.
1515
deadline: u128,
1616
}
1717

1818
#[derive(Serialize, Deserialize, Debug, Clone)]
19-
struct PingpongResponse {
19+
pub struct PingpongResponse {
2020
/// UUID as given in the ping payload.
2121
pub(crate) uuid: String,
2222
/// Models available in the node.
@@ -26,8 +26,8 @@ struct PingpongResponse {
2626
}
2727

2828
impl PingpongHandler {
29-
pub(crate) const LISTEN_TOPIC: &'static str = "ping";
30-
pub(crate) const RESPONSE_TOPIC: &'static str = "pong";
29+
pub const LISTEN_TOPIC: &'static str = "ping";
30+
pub const RESPONSE_TOPIC: &'static str = "pong";
3131

3232
/// Handles the ping message and responds with a pong message.
3333
///

compute/src/handlers/workflow.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ use libsecp256k1::PublicKey;
66
use serde::Deserialize;
77
use tokio_util::either::Either;
88

9-
use crate::payloads::{TaskErrorPayload, TaskRequestPayload, TaskResponsePayload, TaskStats};
9+
use crate::payloads::*;
1010
use crate::utils::DriaMessage;
1111
use crate::workers::workflow::*;
1212
use crate::DriaComputeNode;
1313

1414
pub struct WorkflowHandler;
1515

1616
#[derive(Debug, Deserialize)]
17-
struct WorkflowPayload {
17+
pub struct WorkflowPayload {
1818
/// [Workflow](https://github.com/andthattoo/ollama-workflows/blob/main/src/program/workflow.rs) object to be parsed.
1919
pub(crate) workflow: Workflow,
2020
/// A lıst of model (that can be parsed into `Model`) or model provider names.
@@ -27,8 +27,8 @@ struct WorkflowPayload {
2727
}
2828

2929
impl WorkflowHandler {
30-
pub(crate) const LISTEN_TOPIC: &'static str = "task";
31-
pub(crate) const RESPONSE_TOPIC: &'static str = "results";
30+
pub const LISTEN_TOPIC: &'static str = "task";
31+
pub const RESPONSE_TOPIC: &'static str = "results";
3232

3333
pub(crate) async fn handle_compute(
3434
node: &mut DriaComputeNode,

compute/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub(crate) mod config;
22
pub(crate) mod handlers;
3+
pub(crate) mod monitor;
34
pub(crate) mod node;
45
pub(crate) mod payloads;
56
pub(crate) mod utils;
@@ -9,5 +10,9 @@ pub(crate) mod workers;
910
/// This value is attached within the published messages.
1011
pub const DRIA_COMPUTE_NODE_VERSION: &str = env!("CARGO_PKG_VERSION");
1112

13+
pub use utils::refresh_dria_nodes;
14+
1215
pub use config::DriaComputeNodeConfig;
1316
pub use node::DriaComputeNode;
17+
18+
pub use monitor::DriaMonitorNode;

compute/src/main.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use dkn_compute::*;
2+
use dkn_workflows::DriaWorkflowsConfig;
23
use eyre::Result;
34
use std::env;
45
use tokio_util::{sync::CancellationToken, task::TaskTracker};
@@ -61,7 +62,14 @@ async fn main() -> Result<()> {
6162
});
6263

6364
// create configurations & check required services & address in use
64-
let mut config = DriaComputeNodeConfig::new();
65+
let workflows_config =
66+
DriaWorkflowsConfig::new_from_csv(&env::var("DKN_MODELS").unwrap_or_default());
67+
if workflows_config.models.is_empty() {
68+
return Err(eyre::eyre!("No models were provided, make sure to restart with at least one model provided within DKN_MODELS."));
69+
}
70+
71+
log::info!("Configured models: {:?}", workflows_config.models);
72+
let mut config = DriaComputeNodeConfig::new(workflows_config);
6573
config.assert_address_not_in_use()?;
6674
// check services & models, will exit if there is an error
6775
// since service check can take time, we allow early-exit here as well

compute/src/monitor.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
use std::collections::HashMap;
2+
3+
use crate::{
4+
handlers::{WorkflowHandler, WorkflowPayload},
5+
payloads::{TaskRequestPayload, TaskResponsePayload},
6+
utils::DriaMessage,
7+
};
8+
use dkn_p2p::{
9+
libp2p::{
10+
gossipsub::{Message, MessageId},
11+
PeerId,
12+
},
13+
DriaP2PCommander,
14+
};
15+
use eyre::Result;
16+
use tokio::sync::mpsc;
17+
use tokio_util::sync::CancellationToken;
18+
19+
const TASK_PRINT_INTERVAL_SECS: u64 = 20;
20+
21+
pub struct DriaMonitorNode {
22+
pub p2p: DriaP2PCommander,
23+
pub msg_rx: mpsc::Receiver<(PeerId, MessageId, Message)>,
24+
25+
// task monitoring
26+
pub tasks: HashMap<String, TaskRequestPayload<WorkflowPayload>>,
27+
pub results: HashMap<String, TaskResponsePayload>,
28+
}
29+
30+
impl DriaMonitorNode {
31+
pub fn new(
32+
p2p: DriaP2PCommander,
33+
msg_rx: mpsc::Receiver<(PeerId, MessageId, Message)>,
34+
) -> Self {
35+
Self {
36+
p2p,
37+
msg_rx,
38+
tasks: HashMap::new(),
39+
results: HashMap::new(),
40+
}
41+
}
42+
pub async fn setup(&self) -> Result<()> {
43+
self.p2p.subscribe(WorkflowHandler::LISTEN_TOPIC).await?;
44+
self.p2p.subscribe(WorkflowHandler::RESPONSE_TOPIC).await?;
45+
46+
Ok(())
47+
}
48+
49+
pub async fn shutdown(&mut self) -> Result<()> {
50+
log::info!("Shutting down monitor");
51+
self.p2p.unsubscribe(WorkflowHandler::LISTEN_TOPIC).await?;
52+
self.p2p
53+
.unsubscribe(WorkflowHandler::RESPONSE_TOPIC)
54+
.await?;
55+
56+
self.p2p.shutdown().await?;
57+
self.msg_rx.close();
58+
59+
Ok(())
60+
}
61+
62+
pub async fn run(&mut self, token: CancellationToken) {
63+
let mut task_print_interval =
64+
tokio::time::interval(tokio::time::Duration::from_secs(TASK_PRINT_INTERVAL_SECS));
65+
66+
loop {
67+
tokio::select! {
68+
message = self.msg_rx.recv() => match message {
69+
Some(message) => match self.handle_message(message).await {
70+
Ok(_) => {}
71+
Err(e) => log::error!("Error handling message: {:?}", e),
72+
}
73+
None => break, // channel closed
74+
},
75+
_ = task_print_interval.tick() => {
76+
log::info!("Current seen tasks: {:#?}", self.tasks.keys().collect::<Vec<_>>());
77+
}
78+
_ = token.cancelled() => break,
79+
}
80+
}
81+
}
82+
83+
async fn handle_message(
84+
&mut self,
85+
(peer_id, message_id, gossipsub_message): (PeerId, MessageId, Message),
86+
) -> Result<()> {
87+
log::info!(
88+
"Received {} message {} from {}",
89+
gossipsub_message.topic,
90+
message_id,
91+
peer_id
92+
);
93+
94+
// accept all message regardless immediately
95+
self.p2p
96+
.validate_message(
97+
&message_id,
98+
&peer_id,
99+
dkn_p2p::libp2p::gossipsub::MessageAcceptance::Accept,
100+
)
101+
.await?;
102+
103+
// parse message, ignore signatures
104+
let message: DriaMessage = serde_json::from_slice(&gossipsub_message.data)?;
105+
106+
match message.topic.as_str() {
107+
WorkflowHandler::LISTEN_TOPIC => {
108+
let payload: TaskRequestPayload<WorkflowPayload> = message.parse_payload(true)?;
109+
self.tasks.insert(payload.task_id.clone(), payload);
110+
}
111+
WorkflowHandler::RESPONSE_TOPIC => {
112+
let payload: TaskResponsePayload = message.parse_payload(false)?;
113+
self.results.insert(payload.task_id.clone(), payload);
114+
}
115+
_ => { /* ignore */ }
116+
}
117+
Ok(())
118+
}
119+
}

0 commit comments

Comments
 (0)