Skip to content

Commit c4059cd

Browse files
authored
Merge pull request #284 from mubarak23/fix-remove-mut-ref
Fix: Remove Unnecessary &mut self Requirements from LightningNode Trait
2 parents 698b318 + 3e7ac44 commit c4059cd

File tree

6 files changed

+94
-88
lines changed

6 files changed

+94
-88
lines changed

simln-lib/src/cln.rs

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ use cln_grpc::pb::{
1010
};
1111
use lightning::ln::features::NodeFeatures;
1212
use lightning::ln::PaymentHash;
13-
1413
use serde::{Deserialize, Serialize};
1514
use tokio::fs::File;
1615
use tokio::io::{AsyncReadExt, Error};
16+
use tokio::sync::Mutex;
1717
use tokio::time::{self, Duration};
1818
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
1919
use triggered::Listener;
@@ -38,7 +38,7 @@ pub struct ClnConnection {
3838
}
3939

4040
pub struct ClnNode {
41-
pub client: NodeClient<Channel>,
41+
pub client: Mutex<NodeClient<Channel>>,
4242
info: NodeInfo,
4343
}
4444

@@ -63,7 +63,7 @@ impl ClnNode {
6363
})?,
6464
));
6565

66-
let mut client = NodeClient::new(
66+
let client = Mutex::new(NodeClient::new(
6767
Channel::from_shared(connection.address)
6868
.map_err(|err| LightningError::ConnectionError(err.to_string()))?
6969
.tls_config(tls)
@@ -81,9 +81,11 @@ impl ClnNode {
8181
err
8282
))
8383
})?,
84-
);
84+
));
8585

