Skip to content

Commit 304edc6

Browse files
committed
small rfks
1 parent b74f09c commit 304edc6

File tree

3 files changed

+65
-60
lines changed

3 files changed

+65
-60
lines changed

compute/src/node.rs

Lines changed: 34 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use dkn_p2p::{
55
},
66
DriaP2PClient, DriaP2PCommander, DriaP2PProtocol,
77
};
8-
use eyre::{eyre, Result};
8+
use eyre::Result;
99
use tokio::{sync::mpsc, time::Duration};
1010
use tokio_util::{either::Either, sync::CancellationToken};
1111

@@ -165,7 +165,10 @@ impl DriaComputeNode {
165165
}
166166

167167
// first, parse the raw gossipsub message to a prepared message
168-
let message = match self.parse_message_to_prepared_message(&message) {
168+
let message = match DKNMessage::try_from_gossipsub_message(
169+
&message,
170+
&self.config.admin_public_key,
171+
) {
169172
Ok(message) => message,
170173
Err(e) => {
171174
log::error!("Error parsing message: {:?}", e);
@@ -193,7 +196,7 @@ impl DriaComputeNode {
193196
PingpongHandler::LISTEN_TOPIC => {
194197
PingpongHandler::handle_ping(self, &message).await
195198
}
196-
_ => unreachable!(), // unreachable because of the `match` above
199+
_ => unreachable!("unreachable due to match expression"),
197200
};
198201

199202
// validate the message based on the result
@@ -216,46 +219,20 @@ impl DriaComputeNode {
216219
}
217220
}
218221

219-
/// Peer refresh simply reports the peer count to the user.
220-
async fn handle_peer_refresh(&self) {
221-
match self.p2p.peer_counts().await {
222-
Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all),
223-
Err(e) => log::error!("Error getting peer counts: {:?}", e),
224-
}
225-
}
226-
227-
/// Updates the local list of available nodes by refreshing it.
228-
/// Dials the RPC nodes again for better connectivity.
229-
async fn handle_available_nodes_refresh(&mut self) {
230-
log::info!("Refreshing available nodes.");
231-
232-
// refresh available nodes
233-
if let Err(e) = self.available_nodes.populate_with_api().await {
234-
log::error!("Error refreshing available nodes: {:?}", e);
235-
};
236-
237-
// dial all rpc nodes
238-
for rpc_addr in self.available_nodes.rpc_addrs.iter() {
239-
log::debug!("Dialling RPC node: {}", rpc_addr);
240-
if let Err(e) = self.p2p.dial(rpc_addr.clone()).await {
241-
log::warn!("Error dialling RPC node: {:?}", e);
242-
};
243-
}
244-
}
245-
246222
/// Runs the main loop of the compute node.
247223
/// This method is not expected to return until cancellation occurs.
248224
pub async fn run(&mut self) -> Result<()> {
225+
// prepare durations for sleeps
226+
let peer_refresh_duration = Duration::from_secs(PEER_REFRESH_INTERVAL_SECS);
227+
let available_node_refresh_duration =
228+
Duration::from_secs(AVAILABLE_NODES_REFRESH_INTERVAL_SECS);
229+
249230
// subscribe to topics
250231
self.subscribe(PingpongHandler::LISTEN_TOPIC).await?;
251232
self.subscribe(PingpongHandler::RESPONSE_TOPIC).await?;
252233
self.subscribe(WorkflowHandler::LISTEN_TOPIC).await?;
253234
self.subscribe(WorkflowHandler::RESPONSE_TOPIC).await?;
254235

255-
let peer_refresh_duration = Duration::from_secs(PEER_REFRESH_INTERVAL_SECS);
256-
let available_node_refresh_duration =
257-
Duration::from_secs(AVAILABLE_NODES_REFRESH_INTERVAL_SECS);
258-
259236
loop {
260237
tokio::select! {
261238
// check peer count every now and then
@@ -321,25 +298,31 @@ impl DriaComputeNode {
321298
Ok(())
322299
}
323300

324-
/// Parses a given raw Gossipsub message to a prepared P2PMessage object.
325-
/// This prepared message includes the topic, payload, version and timestamp.
326-
///
327-
/// This also checks the signature of the message, expecting a valid signature from admin node.
328-
// TODO: move this somewhere?
329-
pub fn parse_message_to_prepared_message(&self, message: &Message) -> Result<DKNMessage> {
330-
// the received message is expected to use IdentHash for the topic, so we can see the name of the topic immediately.
331-
log::debug!("Parsing {} message.", message.topic.as_str());
332-
let message = DKNMessage::try_from(message)?;
333-
log::debug!("Parsed: {}", message);
334-
335-
// check dria signature
336-
// NOTE: when we have many public keys, we should check the signature against all of them
337-
// TODO: public key here will be given dynamically
338-
if !message.is_signed(&self.config.admin_public_key)? {
339-
return Err(eyre!("Invalid signature."));
301+
/// Peer refresh simply reports the peer count to the user.
302+
async fn handle_peer_refresh(&self) {
303+
match self.p2p.peer_counts().await {
304+
Ok((mesh, all)) => log::info!("Peer Count (mesh/all): {} / {}", mesh, all),
305+
Err(e) => log::error!("Error getting peer counts: {:?}", e),
340306
}
307+
}
341308

342-
Ok(message)
309+
/// Updates the local list of available nodes by refreshing it.
310+
/// Dials the RPC nodes again for better connectivity.
311+
async fn handle_available_nodes_refresh(&mut self) {
312+
log::info!("Refreshing available nodes.");
313+
314+
// refresh available nodes
315+
if let Err(e) = self.available_nodes.populate_with_api().await {
316+
log::error!("Error refreshing available nodes: {:?}", e);
317+
};
318+
319+
// dial all rpc nodes
320+
for rpc_addr in self.available_nodes.rpc_addrs.iter() {
321+
log::debug!("Dialling RPC node: {}", rpc_addr);
322+
if let Err(e) = self.p2p.dial(rpc_addr.clone()).await {
323+
log::warn!("Error dialling RPC node: {:?}", e);
324+
};
325+
}
343326
}
344327
}
345328

compute/src/utils/message.rs

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::DRIA_COMPUTE_NODE_VERSION;
66
use base64::{prelude::BASE64_STANDARD, Engine};
77
use core::fmt;
88
use ecies::PublicKey;
9-
use eyre::{Context, Result};
9+
use eyre::{eyre, Context, Result};
1010
use libsecp256k1::{verify, Message, SecretKey, Signature};
1111
use serde::{Deserialize, Serialize};
1212

@@ -44,7 +44,7 @@ impl DKNMessage {
4444
///
4545
/// - `data` is given as bytes, it is encoded into base64 to make up the `payload` within.
4646
/// - `topic` is the name of the [gossipsub topic](https://docs.libp2p.io/concepts/pubsub/overview/).
47-
pub fn new(data: impl AsRef<[u8]>, topic: &str) -> Self {
47+
pub(crate) fn new(data: impl AsRef<[u8]>, topic: &str) -> Self {
4848
Self {
4949
payload: BASE64_STANDARD.encode(data),
5050
topic: topic.to_string(),
@@ -55,7 +55,7 @@ impl DKNMessage {
5555
}
5656

5757
/// Creates a new Message by signing the SHA256 of the payload, and prepending the signature.
58-
pub fn new_signed(data: impl AsRef<[u8]>, topic: &str, signing_key: &SecretKey) -> Self {
58+
pub(crate) fn new_signed(data: impl AsRef<[u8]>, topic: &str, signing_key: &SecretKey) -> Self {
5959
// sign the SHA256 hash of the data
6060
let signature_bytes = sign_bytes_recoverable(&sha256hash(data.as_ref()), signing_key);
6161

@@ -69,19 +69,19 @@ impl DKNMessage {
6969
}
7070

7171
/// Sets the identity of the message.
72-
pub fn with_identity(mut self, identity: String) -> Self {
72+
pub(crate) fn with_identity(mut self, identity: String) -> Self {
7373
self.identity = identity;
7474
self
7575
}
7676

7777
/// Decodes the base64 payload into bytes.
7878
#[inline(always)]
79-
pub fn decode_payload(&self) -> Result<Vec<u8>, base64::DecodeError> {
79+
pub(crate) fn decode_payload(&self) -> Result<Vec<u8>, base64::DecodeError> {
8080
BASE64_STANDARD.decode(&self.payload)
8181
}
8282

8383
/// Decodes and parses the base64 payload into JSON for the provided type `T`.
84-
pub fn parse_payload<T: for<'a> Deserialize<'a>>(&self, signed: bool) -> Result<T> {
84+
pub(crate) fn parse_payload<T: for<'a> Deserialize<'a>>(&self, signed: bool) -> Result<T> {
8585
let payload = self.decode_payload()?;
8686

8787
let body = if signed {
@@ -96,7 +96,7 @@ impl DKNMessage {
9696
}
9797

9898
/// Checks if the payload is signed by the given public key.
99-
pub fn is_signed(&self, public_key: &PublicKey) -> Result<bool> {
99+
pub(crate) fn is_signed(&self, public_key: &PublicKey) -> Result<bool> {
100100
// decode base64 payload
101101
let data = self.decode_payload()?;
102102

@@ -116,6 +116,29 @@ impl DKNMessage {
116116
let digest = Message::parse(&sha256hash(body));
117117
Ok(verify(&digest, &signature, public_key))
118118
}
119+
120+
/// Tries to parse the given gossipsub message into a DKNMessage.
121+
///
122+
/// This prepared message includes the topic, payload, version and timestamp.
123+
/// It also checks the signature of the message, expecting a valid signature from admin node.
124+
pub(crate) fn try_from_gossipsub_message(
125+
gossipsub_message: &dkn_p2p::libp2p::gossipsub::Message,
126+
public_key: &libsecp256k1::PublicKey,
127+
) -> Result<Self> {
128+
// the received message is expected to use IdentHash for the topic, so we can see the name of the topic immediately.
129+
log::debug!("Parsing {} message.", gossipsub_message.topic.as_str());
130+
let message = serde_json::from_slice::<DKNMessage>(&gossipsub_message.data)
131+
.wrap_err("could not parse message")?;
132+
log::debug!("Parsed: {}", message);
133+
134+
// check dria signature
135+
// NOTE: when we have many public keys, we should check the signature against all of them
136+
if !message.is_signed(&public_key)? {
137+
return Err(eyre!("Invalid signature."));
138+
}
139+
140+
Ok(message)
141+
}
119142
}
120143

121144
impl fmt::Display for DKNMessage {

workflows/src/providers/ollama.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use ollama_workflows::{
44
generation::{
55
completion::request::GenerationRequest,
66
embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest},
7-
options::GenerationOptions,
87
},
98
Ollama,
109
},
@@ -171,7 +170,7 @@ impl OllamaConfig {
171170
// otherwise, give error
172171
log::error!("Please download missing model with: ollama pull {}", model);
173172
log::error!("Or, set OLLAMA_AUTO_PULL=true to pull automatically.");
174-
Err(eyre!("Required model not pulled in Ollama."))
173+
Err(eyre!("required model not pulled in Ollama"))
175174
}
176175
}
177176

0 commit comments

Comments
 (0)