Skip to content

Commit 519ee7c

Browse files
Move helper methods into macro and use tokio::sync::Mutex
- Inline get_next_endpoint, handle_success, handle_error logic into macro - Replace std::sync::Mutex with tokio::sync::Mutex for async compatibility - Remove now-unused helper methods from RoundRobinState and RpcMultiClient - Update all constructor methods to use tokio mutex - Addresses GitHub PR feedback from ali-behjati Co-Authored-By: Ali <[email protected]>
1 parent 042187f commit 519ee7c

File tree

1 file changed

+49
-66
lines changed

1 file changed

+49
-66
lines changed

src/agent/utils/rpc_multi_client.rs

Lines changed: 49 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,13 @@ use {
1818
},
1919
solana_transaction_status::TransactionStatus,
2020
std::{
21-
sync::{
22-
Arc,
23-
Mutex,
24-
},
21+
sync::Arc,
2522
time::{
2623
Duration,
2724
Instant,
2825
},
2926
},
27+
tokio::sync::Mutex,
3028
url::Url,
3129
};
3230

@@ -36,15 +34,59 @@ macro_rules! retry_rpc_operation {
3634
let max_attempts = $self.rpc_clients.len() * 2;
3735

3836
while attempts < max_attempts {
39-
if let Some(index) = $self.get_next_endpoint() {
37+
let index_option = {
38+
let mut state = $self.round_robin_state.lock().await;
39+
let now = Instant::now();
40+
let start_index = state.current_index;
41+
42+
let mut found_index = None;
43+
for _ in 0..state.endpoint_states.len() {
44+
let index = state.current_index;
45+
state.current_index = (state.current_index + 1) % state.endpoint_states.len();
46+
47+
let endpoint_state = &state.endpoint_states[index];
48+
if endpoint_state.is_healthy
49+
|| endpoint_state.last_failure.map_or(true, |failure_time| {
50+
now.duration_since(failure_time) >= state.cooldown_duration
51+
})
52+
{
53+
found_index = Some(index);
54+
break;
55+
}
56+
}
57+
58+
if found_index.is_none() {
59+
let index = start_index;
60+
state.current_index = (start_index + 1) % state.endpoint_states.len();
61+
found_index = Some(index);
62+
}
63+
found_index
64+
};
65+
66+
if let Some(index) = index_option {
4067
let $client = &$self.rpc_clients[index];
4168
match $operation {
4269
Ok(result) => {
43-
$self.handle_success(index);
70+
let mut state = $self.round_robin_state.lock().await;
71+
if index < state.endpoint_states.len() {
72+
state.endpoint_states[index].is_healthy = true;
73+
state.endpoint_states[index].last_failure = None;
74+
}
4475
return Ok(result);
4576
}
4677
Err(e) => {
47-
$self.handle_error(index, $operation_name, &e);
78+
let client = &$self.rpc_clients[index];
79+
tracing::warn!(
80+
"{} error for rpc endpoint {}: {}",
81+
$operation_name,
82+
client.url(),
83+
e
84+
);
85+
let mut state = $self.round_robin_state.lock().await;
86+
if index < state.endpoint_states.len() {
87+
state.endpoint_states[index].last_failure = Some(Instant::now());
88+
state.endpoint_states[index].is_healthy = false;
89+
}
4890
}
4991
}
5092
}
@@ -87,43 +129,6 @@ impl RoundRobinState {
87129
cooldown_duration,
88130
}
89131
}
90-
91-
fn get_next_healthy_endpoint(&mut self) -> Option<usize> {
92-
let now = Instant::now();
93-
let start_index = self.current_index;
94-
95-
for _ in 0..self.endpoint_states.len() {
96-
let index = self.current_index;
97-
self.current_index = (self.current_index + 1) % self.endpoint_states.len();
98-
99-
let state = &self.endpoint_states[index];
100-
if state.is_healthy
101-
|| state.last_failure.map_or(true, |failure_time| {
102-
now.duration_since(failure_time) >= self.cooldown_duration
103-
})
104-
{
105-
return Some(index);
106-
}
107-
}
108-
109-
let index = start_index;
110-
self.current_index = (start_index + 1) % self.endpoint_states.len();
111-
Some(index)
112-
}
113-
114-
fn mark_endpoint_failed(&mut self, index: usize) {
115-
if index < self.endpoint_states.len() {
116-
self.endpoint_states[index].last_failure = Some(Instant::now());
117-
self.endpoint_states[index].is_healthy = false;
118-
}
119-
}
120-
121-
fn mark_endpoint_healthy(&mut self, index: usize) {
122-
if index < self.endpoint_states.len() {
123-
self.endpoint_states[index].is_healthy = true;
124-
self.endpoint_states[index].last_failure = None;
125-
}
126-
}
127132
}
128133

129134
pub struct RpcMultiClient {
@@ -217,28 +222,6 @@ impl RpcMultiClient {
217222
}
218223
}
219224

220-
fn get_next_endpoint(&self) -> Option<usize> {
221-
let mut state = self.round_robin_state.lock().unwrap();
222-
state.get_next_healthy_endpoint()
223-
}
224-
225-
fn handle_success(&self, index: usize) {
226-
let mut state = self.round_robin_state.lock().unwrap();
227-
state.mark_endpoint_healthy(index);
228-
}
229-
230-
fn handle_error(&self, index: usize, operation_name: &str, error: &dyn std::fmt::Display) {
231-
let client = &self.rpc_clients[index];
232-
tracing::warn!(
233-
"{} error for rpc endpoint {}: {}",
234-
operation_name,
235-
client.url(),
236-
error
237-
);
238-
let mut state = self.round_robin_state.lock().unwrap();
239-
state.mark_endpoint_failed(index);
240-
}
241-
242225

243226
pub async fn get_balance(&self, kp: &Keypair) -> anyhow::Result<u64> {
244227
retry_rpc_operation!(self, "getBalance", client => client.get_balance(&kp.pubkey()).await)

0 commit comments

Comments
 (0)