8686
let (id, mut alias, our_features) = client
87+
.lock()
88+
.await
8789
.getinfo(GetinfoRequest {})
8890
.await
8991
.map(|r| {
@@ -119,7 +121,7 @@ impl ClnNode {
119121
/// Fetch channels belonging to the local node, initiated locally if is_source is true, and initiated remotely if
120122
/// is_source is false. Introduced as a helper function because CLN doesn't have a single API to list all of our
121123
/// node's channels.
122-
async fn node_channels(&mut self, is_source: bool) -> Result<Vec<u64>, LightningError> {
124+
async fn node_channels(&self, is_source: bool) -> Result<Vec<u64>, LightningError> {
123125
let req = if is_source {
124126
ListchannelsRequest {
125127
source: Some(self.info.pubkey.serialize().to_vec()),
@@ -134,6 +136,8 @@ impl ClnNode {
134136

135137
let resp = self
136138
.client
139+
.lock()
140+
.await
137141
.list_channels(req)
138142
.await
139143
.map_err(|err| LightningError::ListChannelsError(err.to_string()))?
@@ -153,9 +157,9 @@ impl LightningNode for ClnNode {
153157
&self.info
154158
}
155159

156-
async fn get_network(&mut self) -> Result<Network, LightningError> {
157-
let info = self
158-
.client
160+
async fn get_network(&self) -> Result<Network, LightningError> {
161+
let mut client = self.client.lock().await;
162+
let info = client
159163
.getinfo(GetinfoRequest {})
160164
.await
161165
.map_err(|err| LightningError::GetInfoError(err.to_string()))?
@@ -166,12 +170,12 @@ impl LightningNode for ClnNode {
166170
}
167171

168172
async fn send_payment(
169-
&mut self,
173+
&self,
170174
dest: PublicKey,
171175
amount_msat: u64,
172176
) -> Result<PaymentHash, LightningError> {
173-
let KeysendResponse { payment_hash, .. } = self
174-
.client
177+
let mut client = self.client.lock().await;
178+
let KeysendResponse { payment_hash, .. } = client
175179
.key_send(KeysendRequest {
176180
destination: dest.serialize().to_vec(),
177181
amount_msat: Some(Amount { msat: amount_msat }),
@@ -200,7 +204,7 @@ impl LightningNode for ClnNode {
200204
}
201205

202206
async fn track_payment(
203-
&mut self,
207+
&self,
204208
hash: &PaymentHash,
205209
shutdown: Listener,
206210
) -> Result<PaymentResult, LightningError> {
@@ -211,8 +215,8 @@ impl LightningNode for ClnNode {
211215
return Err(LightningError::TrackPaymentError("Shutdown before tracking results".to_string()));
212216
},
213217
_ = time::sleep(Duration::from_millis(500)) => {
214-
let ListpaysResponse { pays } = self
215-
.client
218+
let mut client = self.client.lock().await;
219+
let ListpaysResponse { pays } = client
216220
.list_pays(ListpaysRequest {
217221
payment_hash: Some(hash.0.to_vec()),
218222
..Default::default()
@@ -242,9 +246,9 @@ impl LightningNode for ClnNode {
242246
}
243247
}
244248

245-
async fn get_node_info(&mut self, node_id: &PublicKey) -> Result<NodeInfo, LightningError> {
246-
let mut nodes: Vec<cln_grpc::pb::ListnodesNodes> = self
247-
.client
249+
async fn get_node_info(&self, node_id: &PublicKey) -> Result<NodeInfo, LightningError> {
250+
let mut client = self.client.lock().await;
251+
let mut nodes: Vec<cln_grpc::pb::ListnodesNodes> = client
248252
.list_nodes(ListnodesRequest {
249253
id: Some(node_id.serialize().to_vec()),
250254
})
@@ -270,15 +274,15 @@ impl LightningNode for ClnNode {
270274
}
271275
}
272276

273-
async fn list_channels(&mut self) -> Result<Vec<u64>, LightningError> {
277+
async fn list_channels(&self) -> Result<Vec<u64>, LightningError> {
274278
let mut node_channels = self.node_channels(true).await?;
275279
node_channels.extend(self.node_channels(false).await?);
276280
Ok(node_channels)
277281
}
278282

279-
async fn get_graph(&mut self) -> Result<Graph, LightningError> {
280-
let nodes: Vec<cln_grpc::pb::ListnodesNodes> = self
281-
.client
283+
async fn get_graph(&self) -> Result<Graph, LightningError> {
284+
let mut client = self.client.lock().await;
285+
let nodes: Vec<cln_grpc::pb::ListnodesNodes> = client
282286
.list_nodes(ListnodesRequest { id: None })
283287
.await
284288
.map_err(|err| LightningError::GetNodeInfoError(err.to_string()))?

simln-lib/src/eclair.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::collections::HashMap;
1515
use std::error::Error;
1616
use std::str::FromStr;
1717
use std::time::Duration;
18+
use tokio::sync::Mutex;
1819
use tokio::time;
1920
use triggered::Listener;
2021

@@ -83,7 +84,7 @@ impl TryInto<EclairClient> for EclairConnection {
8384
}
8485

8586
pub struct EclairNode {
86-
client: EclairClient,
87+
client: Mutex<EclairClient>,
8788
info: NodeInfo,
8889
network: Network,
8990
}
@@ -110,7 +111,7 @@ impl EclairNode {
110111
let features = parse_json_to_node_features(&info.features);
111112

112113
Ok(Self {
113-
client,
114+
client: Mutex::new(client),
114115
info: NodeInfo {
115116
pubkey,
116117
alias: info.alias,
@@ -127,31 +128,30 @@ impl LightningNode for EclairNode {
127128
&self.info
128129
}
129130

130-
async fn get_network(&mut self) -> Result<Network, LightningError> {
131+
async fn get_network(&self) -> Result<Network, LightningError> {
131132
Ok(self.network)
132133
}
133134

134135
async fn send_payment(
135-
&mut self,
136+
&self,
136137
dest: PublicKey,
137138
amount_msat: u64,
138139
) -> Result<PaymentHash, LightningError> {
140+
let client = self.client.lock().await;
139141
let preimage = PaymentPreimage(rand::random()).0;
140142
let mut params = HashMap::new();
141143
params.insert("nodeId".to_string(), hex::encode(dest.serialize()));
142144
params.insert("amountMsat".to_string(), amount_msat.to_string());
143145
params.insert("paymentHash".to_string(), hex::encode(preimage));
144-
let uuid: String = self
145-
.client
146+
let uuid: String = client
146147
.request("sendtonode", Some(params))
147148
.await
148149
.map_err(|err| LightningError::SendPaymentError(err.to_string()))?;
149150

150151
let mut params = HashMap::new();
151152
params.insert("paymentHash".to_string(), hex::encode(preimage));
152153
params.insert("id".to_string(), uuid);
153-
let payment_parts: PaymentInfoResponse = self
154-
.client
154+
let payment_parts: PaymentInfoResponse = client
155155
.request("getsentinfo", Some(params))
156156
.await
157157
.map_err(|_| LightningError::InvalidPaymentHash)?;
@@ -164,7 +164,7 @@ impl LightningNode for EclairNode {
164164
}
165165

166166
async fn track_payment(
167-
&mut self,
167+
&self,
168168
hash: &PaymentHash,
169169
shutdown: Listener,
170170
) -> Result<PaymentResult, LightningError> {
@@ -175,11 +175,11 @@ impl LightningNode for EclairNode {
175175
return Err(LightningError::TrackPaymentError("Shutdown before tracking results".to_string()));
176176
},
177177
_ = time::sleep(Duration::from_millis(500)) => {
178+
let client = self.client.lock().await;
178179
let mut params = HashMap::new();
179180
params.insert("paymentHash".to_string(), hex::encode(hash.0));
180181

181-
let payment_parts: PaymentInfoResponse = self
182-
.client
182+
let payment_parts: PaymentInfoResponse = client
183183
.request("getsentinfo", Some(params))
184184
.await
185185
.map_err(|err| LightningError::TrackPaymentError(err.to_string()))?;
@@ -204,12 +204,12 @@ impl LightningNode for EclairNode {
204204
}
205205
}
206206

207-
async fn get_node_info(&mut self, node_id: &PublicKey) -> Result<NodeInfo, LightningError> {
207+
async fn get_node_info(&self, node_id: &PublicKey) -> Result<NodeInfo, LightningError> {
208208
let mut params = HashMap::new();
209209
params.insert("nodeId".to_string(), hex::encode(node_id.serialize()));
210210

211-
let node_info: NodeResponse = self
212-
.client
211+
let client = self.client.lock().await;
212+
let node_info: NodeResponse = client
213213
.request("node", Some(params))
214214
.await
215215
.map_err(|err| LightningError::GetNodeInfoError(err.to_string()))?;
@@ -222,9 +222,9 @@ impl LightningNode for EclairNode {
222222
})
223223
}
224224

225-
async fn list_channels(&mut self) -> Result<Vec<u64>, LightningError> {
226-
let channels: ChannelsResponse = self
227-
.client
225+
async fn list_channels(&self) -> Result<Vec<u64>, LightningError> {
226+
let client = self.client.lock().await;
227+
let channels: ChannelsResponse = client
228228
.request("channels", None)
229229
.await
230230
.map_err(|err| LightningError::ListChannelsError(err.to_string()))?;
@@ -245,9 +245,9 @@ impl LightningNode for EclairNode {
245245
Ok(capacities_msat)
246246
}
247247

248-
async fn get_graph(&mut self) -> Result<Graph, LightningError> {
249-
let nodes: NodesResponse = self
250-
.client
248+
async fn get_graph(&self) -> Result<Graph, LightningError> {
249+
let client = self.client.lock().await;
250+
let nodes: NodesResponse = client
251251
.request("nodes", None)
252252
.await
253253
.map_err(|err| LightningError::GetNodeInfoError(err.to_string()))?;

simln-lib/src/lib.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -324,26 +324,26 @@ pub trait LightningNode: Send {
324324
/// Get information about the node.
325325
fn get_info(&self) -> &NodeInfo;
326326
/// Get the network this node is running at.
327-
async fn get_network(&mut self) -> Result<Network, LightningError>;
327+
async fn get_network(&self) -> Result<Network, LightningError>;
328328
/// Keysend payment worth `amount_msat` from a source node to the destination node.
329329
async fn send_payment(
330-
&mut self,
330+
&self,
331331
dest: PublicKey,
332332
amount_msat: u64,
333333
) -> Result<PaymentHash, LightningError>;
334334
/// Track a payment with the specified hash.
335335
async fn track_payment(
336-
&mut self,
336+
&self,
337337
hash: &PaymentHash,
338338
shutdown: Listener,
339339
) -> Result<PaymentResult, LightningError>;
340340
/// Gets information on a specific node.
341-
async fn get_node_info(&mut self, node_id: &PublicKey) -> Result<NodeInfo, LightningError>;
341+
async fn get_node_info(&self, node_id: &PublicKey) -> Result<NodeInfo, LightningError>;
342342
/// Lists all channels, at present only returns a vector of channel capacities in msat because no further
343343
/// information is required.
344-
async fn list_channels(&mut self) -> Result<Vec<u64>, LightningError>;
344+
async fn list_channels(&self) -> Result<Vec<u64>, LightningError>;
345345
/// Get the network graph from the point of view of a given node.
346-
async fn get_graph(&mut self) -> Result<Graph, LightningError>;
346+
async fn get_graph(&self) -> Result<Graph, LightningError>;
347347
}
348348

349349
/// Represents an error that occurs when generating a destination for a payment.
@@ -1149,7 +1149,7 @@ async fn consume_events(
11491149
if let Some(event) = simulation_event {
11501150
match event {
11511151
SimulationEvent::SendPayment(dest, amt_msat) => {
1152-
let mut node = node.lock().await;
1152+
let node = node.lock().await;
11531153

11541154
let mut payment = Payment {
11551155
source: node.get_info().pubkey,
@@ -1501,7 +1501,7 @@ async fn track_payment_result(
15011501
) -> Result<(), SimulationError> {
15021502
log::trace!("Payment result tracker starting.");
15031503

1504-
let mut node = node.lock().await;
1504+
let node = node.lock().await;
15051505

15061506
let res = match payment.hash {
15071507
Some(hash) => {

0 commit comments

Comments
 (0)