Skip to content

Commit cedb534

Browse files
committed
added network specific Jina & Serper checks
1 parent 160c802 commit cedb534

File tree

6 files changed

+34
-2
lines changed

6 files changed

+34
-2
lines changed

compute/src/config.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ impl DriaComputeNodeConfig {
127127

128128
/// Asserts that the configured listen address is free.
129129
/// Throws an error if the address is already in use.
130+
#[inline]
130131
pub fn assert_address_not_in_use(&self) -> Result<()> {
131132
if address_in_use(&self.p2p_listen_addr) {
132133
return Err(eyre!(
@@ -137,6 +138,21 @@ impl DriaComputeNodeConfig {
137138

138139
Ok(())
139140
}
141+
142+
/// Checks the network specific configurations.
143+
pub fn check_network_specific(&self) -> Result<()> {
144+
// if network is `pro`, we require Jina and Serper to be present.
145+
if self.network_type == DriaNetworkType::Pro {
146+
if !self.workflows.jina.has_api_key() {
147+
return Err(eyre!("Jina is required for the Pro network."));
148+
}
149+
if !self.workflows.serper.has_api_key() {
150+
return Err(eyre!("Serper is required for the Pro network."));
151+
}
152+
}
153+
154+
Ok(())
155+
}
140156
}
141157

142158
#[cfg(test)]

compute/src/main.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ async fn main() -> Result<()> {
5252
tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await;
5353

5454
log::warn!("Exiting due to DKN_EXIT_TIMEOUT.");
55-
5655
cancellation_token.cancel();
5756
} else if let Err(err) = wait_for_termination(cancellation_token.clone()).await {
5857
// if there is no timeout, we wait for termination signals here
@@ -86,6 +85,9 @@ async fn main() -> Result<()> {
8685
}?;
8786
log::warn!("Using models: {:#?}", config.workflows.models);
8887

88+
// check network-specific configurations
89+
config.check_network_specific()?;
90+
8991
// create the node
9092
let batch_size = config.batch_size;
9193
let (mut node, p2p, worker_batch, worker_single) = DriaComputeNode::new(config).await?;

p2p/src/network.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use libp2p::{Multiaddr, PeerId};
22

33
/// Network type.
4-
#[derive(Default, Debug, Clone, Copy)]
4+
#[derive(Default, Debug, Clone, Copy, PartialEq)]
55
pub enum DriaNetworkType {
66
#[default]
77
Community,

utils/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub fn split_csv_line(input: &str) -> Vec<String> {
2121

2222
/// Reads an environment variable and trims whitespace and `"` from both ends.
2323
/// If the trimmed value is empty, returns `None`.
24+
#[inline]
2425
pub fn safe_read_env(var: Result<String, std::env::VarError>) -> Option<String> {
2526
var.map(|s| s.trim_matches('"').trim().to_string())
2627
.ok()
@@ -43,6 +44,7 @@ where
4344
/// Returns the current time in nanoseconds since the Unix epoch.
4445
///
4546
/// If a `SystemTimeError` occurs, will return 0 just to keep things running.
47+
#[inline]
4648
pub fn get_current_time_nanos() -> u128 {
4749
SystemTime::now()
4850
.duration_since(SystemTime::UNIX_EPOCH)

workflows/src/apis/jina.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ impl JinaConfig {
2020
}
2121
}
2222

23+
/// Checks if the API key is present.
24+
#[inline]
25+
pub fn has_api_key(&self) -> bool {
26+
self.api_key.is_some()
27+
}
28+
2329
/// Sets the API key for Jina.
2430
pub fn with_api_key(mut self, api_key: String) -> Self {
2531
self.api_key = Some(api_key);

workflows/src/apis/serper.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ impl SerperConfig {
2020
}
2121
}
2222

23+
/// Checks if the API key is present.
24+
#[inline]
25+
pub fn has_api_key(&self) -> bool {
26+
self.api_key.is_some()
27+
}
28+
2329
/// Sets the API key for Serper.
2430
pub fn with_api_key(mut self, api_key: String) -> Self {
2531
self.api_key = Some(api_key);

0 commit comments

Comments
 (0)