Skip to content

Commit 177e60b

Browse files
authored
major orchestrator async refactor (#399)
* major orchestrator async refactor
1 parent eb516aa commit 177e60b

File tree

26 files changed

+1959
-904
lines changed

26 files changed

+1959
-904
lines changed

Cargo.lock

Lines changed: 9 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/orchestrator/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ hex = { workspace = true }
2020
log = { workspace = true }
2121
prometheus = "0.14.0"
2222
rand = "0.9.0"
23-
redis = { workspace = true }
23+
redis = { workspace = true, features = ["tokio-comp"] }
2424
redis-test = { workspace = true }
2525
reqwest = { workspace = true }
2626
serde = { workspace = true }

crates/orchestrator/src/api/routes/groups.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ const NODE_REQUEST_TIMEOUT: u64 = 30;
1515

1616
async fn get_groups(app_state: Data<AppState>) -> HttpResponse {
1717
if let Some(node_groups_plugin) = &app_state.node_groups_plugin {
18-
match node_groups_plugin.get_all_groups() {
18+
match node_groups_plugin.get_all_groups().await {
1919
Ok(groups) => {
2020
let groups_with_details: Vec<_> = groups
2121
.into_iter()
@@ -52,7 +52,7 @@ async fn get_groups(app_state: Data<AppState>) -> HttpResponse {
5252
async fn get_configurations(app_state: Data<AppState>) -> HttpResponse {
5353
if let Some(node_groups_plugin) = &app_state.node_groups_plugin {
5454
let all_configs = node_groups_plugin.get_all_configuration_templates();
55-
let available_configs = node_groups_plugin.get_available_configurations();
55+
let available_configs = node_groups_plugin.get_available_configurations().await;
5656

5757
let available_names: std::collections::HashSet<String> =
5858
available_configs.iter().map(|c| c.name.clone()).collect();
@@ -85,7 +85,7 @@ async fn get_configurations(app_state: Data<AppState>) -> HttpResponse {
8585

8686
async fn get_group_logs(group_id: web::Path<String>, app_state: Data<AppState>) -> HttpResponse {
8787
if let Some(node_groups_plugin) = &app_state.node_groups_plugin {
88-
match node_groups_plugin.get_group_by_id(&group_id) {
88+
match node_groups_plugin.get_group_by_id(&group_id).await {
8989
Ok(Some(group)) => {
9090
// Collect all node addresses
9191
let node_addresses: Vec<Address> = group
@@ -143,7 +143,21 @@ async fn get_group_logs(group_id: web::Path<String>, app_state: Data<AppState>)
143143
}
144144

145145
async fn fetch_node_logs(node_address: Address, app_state: Data<AppState>) -> serde_json::Value {
146-
let node = app_state.store_context.node_store.get_node(&node_address);
146+
let node = match app_state
147+
.store_context
148+
.node_store
149+
.get_node(&node_address)
150+
.await
151+
{
152+
Ok(node) => node,
153+
Err(e) => {
154+
error!("Failed to get node {}: {}", node_address, e);
155+
return json!({
156+
"success": false,
157+
"error": format!("Failed to get node: {}", e)
158+
});
159+
}
160+
};
147161

148162
match node {
149163
Some(node) => {
@@ -188,7 +202,8 @@ async fn fetch_node_logs(node_address: Address, app_state: Data<AppState>) -> se
188202
);
189203
headers.insert("x-signature", message_signature.parse().unwrap());
190204

191-
match reqwest::Client::new()
205+
match app_state
206+
.http_client
192207
.get(logs_url)
193208
.timeout(Duration::from_secs(NODE_REQUEST_TIMEOUT))
194209
.headers(headers)

crates/orchestrator/src/api/routes/heartbeat.rs

Lines changed: 98 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use actix_web::{
44
HttpResponse, Scope,
55
};
66
use alloy::primitives::Address;
7+
use log::error;
78
use serde_json::json;
89
use shared::models::{
910
api::ApiResponse,
@@ -18,36 +19,70 @@ async fn heartbeat(
1819
) -> HttpResponse {
1920
let task_info = heartbeat.clone();
2021
let node_address = Address::from_str(&heartbeat.address).unwrap();
21-
let node = app_state.store_context.node_store.get_node(&node_address);
22-
if let Some(node) = node {
23-
if node.status == NodeStatus::Banned {
22+
23+
let node_opt = app_state
24+
.store_context
25+
.node_store
26+
.get_node(&node_address)
27+
.await;
28+
match node_opt {
29+
Ok(Some(node)) => {
30+
if node.status == NodeStatus::Banned {
31+
return HttpResponse::BadRequest().json(json!({
32+
"success": false,
33+
"error": "Node is banned"
34+
}));
35+
}
36+
}
37+
_ => {
2438
return HttpResponse::BadRequest().json(json!({
2539
"success": false,
26-
"error": "Node is banned"
40+
"error": "Node not found"
2741
}));
2842
}
2943
}
30-
31-
app_state.store_context.node_store.update_node_task(
32-
node_address,
33-
task_info.task_id,
34-
task_info.task_state,
35-
);
44+
if let Err(e) = app_state
45+
.store_context
46+
.node_store
47+
.update_node_task(node_address, task_info.task_id, task_info.task_state)
48+
.await
49+
{
50+
error!("Error updating node task: {}", e);
51+
}
3652

3753
if let Some(p2p_id) = &heartbeat.p2p_id {
38-
app_state
54+
if let Err(e) = app_state
3955
.store_context
4056
.node_store
41-
.update_node_p2p_id(&node_address, p2p_id);
57+
.update_node_p2p_id(&node_address, p2p_id)
58+
.await
59+
{
60+
error!("Error updating node p2p id: {}", e);
61+
}
4262
}
4363

44-
app_state.store_context.heartbeat_store.beat(&heartbeat);
64+
if let Err(e) = app_state
65+
.store_context
66+
.heartbeat_store
67+
.beat(&heartbeat)
68+
.await
69+
{
70+
error!("Heartbeat Error: {}", e);
71+
}
4572
if let Some(metrics) = heartbeat.metrics.clone() {
4673
// Get all previously reported metrics for this node
47-
let previous_metrics = app_state
74+
let previous_metrics = match app_state
4875
.store_context
4976
.metrics_store
50-
.get_metrics_for_node(node_address);
77+
.get_metrics_for_node(node_address)
78+
.await
79+
{
80+
Ok(metrics) => metrics,
81+
Err(e) => {
82+
error!("Error getting metrics for node: {}", e);
83+
Default::default()
84+
}
85+
};
5186

5287
// Create a HashSet of new metrics for efficient lookup
5388
let new_metrics_set: HashSet<_> = metrics
@@ -67,20 +102,27 @@ async fn heartbeat(
67102
&label,
68103
);
69104
// Remove from Redis metrics store
70-
app_state.store_context.metrics_store.delete_metric(
71-
&task_id,
72-
&label,
73-
&node_address.to_string(),
74-
);
105+
if let Err(e) = app_state
106+
.store_context
107+
.metrics_store
108+
.delete_metric(&task_id, &label, &node_address.to_string())
109+
.await
110+
{
111+
error!("Error deleting metric: {}", e);
112+
}
75113
}
76114
}
77115
}
78116

79117
// Store new metrics and update Prometheus
80-
app_state
118+
if let Err(e) = app_state
81119
.store_context
82120
.metrics_store
83-
.store_metrics(Some(metrics.clone()), node_address);
121+
.store_metrics(Some(metrics.clone()), node_address)
122+
.await
123+
{
124+
error!("Error storing metrics: {}", e);
125+
}
84126

85127
for metric in metrics {
86128
app_state.metrics.record_compute_task_gauge(
@@ -92,7 +134,7 @@ async fn heartbeat(
92134
}
93135
}
94136

95-
let current_task = app_state.scheduler.get_task_for_node(node_address);
137+
let current_task = app_state.scheduler.get_task_for_node(node_address).await;
96138
match current_task {
97139
Ok(Some(task)) => {
98140
let resp: HttpResponse = ApiResponse::new(
@@ -121,6 +163,7 @@ mod tests {
121163

122164
use super::*;
123165
use crate::api::tests::helper::create_test_app_state;
166+
use crate::models::node::OrchestratorNode;
124167

125168
use actix_web::http::StatusCode;
126169
use actix_web::test;
@@ -141,6 +184,13 @@ mod tests {
141184
.await;
142185

143186
let address = "0x0000000000000000000000000000000000000000".to_string();
187+
let node_address = Address::from_str(&address).unwrap();
188+
let node = OrchestratorNode {
189+
address: node_address,
190+
status: NodeStatus::Healthy,
191+
..Default::default()
192+
};
193+
let _ = app_state.store_context.node_store.add_node(node).await;
144194
let req_payload = json!({"address": address, "metrics": [
145195
{"key": {"task_id": "long-task-1234", "label": "performance/batch_avg_seq_length"}, "value": 1.0},
146196
{"key": {"task_id": "long-task-1234", "label": "performance/batch_min_seq_length"}, "value": 5.0}
@@ -160,10 +210,14 @@ mod tests {
160210
assert_eq!(json["current_task"], serde_json::Value::Null);
161211

162212
let node_address = Address::from_str(&address).unwrap();
213+
163214
let value = app_state
164215
.store_context
165216
.heartbeat_store
166-
.get_heartbeat(&node_address);
217+
.get_heartbeat(&node_address)
218+
.await
219+
.unwrap();
220+
167221
assert_eq!(
168222
value,
169223
Some(HeartbeatRequest {
@@ -218,7 +272,9 @@ mod tests {
218272
let aggregated_metrics = app_state
219273
.store_context
220274
.metrics_store
221-
.get_aggregate_metrics_for_all_tasks();
275+
.get_aggregate_metrics_for_all_tasks()
276+
.await
277+
.unwrap();
222278
assert_eq!(aggregated_metrics.len(), 2);
223279
assert_eq!(aggregated_metrics.get("performance/batch_len"), Some(&10.0));
224280
assert_eq!(
@@ -245,7 +301,9 @@ mod tests {
245301
let aggregated_metrics = app_state
246302
.store_context
247303
.metrics_store
248-
.get_aggregate_metrics_for_all_tasks();
304+
.get_aggregate_metrics_for_all_tasks()
305+
.await
306+
.unwrap();
249307
assert_eq!(aggregated_metrics, HashMap::new());
250308
assert_eq!(metrics, "");
251309
}
@@ -267,20 +325,27 @@ mod tests {
267325
name: "test".to_string(),
268326
..Default::default()
269327
};
328+
329+
let node = OrchestratorNode {
330+
address: Address::from_str(&address).unwrap(),
331+
status: NodeStatus::Healthy,
332+
..Default::default()
333+
};
334+
335+
let _ = app_state.store_context.node_store.add_node(node).await;
336+
270337
let task = match task.try_into() {
271338
Ok(task) => task,
272339
Err(e) => panic!("Failed to convert TaskRequest to Task: {}", e),
273340
};
274-
app_state.store_context.task_store.add_task(task);
341+
let _ = app_state.store_context.task_store.add_task(task).await;
275342

276343
let req = test::TestRequest::post()
277344
.uri("/heartbeat")
278345
.set_json(json!({"address": "0x0000000000000000000000000000000000000000"}))
279346
.to_request();
280347

281348
let resp = test::call_service(&app, req).await;
282-
assert_eq!(resp.status(), StatusCode::OK);
283-
284349
let body = test::read_body(resp).await;
285350
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
286351
assert_eq!(json["success"], serde_json::Value::Bool(true));
@@ -293,14 +358,15 @@ mod tests {
293358
let value = app_state
294359
.store_context
295360
.heartbeat_store
296-
.get_heartbeat(&node_address);
361+
.get_heartbeat(&node_address)
362+
.await
363+
.unwrap();
297364
// Task has not started yet
298365

299-
let value = value.unwrap();
300366
let heartbeat = HeartbeatRequest {
301367
address: "0x0000000000000000000000000000000000000000".to_string(),
302368
..Default::default()
303369
};
304-
assert_eq!(value, heartbeat);
370+
assert_eq!(value, Some(heartbeat));
305371
}
306372
}

0 commit comments

Comments
 (0)