Skip to content

Commit bb28960

Browse files
committed
available node logic added via api call, update workflows, trait for compute
1 parent 4c71e9a commit bb28960

File tree

14 files changed

+318
-161
lines changed

14 files changed

+318
-161
lines changed

Cargo.lock

Lines changed: 16 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "dkn-compute"
3-
version = "0.1.3"
3+
version = "0.1.4"
44
edition = "2021"
55
license = "Apache-2.0"
66
readme = "README.md"
@@ -12,6 +12,7 @@ parking_lot = "0.12.2"
1212
serde = { version = "1.0", features = ["derive"] }
1313
serde_json = "1.0"
1414
async-trait = "0.1.81"
15+
reqwest = "0.12.5"
1516

1617
# utilities
1718
base64 = "0.22.0"
@@ -34,8 +35,7 @@ sha3 = "0.10.8"
3435
fastbloom-rs = "0.5.9"
3536

3637
# workflows
37-
ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows", rev = "274b26e" }
38-
ollama-rs = "0.2.0"
38+
ollama-workflows = { git = "https://github.com/andthattoo/ollama-workflows", rev = "25467d2" }
3939

4040
# peer-to-peer
4141
libp2p = { version = "0.53", features = [

examples/common/ollama.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::time::SystemTime;
22

3-
use ollama_rs::{
3+
use ollama_workflows::ollama_rs::{
44
generation::completion::{request::GenerationRequest, GenerationResponse},
55
Ollama,
66
};

src/config/ollama.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ollama_rs::Ollama;
1+
use ollama_workflows::ollama_rs::Ollama;
22

33
const DEFAULT_OLLAMA_HOST: &str = "http://127.0.0.1";
44
const DEFAULT_OLLAMA_PORT: u16 = 11434;

src/errors/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use ollama_rs::error::OllamaError;
1+
use ollama_workflows::ollama_rs::error::OllamaError;
22

33
/// Alias for `Result<T, NodeError>`.
44
pub type NodeResult<T> = std::result::Result<T, NodeError>;
@@ -96,6 +96,15 @@ impl From<libp2p::gossipsub::SubscriptionError> for NodeError {
9696
}
9797
}
9898

99+
impl From<reqwest::Error> for NodeError {
100+
fn from(value: reqwest::Error) -> Self {
101+
Self {
102+
message: value.to_string(),
103+
source: "reqwest".to_string(),
104+
}
105+
}
106+
}
107+
99108
impl From<libp2p::gossipsub::PublishError> for NodeError {
100109
fn from(value: libp2p::gossipsub::PublishError) -> Self {
101110
Self {

src/handlers/mod.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
1+
use async_trait::async_trait;
2+
use libp2p::gossipsub::MessageAcceptance;
3+
14
mod pingpong;
2-
pub use pingpong::HandlesPingpong;
5+
pub use pingpong::PingpongHandler;
36

47
mod workflow;
5-
pub use workflow::HandlesWorkflow;
8+
pub use workflow::WorkflowHandler;
9+
10+
use crate::{errors::NodeResult, p2p::P2PMessage, DriaComputeNode};
11+
12+
#[async_trait]
13+
pub trait ComputeHandler {
14+
async fn handle_compute(
15+
node: &mut DriaComputeNode,
16+
message: P2PMessage,
17+
result_topic: &str,
18+
) -> NodeResult<MessageAcceptance>;
19+
}

src/handlers/pingpong.rs

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
use crate::{
22
errors::NodeResult, node::DriaComputeNode, p2p::P2PMessage, utils::get_current_time_nanos,
33
};
4+
use async_trait::async_trait;
45
use libp2p::gossipsub::MessageAcceptance;
56
use ollama_workflows::{Model, ModelProvider};
67
use serde::{Deserialize, Serialize};
78

9+
use super::ComputeHandler;
10+
11+
pub struct PingpongHandler;
12+
813
#[derive(Serialize, Deserialize, Debug, Clone)]
914
struct PingpongPayload {
1015
uuid: String,
@@ -18,19 +23,10 @@ struct PingpongResponse {
1823
pub(crate) timestamp: u128,
1924
}
2025

21-
/// A ping-pong is a message sent by a node to indicate that it is alive.
22-
/// Compute nodes listen to `pong` topic, and respond to `ping` topic.
23-
pub trait HandlesPingpong {
24-
fn handle_heartbeat(
25-
&mut self,
26-
message: P2PMessage,
27-
result_topic: &str,
28-
) -> NodeResult<MessageAcceptance>;
29-
}
30-
31-
impl HandlesPingpong for DriaComputeNode {
32-
fn handle_heartbeat(
33-
&mut self,
26+
#[async_trait]
27+
impl ComputeHandler for PingpongHandler {
28+
async fn handle_compute(
29+
node: &mut DriaComputeNode,
3430
message: P2PMessage,
3531
result_topic: &str,
3632
) -> NodeResult<MessageAcceptance> {
@@ -53,15 +49,15 @@ impl HandlesPingpong for DriaComputeNode {
5349
// respond
5450
let response_body = PingpongResponse {
5551
uuid: pingpong.uuid.clone(),
56-
models: self.config.model_config.models.clone(),
52+
models: node.config.model_config.models.clone(),
5753
timestamp: get_current_time_nanos(),
5854
};
5955
let response = P2PMessage::new_signed(
6056
serde_json::json!(response_body).to_string(),
6157
result_topic,
62-
&self.config.secret_key,
58+
&node.config.secret_key,
6359
);
64-
self.publish(response)?;
60+
node.publish(response)?;
6561

6662
// accept message, someone else may be included in the filter
6763
Ok(MessageAcceptance::Accept)

src/handlers/workflow.rs

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ use crate::p2p::P2PMessage;
99
use crate::utils::get_current_time_nanos;
1010
use crate::utils::payload::{TaskRequest, TaskRequestPayload};
1111

12+
use super::ComputeHandler;
13+
14+
pub struct WorkflowHandler;
15+
1216
#[derive(Debug, Deserialize)]
1317
struct WorkflowPayload {
1418
/// Workflow object to be parsed.
@@ -23,18 +27,9 @@ struct WorkflowPayload {
2327
}
2428

2529
#[async_trait]
26-
pub trait HandlesWorkflow {
27-
async fn handle_workflow(
28-
&mut self,
29-
message: P2PMessage,
30-
result_topic: &str,
31-
) -> NodeResult<MessageAcceptance>;
32-
}
33-
34-
#[async_trait]
35-
impl HandlesWorkflow for DriaComputeNode {
36-
async fn handle_workflow(
37-
&mut self,
30+
impl ComputeHandler for WorkflowHandler {
31+
async fn handle_compute(
32+
node: &mut DriaComputeNode,
3833
message: P2PMessage,
3934
result_topic: &str,
4035
) -> NodeResult<MessageAcceptance> {
@@ -55,7 +50,7 @@ impl HandlesWorkflow for DriaComputeNode {
5550
}
5651

5752
// check task inclusion via the bloom filter
58-
if !task.filter.contains(&self.config.address)? {
53+
if !task.filter.contains(&node.config.address)? {
5954
log::info!(
6055
"Task {} does not include this node within the filter.",
6156
task.task_id
@@ -75,7 +70,7 @@ impl HandlesWorkflow for DriaComputeNode {
7570
};
7671

7772
// read model / provider from the task
78-
let (model_provider, model) = self
73+
let (model_provider, model) = node
7974
.config
8075
.model_config
8176
.get_any_matching_model(task.input.model)?;
@@ -85,8 +80,8 @@ impl HandlesWorkflow for DriaComputeNode {
8580
let executor = if model_provider == ModelProvider::Ollama {
8681
Executor::new_at(
8782
model,
88-
&self.config.ollama_config.host,
89-
self.config.ollama_config.port,
83+
&node.config.ollama_config.host,
84+
node.config.ollama_config.port,
9085
)
9186
} else {
9287
Executor::new(model)
@@ -98,7 +93,7 @@ impl HandlesWorkflow for DriaComputeNode {
9893
.map(|prompt| Entry::try_value_or_str(&prompt));
9994
let result: Option<String>;
10095
tokio::select! {
101-
_ = self.cancellation.cancelled() => {
96+
_ = node.cancellation.cancelled() => {
10297
log::info!("Received cancellation, quitting all tasks.");
10398
return Ok(MessageAcceptance::Accept)
10499
},
@@ -113,7 +108,7 @@ impl HandlesWorkflow for DriaComputeNode {
113108
let result = result.ok_or::<String>(format!("No result for task {}", task.task_id))?;
114109

115110
// publish the result
116-
self.send_result(result_topic, &task.public_key, &task.task_id, result)?;
111+
node.send_result(result_topic, &task.public_key, &task.task_id, result)?;
117112

118113
// accept message, someone else may be included in the filter
119114
Ok(MessageAcceptance::Accept)

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1818
config.check_services().await?;
1919

2020
// launch the node
21-
let mut node = DriaComputeNode::new(config, CancellationToken::new())?;
21+
let mut node = DriaComputeNode::new(config, CancellationToken::new()).await?;
2222
node.launch().await?;
2323

2424
Ok(())

0 commit comments

Comments
 (0)