Skip to content

Commit 585c398

Browse files
committed
different points retrieval logic
1 parent da20fec commit 585c398

File tree

6 files changed

+101
-50
lines changed

6 files changed

+101
-50
lines changed

Cargo.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default-members = ["compute"]
77

88
[workspace.package]
99
edition = "2021"
10-
version = "0.5.2"
10+
version = "0.5.3"
1111
license = "Apache-2.0"
1212
readme = "README.md"
1313

compute/src/node/core.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@ impl DriaComputeNode {
1414
/// Runs the main loop of the compute node.
1515
/// This method is not expected to return until cancellation occurs for the given token.
1616
pub async fn run(&mut self, cancellation: CancellationToken) {
17+
// initialize the points client
18+
self.points_client.initialize().await;
19+
1720
/// Duration between refreshing for diagnostic prints.
1821
const DIAGNOSTIC_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(45);
22+
/// Duration between refreshing for points update.
23+
const POINTS_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(180);
1924
/// Duration between refreshing the available nodes.
2025
const RPC_LIVENESS_REFRESH_INTERVAL_SECS: Duration = Duration::from_secs(2 * 60);
2126
/// Duration between each specs update sent to the RPC.
@@ -28,6 +33,11 @@ impl DriaComputeNode {
2833
tokio::time::interval(RPC_LIVENESS_REFRESH_INTERVAL_SECS);
2934
rpc_liveness_refresh_interval.tick().await; // move each one tick
3035

36+
// tick the first time a bit earlier
37+
let mut points_refresh_interval = tokio::time::interval(POINTS_REFRESH_INTERVAL_SECS);
38+
points_refresh_interval.tick().await;
39+
points_refresh_interval.reset_after(POINTS_REFRESH_INTERVAL_SECS / 12);
40+
3141
// move one tick, and wait at least a third of the diagnostics
3242
let mut heartbeat_interval = tokio::time::interval(HeartbeatRequester::HEARTBEAT_DEADLINE);
3343
heartbeat_interval.tick().await;
@@ -68,6 +78,9 @@ impl DriaComputeNode {
6878
// check RPC, and get a new one if we are disconnected
6979
_ = rpc_liveness_refresh_interval.tick() => self.handle_rpc_liveness_check().await,
7080

81+
// log points every now and then
82+
_ = points_refresh_interval.tick() => self.handle_points_refresh().await,
83+
7184
// send a heartbeat request to publish liveness info
7285
_ = heartbeat_interval.tick() => {
7386
if let Err(e) = self.send_heartbeat().await {

compute/src/node/diagnostic.rs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use colored::Colorize;
22
use std::time::Duration;
33

4-
use crate::{node::rpc::DriaRPC, utils::get_points, DriaComputeNode, DRIA_COMPUTE_NODE_VERSION};
4+
use crate::{node::rpc::DriaRPC, DriaComputeNode, DRIA_COMPUTE_NODE_VERSION};
55

66
/// Number of seconds such that if the last heartbeat ACK is older than this, the node is considered unreachable.
77
/// This must be at least greated than the heartbeat interval duration, and the liveness check duration.
@@ -22,17 +22,6 @@ impl DriaComputeNode {
2222
pub(crate) async fn handle_diagnostic_refresh(&mut self) {
2323
let mut diagnostics = vec![format!("Diagnostics (v{}):", DRIA_COMPUTE_NODE_VERSION)];
2424

25-
// print steps
26-
if let Ok(steps) = get_points(&self.config.address).await {
27-
let earned = steps.score - self.initial_steps;
28-
diagnostics.push(format!(
29-
"$DRIA Points: {} total, {} earned in this run, within top {}%",
30-
steps.score,
31-
earned,
32-
steps.percentile.unwrap_or("100".to_string())
33-
));
34-
}
35-
3625
// completed tasks count is printed as well in debug
3726
if log::log_enabled!(log::Level::Debug) {
3827
diagnostics.push(format!(
@@ -137,4 +126,24 @@ impl DriaComputeNode {
137126
log::debug!("Connection with {} is intact.", self.dria_rpc.peer_id);
138127
}
139128
}
129+
130+
/// Updates the points for the given address.
131+
#[inline]
132+
pub(crate) async fn handle_points_refresh(&mut self) {
133+
// get points from the API
134+
match self.points_client.get_points().await {
135+
Ok(steps) => {
136+
log::info!(
137+
"{}: {} total, {} earned in this run, within top {}%",
138+
"$DRIA Points".purple(),
139+
steps.score,
140+
steps.score - self.points_client.initial,
141+
steps.percentile.unwrap_or("100".to_string())
142+
);
143+
}
144+
Err(err) => {
145+
log::error!("Could not get $DRIA points info: {err:?}");
146+
}
147+
}
148+
}
140149
}

compute/src/node/mod.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use uuid::Uuid;
99

1010
use crate::{
1111
config::*,
12-
utils::{get_points, SpecCollector},
12+
utils::{DriaPointsClient, SpecCollector},
1313
workers::task::{TaskWorker, TaskWorkerInput, TaskWorkerMetadata, TaskWorkerOutput},
1414
};
1515

@@ -58,8 +58,8 @@ pub struct DriaComputeNode {
5858
completed_tasks_batch: usize,
5959
/// Specifications collector.
6060
spec_collector: SpecCollector,
61-
/// Initial steps count.
62-
initial_steps: f64,
61+
/// Points client.
62+
points_client: DriaPointsClient,
6363
}
6464

6565
impl DriaComputeNode {
@@ -78,7 +78,7 @@ impl DriaComputeNode {
7878
let keypair = secret_to_keypair(&config.secret_key);
7979

8080
// dial the RPC node
81-
let dria_nodes = if let Some(addr) = config.initial_rpc_addr.take() {
81+
let dria_rpc = if let Some(addr) = config.initial_rpc_addr.take() {
8282
log::info!("Using initial RPC address: {}", addr);
8383
DriaRPC::new(addr, config.network_type).expect("could not get RPC to connect to")
8484
} else {
@@ -96,7 +96,7 @@ impl DriaComputeNode {
9696
let (p2p_client, p2p_commander, request_rx) = DriaP2PClient::new(
9797
keypair,
9898
config.p2p_listen_addr.clone(),
99-
&dria_nodes.addr,
99+
&dria_rpc.addr,
100100
protocol,
101101
)?;
102102

@@ -120,19 +120,15 @@ impl DriaComputeNode {
120120
};
121121

122122
let model_names = config.workflows.get_model_names();
123-
124-
let initial_steps = get_points(&config.address)
125-
.await
126-
.map(|s| s.score)
127-
.unwrap_or_default();
123+
let points_client = DriaPointsClient::new(&config.address)?;
128124

129125
let spec_collector = SpecCollector::new(model_names.clone(), config.version);
130126
Ok((
131127
DriaComputeNode {
132128
config,
133129
p2p: p2p_commander,
134-
dria_rpc: dria_nodes,
135-
initial_steps,
130+
dria_rpc,
131+
points_client,
136132
// receivers
137133
task_output_rx: publish_rx,
138134
reqres_rx: request_rx,

compute/src/utils/points.rs

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,57 @@ use eyre::Context;
33
/// Points URL, use with an `address` query parameter.
44
const POINTS_API_BASE_URL: &str =
55
"https://mainnet.dkn.dria.co/dashboard/supply/v0/leaderboard/steps";
6+
// TODO: support testnet here?
7+
8+
pub struct DriaPointsClient {
9+
pub url: String,
10+
client: reqwest::Client,
11+
/// The total number of points you have accumulated at the start of the run.
12+
pub initial: f64,
13+
}
14+
15+
impl DriaPointsClient {
16+
/// Creates a new `DriaPointsClient` for the given address.
17+
pub fn new(address: &str) -> eyre::Result<Self> {
18+
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
19+
20+
let url = format!(
21+
"{}?address=0x{}",
22+
POINTS_API_BASE_URL,
23+
address.trim_start_matches("0x")
24+
);
25+
26+
let client = reqwest::Client::builder()
27+
.user_agent(USER_AGENT)
28+
.build()
29+
.wrap_err("could not create Points client")?;
30+
31+
Ok(Self {
32+
url,
33+
client,
34+
initial: 0.0,
35+
})
36+
}
37+
38+
/// Sets the initial points to the current points.
39+
///
40+
/// If there is an error, it sets to 0.0.
41+
pub async fn initialize(&mut self) {
42+
self.initial = self.get_points().await.map(|p| p.score).unwrap_or_default();
43+
}
44+
45+
pub async fn get_points(&self) -> eyre::Result<DriaPoints> {
46+
let res = self
47+
.client
48+
.get(&self.url)
49+
.send()
50+
.await
51+
.wrap_err("could not make request")?;
52+
res.json::<DriaPoints>()
53+
.await
54+
.wrap_err("could not parse response")
55+
}
56+
}
657

758
#[derive(Debug, serde::Deserialize)]
859
pub struct DriaPoints {
@@ -15,33 +66,15 @@ pub struct DriaPoints {
1566
pub score: f64,
1667
}
1768

18-
/// Returns the points for the given address.
19-
pub async fn get_points(address: &str) -> eyre::Result<DriaPoints> {
20-
// the address can have 0x or not, we add it ourselves here
21-
let url = format!(
22-
"{}?address=0x{}",
23-
POINTS_API_BASE_URL,
24-
address.trim_start_matches("0x")
25-
);
26-
27-
let res = reqwest::get(&url)
28-
.await
29-
.wrap_err("could not make request")?;
30-
res.json::<DriaPoints>()
31-
.await
32-
.wrap_err("could not parse response")
33-
}
34-
3569
#[cfg(test)]
3670
mod tests {
3771
use super::*;
3872

3973
#[tokio::test]
4074
#[ignore = "waiting for API"]
4175
async fn test_get_points() {
42-
let steps = get_points("0xa43536a6032a3907ccf60e8109429ee1047b207c")
43-
.await
44-
.unwrap();
76+
let client = DriaPointsClient::new("0xa43536a6032a3907ccf60e8109429ee1047b207c").unwrap();
77+
let steps = client.get_points().await.unwrap();
4578
assert!(steps.score != 0.0);
4679
}
4780
}

0 commit comments

Comments
 (0)