Skip to content

Commit 1c78c95

Browse files
committed
Remove LDK specific scoring from PathFinder trait
1 parent aaef04d commit 1c78c95

File tree

2 files changed

+78
-89
lines changed

2 files changed

+78
-89
lines changed

sim-cli/src/parsing.rs

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct NodeMapping {
225225
alias_node_map: HashMap<String, NodeInfo>,
226226
}
227227

228-
pub async fn create_simulation_with_network<P: for<'a> PathFinder<'a> + Clone + 'static>(
228+
pub async fn create_simulation_with_network<P: PathFinder + Clone + 'static>(
229229
cli: &Cli,
230230
sim_params: &SimParams,
231231
tasks: TaskTracker,
@@ -661,62 +661,63 @@ mod tests {
661661
#[derive(Clone)]
662662
pub struct AlwaysFailPathFinder;
663663

664-
impl<'a> PathFinder<'a> for AlwaysFailPathFinder {
664+
impl PathFinder for AlwaysFailPathFinder {
665665
fn find_route(
666666
&self,
667667
_source: &PublicKey,
668668
_dest: PublicKey,
669669
_amount_msat: u64,
670-
_pathfinding_graph: &NetworkGraph<&'a WrappedLog>,
671-
_scorer: &ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
670+
_pathfinding_graph: &NetworkGraph<&'static WrappedLog>,
672671
) -> Result<Route, SimulationError> {
673672
Err(SimulationError::SimulatedNetworkError(
674673
"No route found".to_string(),
675674
))
676675
}
677676
}
678677

679-
/// A pathfinder that only returns single-hop paths
678+
/// A pathfinder that only returns single-hop paths.
680679
#[derive(Clone)]
681680
pub struct SingleHopOnlyPathFinder;
682681

683-
impl<'a> PathFinder<'a> for SingleHopOnlyPathFinder {
682+
impl PathFinder for SingleHopOnlyPathFinder {
684683
fn find_route(
685684
&self,
686685
source: &PublicKey,
687686
dest: PublicKey,
688687
amount_msat: u64,
689-
pathfinding_graph: &NetworkGraph<&'a WrappedLog>,
690-
scorer: &ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
688+
pathfinding_graph: &NetworkGraph<&'static WrappedLog>,
691689
) -> Result<Route, SimulationError> {
692-
// Try to find a direct route only (single hop)
693-
let route_params = RouteParameters {
694-
payment_params: PaymentParameters::from_node_id(dest, 0)
695-
.with_max_total_cltv_expiry_delta(u32::MAX)
696-
.with_max_path_count(1)
697-
.with_max_channel_saturation_power_of_half(1),
698-
final_value_msat: amount_msat,
699-
max_total_routing_fee_msat: None,
700-
};
701-
702-
// Try to find a route - if it fails or has more than one hop, return an error
690+
let scorer = ProbabilisticScorer::new(
691+
ProbabilisticScoringDecayParameters::default(),
692+
pathfinding_graph,
693+
&WrappedLog {},
694+
);
695+
696+
// Try to find a route - if it fails or has more than one hop, return an error.
703697
match find_route(
704698
source,
705-
&route_params,
699+
&RouteParameters {
700+
payment_params: PaymentParameters::from_node_id(dest, 0)
701+
.with_max_total_cltv_expiry_delta(u32::MAX)
702+
.with_max_path_count(1)
703+
.with_max_channel_saturation_power_of_half(1),
704+
final_value_msat: amount_msat,
705+
max_total_routing_fee_msat: None,
706+
},
706707
pathfinding_graph,
707708
None,
708709
&WrappedLog {},
709-
scorer,
710+
&scorer,
710711
&Default::default(),
711712
&[0; 32],
712713
) {
713714
Ok(route) => {
714-
// Check if the route has exactly one hop
715+
// Only allow single-hop routes.
715716
if route.paths.len() == 1 && route.paths[0].hops.len() == 1 {
716717
Ok(route)
717718
} else {
718719
Err(SimulationError::SimulatedNetworkError(
719-
"No direct route found".to_string(),
720+
"Only single-hop routes allowed".to_string(),
720721
))
721722
}
722723
},
@@ -735,13 +736,7 @@ mod tests {
735736
let source = channels[0].get_node_1_pubkey();
736737
let dest = channels[2].get_node_2_pubkey();
737738

738-
let scorer = ProbabilisticScorer::new(
739-
ProbabilisticScoringDecayParameters::default(),
740-
routing_graph.clone(),
741-
&WrappedLog {},
742-
);
743-
744-
let result = pathfinder.find_route(&source, dest, 100_000, &routing_graph, &scorer);
739+
let result = pathfinder.find_route(&source, dest, 100_000, &routing_graph);
745740

746741
// Should always fail
747742
assert!(result.is_err());
@@ -756,15 +751,9 @@ mod tests {
756751
let pathfinder = SingleHopOnlyPathFinder;
757752
let source = channels[0].get_node_1_pubkey();
758753

759-
let scorer = ProbabilisticScorer::new(
760-
ProbabilisticScoringDecayParameters::default(),
761-
routing_graph.clone(),
762-
&WrappedLog {},
763-
);
764-
765754
// Test direct connection (should work)
766755
let direct_dest = channels[0].get_node_2_pubkey();
767-
let result = pathfinder.find_route(&source, direct_dest, 100_000, &routing_graph, &scorer);
756+
let result = pathfinder.find_route(&source, direct_dest, 100_000, &routing_graph);
768757

769758
if result.is_ok() {
770759
let route = result.unwrap();
@@ -773,8 +762,7 @@ mod tests {
773762

774763
// Test indirect connection (should fail)
775764
let indirect_dest = channels[2].get_node_2_pubkey();
776-
let _result =
777-
pathfinder.find_route(&source, indirect_dest, 100_000, &routing_graph, &scorer);
765+
let _result = pathfinder.find_route(&source, indirect_dest, 100_000, &routing_graph);
778766

779767
// May fail because no direct route exists
780768
// (depends on your test network topology)

simln-lib/src/sim_node.rs

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -495,30 +495,53 @@ pub trait SimNetwork: Send + Sync {
495495
}
496496

497497
/// A trait for custom pathfinding implementations.
498-
pub trait PathFinder<'a>: Send + Sync {
498+
/// Finds a route from the source node to the destination node for the specified amount.
499+
///
500+
/// # Arguments
501+
/// * `source` - The public key of the node initiating the payment.
502+
/// * `dest` - The public key of the destination node to receive the payment.
503+
/// * `amount_msat` - The amount to send in millisatoshis.
504+
/// * `pathfinding_graph` - The network graph containing channel topology and routing information.
505+
///
506+
/// # Returns
507+
/// Returns a `Route` containing the payment path, or a `SimulationError` if no route is found.
508+
509+
pub trait PathFinder: Send + Sync + Clone {
499510
fn find_route(
500511
&self,
501512
source: &PublicKey,
502513
dest: PublicKey,
503514
amount_msat: u64,
504-
pathfinding_graph: &NetworkGraph<&'a WrappedLog>,
505-
scorer: &ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
515+
pathfinding_graph: &NetworkGraph<&'static WrappedLog>,
506516
) -> Result<Route, SimulationError>;
507517
}
508518

509-
/// Default pathfinder that uses LDK's pathfinding algorithm.
519+
/// The default pathfinding implementation that uses LDK's built-in pathfinding algorithm.
510520
#[derive(Clone)]
511521
pub struct DefaultPathFinder;
512522

513-
impl<'a> PathFinder<'a> for DefaultPathFinder {
523+
impl DefaultPathFinder {
524+
pub fn new() -> Self {
525+
Self
526+
}
527+
}
528+
529+
impl PathFinder for DefaultPathFinder {
514530
fn find_route(
515531
&self,
516532
source: &PublicKey,
517533
dest: PublicKey,
518534
amount_msat: u64,
519-
pathfinding_graph: &NetworkGraph<&'a WrappedLog>,
520-
scorer: &ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
535+
pathfinding_graph: &NetworkGraph<&'static WrappedLog>,
521536
) -> Result<Route, SimulationError> {
537+
let scorer_graph = NetworkGraph::new(bitcoin::Network::Regtest, &WrappedLog {});
538+
let scorer = ProbabilisticScorer::new(
539+
ProbabilisticScoringDecayParameters::default(),
540+
Arc::new(scorer_graph),
541+
&WrappedLog {},
542+
);
543+
544+
// Call LDK's find_route with the scorer (LDK-specific requirement)
522545
find_route(
523546
source,
524547
&RouteParameters {
@@ -529,10 +552,10 @@ impl<'a> PathFinder<'a> for DefaultPathFinder {
529552
final_value_msat: amount_msat,
530553
max_total_routing_fee_msat: None,
531554
},
532-
pathfinding_graph,
555+
pathfinding_graph, // This is the real network graph used for pathfinding
533556
None,
534557
&WrappedLog {},
535-
scorer,
558+
&scorer, // LDK requires a scorer, so we provide a simple one
536559
&Default::default(),
537560
&[0; 32],
538561
)
@@ -551,38 +574,25 @@ pub struct SimNode<'a, T: SimNetwork, P: PathFinder<'a> = DefaultPathFinder> {
551574
/// Tracks the channel that will provide updates for payments by hash.
552575
in_flight: HashMap<PaymentHash, Receiver<Result<PaymentResult, LightningError>>>,
553576
/// A read-only graph used for pathfinding.
554-
pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
555-
/// Probabilistic scorer used to rank paths through the network for routing. This is reused across
556-
/// multiple payments to maintain scoring state.
557-
scorer: ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
577+
pathfinding_graph: Arc<NetworkGraph<&'static WrappedLog>>,
558578
/// The pathfinder implementation to use for finding routes
559579
pathfinder: P,
560580
}
561581

562-
impl<'a, T: SimNetwork, P: PathFinder<'a>> SimNode<'a, T, P> {
582+
impl<T: SimNetwork, P: PathFinder> SimNode<T, P> {
563583
/// Creates a new simulation node that refers to the high level network coordinator provided to process payments
564584
/// on its behalf. The pathfinding graph is provided separately so that each node can handle its own pathfinding.
565585
pub fn new(
566586
pubkey: PublicKey,
567587
payment_network: Arc<Mutex<T>>,
568-
pathfinding_graph: Arc<NetworkGraph<&'a WrappedLog>>,
588+
pathfinding_graph: Arc<NetworkGraph<&'static WrappedLog>>,
569589
pathfinder: P,
570590
) -> Self {
571-
// Initialize the probabilistic scorer with default parameters for learning from payment
572-
// history. These parameters control how much successful/failed payments affect routing
573-
// scores and how quickly these scores decay over time.
574-
let scorer = ProbabilisticScorer::new(
575-
ProbabilisticScoringDecayParameters::default(),
576-
pathfinding_graph.clone(),
577-
&WrappedLog {},
578-
);
579-
580591
SimNode {
581592
info: node_info(pubkey),
582593
network: payment_network,
583594
in_flight: HashMap::new(),
584595
pathfinding_graph,
585-
scorer,
586596
pathfinder,
587597
}
588598
}
@@ -664,7 +674,7 @@ fn node_info(pubkey: PublicKey) -> NodeInfo {
664674
}
665675

666676
#[async_trait]
667-
impl<'a, T: SimNetwork, P: PathFinder<'a>> LightningNode for SimNode<'a, T, P> {
677+
impl<T: SimNetwork, P: PathFinder> LightningNode for SimNode<T, P> {
668678
fn get_info(&self) -> &NodeInfo {
669679
&self.info
670680
}
@@ -686,7 +696,6 @@ impl<'a, T: SimNetwork, P: PathFinder<'a>> LightningNode for SimNode<'a, T, P> {
686696
dest,
687697
amount_msat,
688698
&self.pathfinding_graph,
689-
&self.scorer,
690699
) {
691700
Ok(route) => route,
692701
Err(e) => {
@@ -1064,7 +1073,7 @@ pub async fn ln_node_from_graph<P>(
10641073
pathfinder: P,
10651074
) -> HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>>
10661075
where
1067-
P: for<'a> PathFinder<'a> + Clone + 'static,
1076+
P: PathFinder + 'static,
10681077
{
10691078
let mut nodes: HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>> = HashMap::new();
10701079

@@ -1563,7 +1572,6 @@ mod tests {
15631572
use mockall::mock;
15641573
use ntest::assert_true;
15651574
use std::time::Duration;
1566-
use tokio::sync::oneshot;
15671575
use tokio::time::{self, timeout};
15681576

15691577
/// Creates a test channel policy with its maximum HTLC size set to half of the in flight limit of the channel.
@@ -1953,7 +1961,12 @@ mod tests {
19531961

19541962
// Create a simulated node for the first channel in our network.
19551963
let pk = channels[0].node_1.policy.pubkey;
1956-
let mut node = SimNode::new(pk, sim_network.clone(), Arc::new(graph), DefaultPathFinder);
1964+
let mut node = SimNode::new(
1965+
pk,
1966+
sim_network.clone(),
1967+
Arc::new(graph),
1968+
DefaultPathFinder::new(),
1969+
);
19571970

19581971
// Prime mock to return node info from lookup and assert that we get the pubkey we're expecting.
19591972
let lookup_pk = channels[3].node_1.policy.pubkey;
@@ -2038,16 +2051,15 @@ mod tests {
20382051
}
20392052

20402053
/// Contains elements required to test dispatch_payment functionality.
2041-
struct DispatchPaymentTestKit<'a> {
2054+
struct DispatchPaymentTestKit {
20422055
graph: SimGraph,
20432056
nodes: Vec<PublicKey>,
2044-
routing_graph: Arc<NetworkGraph<&'a WrappedLog>>,
2045-
scorer: ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
2057+
routing_graph: Arc<NetworkGraph<&'static WrappedLog>>,
20462058
shutdown: (Trigger, Listener),
20472059
pathfinder: DefaultPathFinder,
20482060
}
20492061

2050-
impl DispatchPaymentTestKit<'_> {
2062+
impl DispatchPaymentTestKit {
20512063
/// Creates a test graph with a set of nodes connected by three channels, with all the capacity of the channel
20522064
/// on the side of the first node. For example, if called with capacity = 100 it will set up the following
20532065
/// network:
@@ -2065,12 +2077,6 @@ mod tests {
20652077
populate_network_graph(channels.clone(), Arc::new(SystemClock {})).unwrap(),
20662078
);
20672079

2068-
let scorer = ProbabilisticScorer::new(
2069-
ProbabilisticScoringDecayParameters::default(),
2070-
routing_graph.clone(),
2071-
&WrappedLog {},
2072-
);
2073-
20742080
// Collect pubkeys in-order, pushing the last node on separately because they don't have an outgoing
20752081
// channel (they are not node_1 in any channel, only node_2).
20762082
let mut nodes = channels
@@ -2091,9 +2097,8 @@ mod tests {
20912097
.expect("could not create test graph"),
20922098
nodes,
20932099
routing_graph,
2094-
scorer,
20952100
shutdown: shutdown_clone,
2096-
pathfinder: DefaultPathFinder,
2101+
pathfinder: DefaultPathFinder::new(),
20972102
};
20982103

20992104
// Assert that our channel balance is all on the side of the channel opener when we start up.
@@ -2138,18 +2143,14 @@ mod tests {
21382143
) -> (Route, Result<PaymentResult, LightningError>) {
21392144
let route = self
21402145
.pathfinder
2141-
.find_route(&source, dest, amt, &self.routing_graph, &self.scorer)
2146+
.find_route(&source, dest, amt, &self.routing_graph)
21422147
.unwrap();
2148+
let (sender, receiver) = tokio::sync::oneshot::channel();
21432149

2144-
let (sender, receiver) = oneshot::channel();
21452150
self.graph
2146-
.dispatch_payment(source, route.clone(), PaymentHash([1; 32]), sender);
2147-
2148-
let payment_result = timeout(Duration::from_millis(10), receiver).await;
2149-
// Assert that we receive from the channel or fail.
2150-
assert!(payment_result.is_ok());
2151+
.dispatch_payment(source, route.clone(), PaymentHash([0; 32]), sender);
21512152

2152-
(route, payment_result.unwrap().unwrap())
2153+
(route, receiver.await.unwrap())
21532154
}
21542155

21552156
// Sets the balance on the channel to the tuple provided, used to arrange liquidity for testing.

0 commit comments

Comments
 (0)