Skip to content

Commit a938996

Browse files
committed
more eyre'ing, slight folder restructures
1 parent 48dc40e commit a938996

File tree

15 files changed

+65
-70
lines changed

15 files changed

+65
-70
lines changed

src/config/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ mod ollama;
33
mod openai;
44

55
use crate::utils::crypto::to_address;
6+
use eyre::{eyre, Result};
67
use libsecp256k1::{PublicKey, SecretKey};
78
use models::ModelConfig;
89
use ollama::OllamaConfig;
@@ -129,7 +130,7 @@ impl DriaComputeNodeConfig {
129130
/// If both type of models are used, both services are checked.
130131
/// In the end, bad models are filtered out and we simply check if we are left if any valid models at all.
131132
/// If not, an error is returned.
132-
pub async fn check_services(&mut self) -> Result<(), String> {
133+
pub async fn check_services(&mut self) -> Result<()> {
133134
log::info!("Checking configured services.");
134135

135136
// TODO: can refactor (provider, model) logic here
@@ -171,7 +172,7 @@ impl DriaComputeNodeConfig {
171172

172173
// update good models
173174
if good_models.is_empty() {
174-
Err("No good models found, please check logs for errors.".into())
175+
Err(eyre!("No good models found, please check logs for errors."))
175176
} else {
176177
self.model_config.models = good_models;
177178
Ok(())

src/config/ollama.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::time::Duration;
22

3+
use eyre::{eyre, Result};
34
use ollama_workflows::{
45
ollama_rs::{
56
generation::{
@@ -81,7 +82,7 @@ impl OllamaConfig {
8182
external_models: Vec<Model>,
8283
timeout: Duration,
8384
min_tps: f64,
84-
) -> Result<Vec<Model>, String> {
85+
) -> Result<Vec<Model>> {
8586
log::info!(
8687
"Checking Ollama requirements (auto-pull {}, workflow timeout: {}s)",
8788
if self.auto_pull { "on" } else { "off" },
@@ -96,7 +97,7 @@ impl OllamaConfig {
9697
Err(e) => {
9798
return {
9899
log::error!("Could not fetch local models from Ollama, is it online?");
99-
Err(e.to_string())
100+
Err(e.into())
100101
}
101102
}
102103
};
@@ -137,25 +138,22 @@ impl OllamaConfig {
137138
}
138139

139140
/// Pulls a model if `auto_pull` exists, otherwise returns an error.
140-
async fn try_pull(&self, ollama: &Ollama, model: String) -> Result<(), String> {
141+
async fn try_pull(&self, ollama: &Ollama, model: String) -> Result<()> {
141142
log::warn!("Model {} not found in Ollama", model);
142143
if self.auto_pull {
143144
// if auto-pull is enabled, pull the model
144145
log::info!(
145146
"Downloading missing model {} (this may take a while)",
146147
model
147148
);
148-
let status = ollama
149-
.pull_model(model, false)
150-
.await
151-
.map_err(|e| format!("Error pulling model with Ollama: {}", e))?;
149+
let status = ollama.pull_model(model, false).await?;
152150
log::debug!("Pulled model with Ollama, final status: {:#?}", status);
153151
Ok(())
154152
} else {
155153
// otherwise, give error
156154
log::error!("Please download missing model with: ollama pull {}", model);
157155
log::error!("Or, set OLLAMA_AUTO_PULL=true to pull automatically.");
158-
Err("Required model not pulled in Ollama.".into())
156+
Err(eyre!("Required model not pulled in Ollama."))
159157
}
160158
}
161159

src/config/openai.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![allow(unused)]
22

3+
use eyre::{eyre, Context, Result};
34
use ollama_workflows::Model;
45
use serde::Deserialize;
56

@@ -39,15 +40,13 @@ impl OpenAIConfig {
3940
Self { api_key }
4041
}
4142

42-
/// Check if requested models exist &
43-
///
44-
///
45-
pub async fn check(&self, models: Vec<Model>) -> Result<Vec<Model>, String> {
43+
/// Check if requested models exist & are available in the OpenAI account.
44+
pub async fn check(&self, models: Vec<Model>) -> Result<Vec<Model>> {
4645
log::info!("Checking OpenAI requirements");
4746

4847
// check API key
4948
let Some(api_key) = &self.api_key else {
50-
return Err("OpenAI API key not found".into());
49+
return Err(eyre!("OpenAI API key not found"));
5150
};
5251

5352
// fetch models
@@ -56,24 +55,21 @@ impl OpenAIConfig {
5655
.get(OPENAI_MODELS_API)
5756
.header("Authorization", format!("Bearer {}", api_key))
5857
.build()
59-
.map_err(|e| format!("Failed to build request: {}", e))?;
58+
.wrap_err("Failed to build request")?;
6059

6160
let response = client
6261
.execute(request)
6362
.await
64-
.map_err(|e| format!("Failed to send request: {}", e))?;
63+
.wrap_err("Failed to send request")?;
6564

6665
// parse response
6766
if response.status().is_client_error() {
68-
return Err(format!(
67+
return Err(eyre!(
6968
"Failed to fetch OpenAI models:\n{}",
7069
response.text().await.unwrap_or_default()
7170
));
7271
}
73-
let openai_models = response
74-
.json::<OpenAIModelsResponse>()
75-
.await
76-
.map_err(|e| e.to_string())?;
72+
let openai_models = response.json::<OpenAIModelsResponse>().await?;
7773

7874
// check if models exist and select those that are available
7975
let mut available_models = Vec::new();

src/handlers/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use async_trait::async_trait;
2+
use eyre::Result;
23
use libp2p::gossipsub::MessageAcceptance;
34

45
mod pingpong;
@@ -7,13 +8,13 @@ pub use pingpong::PingpongHandler;
78
mod workflow;
89
pub use workflow::WorkflowHandler;
910

10-
use crate::{p2p::P2PMessage, DriaComputeNode};
11+
use crate::{utils::P2PMessage, DriaComputeNode};
1112

1213
#[async_trait]
1314
pub trait ComputeHandler {
1415
async fn handle_compute(
1516
node: &mut DriaComputeNode,
1617
message: P2PMessage,
1718
result_topic: &str,
18-
) -> eyre::Result<MessageAcceptance>;
19+
) -> Result<MessageAcceptance>;
1920
}

src/handlers/pingpong.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::{node::DriaComputeNode, p2p::P2PMessage, utils::get_current_time_nanos};
1+
use crate::{
2+
node::DriaComputeNode,
3+
utils::{get_current_time_nanos, P2PMessage},
4+
};
25
use async_trait::async_trait;
36
use eyre::Result;
47
use libp2p::gossipsub::MessageAcceptance;
@@ -66,10 +69,10 @@ impl ComputeHandler for PingpongHandler {
6669
#[cfg(test)]
6770
mod tests {
6871
use crate::{
69-
p2p::P2PMessage,
7072
utils::{
7173
crypto::{sha256hash, to_address},
7274
filter::FilterPayload,
75+
P2PMessage,
7376
},
7477
DriaComputeNodeConfig,
7578
};

src/handlers/workflow.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@ use ollama_workflows::{Entry, Executor, ModelProvider, ProgramMemory, Workflow};
55
use serde::Deserialize;
66

77
use crate::node::DriaComputeNode;
8-
use crate::p2p::P2PMessage;
9-
use crate::utils::get_current_time_nanos;
108
use crate::utils::payload::{TaskRequest, TaskRequestPayload};
9+
use crate::utils::{get_current_time_nanos, P2PMessage};
1110

1211
use super::ComputeHandler;
1312

src/main.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use dkn_compute::{DriaComputeNode, DriaComputeNodeConfig};
2+
use eyre::Result;
23
use tokio_util::sync::CancellationToken;
34

45
#[tokio::main]
5-
async fn main() -> Result<(), Box<dyn std::error::Error>> {
6+
async fn main() -> Result<()> {
67
if let Err(e) = dotenvy::dotenv() {
78
log::warn!("Could not load .env file: {}", e);
89
}

src/node.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use tokio_util::sync::CancellationToken;
66
use crate::{
77
config::DriaComputeNodeConfig,
88
handlers::{ComputeHandler, PingpongHandler, WorkflowHandler},
9-
p2p::{P2PClient, P2PMessage},
10-
utils::{crypto::secret_to_keypair, AvailableNodes},
9+
p2p::P2PClient,
10+
utils::{crypto::secret_to_keypair, AvailableNodes, P2PMessage},
1111
};
1212

1313
/// Number of seconds between refreshing the Admin RPC PeerIDs from Dria server.
@@ -38,10 +38,9 @@ impl DriaComputeNode {
3838
pub async fn new(
3939
config: DriaComputeNodeConfig,
4040
cancellation: CancellationToken,
41-
) -> Result<Self, String> {
41+
) -> Result<Self> {
4242
let keypair = secret_to_keypair(&config.secret_key);
43-
let listen_addr =
44-
Multiaddr::from_str(config.p2p_listen_addr.as_str()).map_err(|e| e.to_string())?;
43+
let listen_addr = Multiaddr::from_str(config.p2p_listen_addr.as_str())?;
4544

4645
// get available nodes (bootstrap, relay, rpc) for p2p
4746
let available_nodes = AvailableNodes::default()
@@ -266,9 +265,8 @@ impl DriaComputeNode {
266265

267266
#[cfg(test)]
268267
mod tests {
269-
use crate::{p2p::P2PMessage, DriaComputeNode, DriaComputeNodeConfig};
268+
use super::*;
270269
use std::env;
271-
use tokio_util::sync::CancellationToken;
272270

273271
#[tokio::test]
274272
#[ignore = "run this manually"]

src/p2p/client.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use eyre::Result;
12
use libp2p::futures::StreamExt;
23
use libp2p::gossipsub::{
34
Message, MessageAcceptance, MessageId, PublishError, SubscriptionError, TopicHash,
@@ -17,6 +18,7 @@ use super::{DriaBehaviour, DriaBehaviourEvent, P2P_KADEMLIA_PROTOCOL, P2P_PROTOC
1718

1819
/// Underlying libp2p client.
1920
pub struct P2PClient {
21+
/// `Swarm` instance, everything is accesses through this one.
2022
swarm: Swarm<DriaBehaviour>,
2123
/// Peer count for (All, Mesh).
2224
peer_count: (usize, usize),
@@ -36,7 +38,7 @@ impl P2PClient {
3638
keypair: Keypair,
3739
listen_addr: Multiaddr,
3840
available_nodes: &AvailableNodes,
39-
) -> Result<Self, String> {
41+
) -> Result<Self> {
4042
// this is our peerId
4143
let node_peerid = keypair.public().to_peer_id();
4244
log::info!("Compute node peer address: {}", node_peerid);
@@ -47,13 +49,10 @@ impl P2PClient {
4749
tcp::Config::default(),
4850
noise::Config::new,
4951
yamux::Config::default,
50-
)
51-
.map_err(|e| e.to_string())?
52+
)?
5253
.with_quic()
53-
.with_relay_client(noise::Config::new, yamux::Config::default)
54-
.map_err(|e| e.to_string())?
55-
.with_behaviour(|key, relay_behavior| Ok(DriaBehaviour::new(key, relay_behavior)))
56-
.map_err(|e| e.to_string())?
54+
.with_relay_client(noise::Config::new, yamux::Config::default)?
55+
.with_behaviour(|key, relay_behavior| Ok(DriaBehaviour::new(key, relay_behavior)))?
5756
.with_swarm_config(|c| {
5857
c.with_idle_connection_timeout(Duration::from_secs(IDLE_CONNECTION_TIMEOUT_SECS))
5958
})
@@ -76,7 +75,7 @@ impl P2PClient {
7675
_ => None,
7776
}) {
7877
log::info!("Dialling peer: {}", addr);
79-
swarm.dial(addr.clone()).map_err(|e| e.to_string())?;
78+
swarm.dial(addr.clone())?;
8079
log::info!("Adding {} to Kademlia routing table", addr);
8180
swarm
8281
.behaviour_mut()
@@ -94,24 +93,18 @@ impl P2PClient {
9493
.behaviour_mut()
9594
.kademlia
9695
.get_closest_peers(random_peer);
97-
swarm
98-
.behaviour_mut()
99-
.kademlia
100-
.bootstrap()
101-
.map_err(|e| e.to_string())?;
96+
swarm.behaviour_mut().kademlia.bootstrap()?;
10297

10398
// listen on all interfaces for incoming connections
10499
log::info!("Listening p2p network on: {}", listen_addr);
105-
swarm.listen_on(listen_addr).map_err(|e| e.to_string())?;
100+
swarm.listen_on(listen_addr)?;
106101

107102
log::info!(
108103
"Listening to relay nodes: {:#?}",
109104
available_nodes.relay_nodes
110105
);
111106
for addr in &available_nodes.relay_nodes {
112-
swarm
113-
.listen_on(addr.clone().with(Protocol::P2pCircuit))
114-
.map_err(|e| e.to_string())?;
107+
swarm.listen_on(addr.clone().with(Protocol::P2pCircuit))?;
115108
}
116109

117110
Ok(Self {
@@ -138,6 +131,8 @@ impl P2PClient {
138131
}
139132

140133
/// Publish a message to a topic.
134+
///
135+
/// Returns the message ID.
141136
pub fn publish(
142137
&mut self,
143138
topic_name: &str,
@@ -168,7 +163,7 @@ impl P2PClient {
168163
msg_id: &MessageId,
169164
propagation_source: &PeerId,
170165
acceptance: MessageAcceptance,
171-
) -> Result<(), PublishError> {
166+
) -> Result<()> {
172167
log::trace!("Validating message ({}): {:?}", msg_id, acceptance);
173168

174169
let msg_was_in_cache = self
@@ -240,6 +235,7 @@ impl P2PClient {
240235
/// - For Kademlia, we check the kademlia protocol and then add the address to the Kademlia routing table.
241236
fn handle_identify_event(&mut self, peer_id: PeerId, info: identify::Info) {
242237
// we only care about the observed address, although there may be other addresses at `info.listen_addrs`
238+
// TODO: this may be wrong
243239
let addr = info.observed_addr;
244240

245241
// check protocol string

src/p2p/mod.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,4 @@ pub use behaviour::{DriaBehaviour, DriaBehaviourEvent};
3737
mod client;
3838
pub use client::P2PClient;
3939

40-
mod message;
41-
42-
pub use message::P2PMessage;
43-
4440
mod data_transform;

0 commit comments

Comments
 (0)