Skip to content

Commit 58f44a4

Browse files
sangbidachuksys
authored andcommitted
Remove LDK specific scoring from PathFinder trait
1 parent 46a5125 commit 58f44a4

File tree

2 files changed

+85
-81
lines changed

2 files changed

+85
-81
lines changed

sim-cli/src/parsing.rs

Lines changed: 38 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ pub async fn create_simulation_with_network<P: PathFinder + Clone + 'static>(
308308
.map_err(|e| SimulationError::SimulatedNetworkError(format!("{:?}", e)))?,
309309
);
310310

311+
<<<<<<< HEAD
311312
<<<<<<< HEAD
312313
<<<<<<< HEAD
313314
// We want the full set of nodes in our graph to return to the caller so that they can take
@@ -327,6 +328,8 @@ pub async fn create_simulation_with_network<P: PathFinder + Clone + 'static>(
327328
=======
328329
>>>>>>> 079041b (Fix formatting)
329330
// Pass the pathfinder to ln_node_from_graph
331+
=======
332+
>>>>>>> 4a0f276 (Remove LDK specific scoring from PathFinder trait)
330333
let nodes = ln_node_from_graph(simulation_graph.clone(), routing_graph, pathfinder).await;
331334
let validated_activities =
332335
get_validated_activities(&nodes_dyn, nodes_info, sim_params.activity.clone()).await?;
@@ -671,7 +674,7 @@ mod tests {
671674
(secret_key, public_key)
672675
}
673676

674-
/// Helper function to create simulated channels for testing
677+
/// Helper function to create simulated channels for testing.
675678
fn create_simulated_channels(num_channels: usize, capacity_msat: u64) -> Vec<SimulatedChannel> {
676679
let mut channels = Vec::new();
677680
for i in 0..num_channels {
@@ -707,66 +710,67 @@ mod tests {
707710
channels
708711
}
709712

710-
/// A pathfinder that always fails to find a path
713+
/// A pathfinder that always fails to find a path.
711714
#[derive(Clone)]
712715
pub struct AlwaysFailPathFinder;
713716

714-
impl<'a> PathFinder<'a> for AlwaysFailPathFinder {
717+
impl PathFinder for AlwaysFailPathFinder {
715718
fn find_route(
716719
&self,
717720
_source: &PublicKey,
718721
_dest: PublicKey,
719722
_amount_msat: u64,
720-
_pathfinding_graph: &NetworkGraph<&'a WrappedLog>,
721-
_scorer: &ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
723+
_pathfinding_graph: &NetworkGraph<&'static WrappedLog>,
722724
) -> Result<Route, SimulationError> {
723725
Err(SimulationError::SimulatedNetworkError(
724726
"No route found".to_string(),
725727
))
726728
}
727729
}
728730

729-
/// A pathfinder that only returns single-hop paths
731+
/// A pathfinder that only returns single-hop paths.
730732
#[derive(Clone)]
731733
pub struct SingleHopOnlyPathFinder;
732734

733-
impl<'a> PathFinder<'a> for SingleHopOnlyPathFinder {
735+
impl PathFinder for SingleHopOnlyPathFinder {
734736
fn find_route(
735737
&self,
736738
source: &PublicKey,
737739
dest: PublicKey,
738740
amount_msat: u64,
739-
pathfinding_graph: &NetworkGraph<&'a WrappedLog>,
740-
scorer: &ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
741+
pathfinding_graph: &NetworkGraph<&'static WrappedLog>,
741742
) -> Result<Route, SimulationError> {
742-
// Try to find a direct route only (single hop)
743-
let route_params = RouteParameters {
744-
payment_params: PaymentParameters::from_node_id(dest, 0)
745-
.with_max_total_cltv_expiry_delta(u32::MAX)
746-
.with_max_path_count(1)
747-
.with_max_channel_saturation_power_of_half(1),
748-
final_value_msat: amount_msat,
749-
max_total_routing_fee_msat: None,
750-
};
751-
752-
// Try to find a route - if it fails or has more than one hop, return an error
743+
let scorer = ProbabilisticScorer::new(
744+
ProbabilisticScoringDecayParameters::default(),
745+
pathfinding_graph,
746+
&WrappedLog {},
747+
);
748+
749+
// Try to find a route - if it fails or has more than one hop, return an error.
753750
match find_route(
754751
source,
755-
&route_params,
752+
&RouteParameters {
753+
payment_params: PaymentParameters::from_node_id(dest, 0)
754+
.with_max_total_cltv_expiry_delta(u32::MAX)
755+
.with_max_path_count(1)
756+
.with_max_channel_saturation_power_of_half(1),
757+
final_value_msat: amount_msat,
758+
max_total_routing_fee_msat: None,
759+
},
756760
pathfinding_graph,
757761
None,
758762
&WrappedLog {},
759-
scorer,
763+
&scorer,
760764
&Default::default(),
761765
&[0; 32],
762766
) {
763767
Ok(route) => {
764-
// Check if the route has exactly one hop
768+
// Only allow single-hop routes.
765769
if route.paths.len() == 1 && route.paths[0].hops.len() == 1 {
766770
Ok(route)
767771
} else {
768772
Err(SimulationError::SimulatedNetworkError(
769-
"No direct route found".to_string(),
773+
"Only single-hop routes allowed".to_string(),
770774
))
771775
}
772776
},
@@ -785,15 +789,9 @@ mod tests {
785789
let source = channels[0].get_node_1_pubkey();
786790
let dest = channels[2].get_node_2_pubkey();
787791

788-
let scorer = ProbabilisticScorer::new(
789-
ProbabilisticScoringDecayParameters::default(),
790-
routing_graph.clone(),
791-
&WrappedLog {},
792-
);
793-
794-
let result = pathfinder.find_route(&source, dest, 100_000, &routing_graph, &scorer);
792+
let result = pathfinder.find_route(&source, dest, 100_000, &routing_graph);
795793

796-
// Should always fail
794+
// Should always fail.
797795
assert!(result.is_err());
798796
}
799797

@@ -806,31 +804,24 @@ mod tests {
806804
let pathfinder = SingleHopOnlyPathFinder;
807805
let source = channels[0].get_node_1_pubkey();
808806

809-
let scorer = ProbabilisticScorer::new(
810-
ProbabilisticScoringDecayParameters::default(),
811-
routing_graph.clone(),
812-
&WrappedLog {},
813-
);
814-
815-
// Test direct connection (should work)
807+
// Test direct connection (should work).
816808
let direct_dest = channels[0].get_node_2_pubkey();
817-
let result = pathfinder.find_route(&source, direct_dest, 100_000, &routing_graph, &scorer);
809+
let result = pathfinder.find_route(&source, direct_dest, 100_000, &routing_graph);
818810

819811
if result.is_ok() {
820812
let route = result.unwrap();
821813
assert_eq!(route.paths[0].hops.len(), 1); // Only one hop
822814
}
823815

824-
// Test indirect connection (should fail)
816+
// Test indirect connection (should fail).
825817
let indirect_dest = channels[2].get_node_2_pubkey();
826-
let _result =
827-
pathfinder.find_route(&source, indirect_dest, 100_000, &routing_graph, &scorer);
818+
let _result = pathfinder.find_route(&source, indirect_dest, 100_000, &routing_graph);
828819

829-
// May fail because no direct route exists
820+
// May fail because no direct route exists.
830821
// (depends on your test network topology)
831822
}
832823

833-
/// Test that different pathfinders produce different behavior in payments
824+
/// Test that different pathfinders produce different behavior in payments.
834825
#[tokio::test]
835826
async fn test_pathfinder_affects_payment_behavior() {
836827
let channels = create_simulated_channels(3, 1_000_000_000);
@@ -848,7 +839,7 @@ mod tests {
848839
let routing_graph =
849840
Arc::new(populate_network_graph(channels.clone(), Arc::new(SystemClock {})).unwrap());
850841

851-
// Create nodes with different pathfinders
842+
// Create nodes with different pathfinders.
852843
let nodes_default = ln_node_from_graph(
853844
sim_graph.clone(),
854845
routing_graph.clone(),
@@ -863,7 +854,7 @@ mod tests {
863854
)
864855
.await;
865856

866-
// Both should create the same structure
857+
// Both should create the same structure.
867858
assert_eq!(nodes_default.len(), nodes_fail.len());
868859
}
869860
}

simln-lib/src/sim_node.rs

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,18 @@ pub trait SimNetwork: Send + Sync {
502502
//type LdkNetworkGraph = NetworkGraph<Arc<WrappedLog>>;
503503
type LdkNetworkGraph = NetworkGraph<&'static WrappedLog>;
504504
/// A trait for custom pathfinding implementations.
505-
pub trait PathFinder<'a>: Send + Sync {
505+
/// Finds a route from the source node to the destination node for the specified amount.
506+
///
507+
/// # Arguments
508+
/// * `source` - The public key of the node initiating the payment.
509+
/// * `dest` - The public key of the destination node to receive the payment.
510+
/// * `amount_msat` - The amount to send in millisatoshis.
511+
/// * `pathfinding_graph` - The network graph containing channel topology and routing information.
512+
///
513+
/// # Returns
514+
/// Returns a `Route` containing the payment path, or a `SimulationError` if no route is found.
515+
516+
pub trait PathFinder: Send + Sync + Clone {
506517
fn find_route(
507518
&self,
508519
source: &PublicKey,
@@ -512,19 +523,32 @@ pub trait PathFinder<'a>: Send + Sync {
512523
) -> Result<Route, SimulationError>;
513524
}
514525

515-
/// Default pathfinder that uses LDK's pathfinding algorithm.
526+
/// The default pathfinding implementation that uses LDK's built-in pathfinding algorithm.
516527
#[derive(Clone)]
517528
pub struct DefaultPathFinder;
518529

519-
impl<'a> PathFinder<'a> for DefaultPathFinder {
530+
impl DefaultPathFinder {
531+
pub fn new() -> Self {
532+
Self
533+
}
534+
}
535+
536+
impl PathFinder for DefaultPathFinder {
520537
fn find_route(
521538
&self,
522539
source: &PublicKey,
523540
dest: PublicKey,
524541
amount_msat: u64,
525-
pathfinding_graph: &NetworkGraph<&'a WrappedLog>,
526-
scorer: &ProbabilisticScorer<Arc<NetworkGraph<&'a WrappedLog>>, &'a WrappedLog>,
542+
pathfinding_graph: &NetworkGraph<&'static WrappedLog>,
527543
) -> Result<Route, SimulationError> {
544+
let scorer_graph = NetworkGraph::new(bitcoin::Network::Regtest, &WrappedLog {});
545+
let scorer = ProbabilisticScorer::new(
546+
ProbabilisticScoringDecayParameters::default(),
547+
Arc::new(scorer_graph),
548+
&WrappedLog {},
549+
);
550+
551+
// Call LDK's find_route with the scorer (LDK-specific requirement)
528552
find_route(
529553
source,
530554
&RouteParameters {
@@ -535,10 +559,10 @@ impl<'a> PathFinder<'a> for DefaultPathFinder {
535559
final_value_msat: amount_msat,
536560
max_total_routing_fee_msat: None,
537561
},
538-
pathfinding_graph,
562+
pathfinding_graph, // This is the real network graph used for pathfinding
539563
None,
540564
&WrappedLog {},
541-
scorer,
565+
&scorer, // LDK requires a scorer, so we provide a simple one
542566
&Default::default(),
543567
&[0; 32],
544568
)
@@ -573,7 +597,7 @@ pub struct SimNode<T: SimNetwork, C: Clock, P: PathFinder = DefaultPathFinder> {
573597
pathfinder: P,
574598
}
575599

576-
impl<'a, T: SimNetwork, C: Clock, P: PathFinder<'a>> SimNode<'a, T, C, P> {
600+
impl<T: SimNetwork, C: Clock, P: PathFinder> SimNode<T, C, P> {
577601
/// Creates a new simulation node that refers to the high level network coordinator provided to process payments
578602
/// on its behalf. The pathfinding graph is provided separately so that each node can handle its own pathfinding.
579603
pub fn new(
@@ -583,24 +607,14 @@ impl<'a, T: SimNetwork, C: Clock, P: PathFinder<'a>> SimNode<'a, T, C, P> {
583607
clock: Arc<C>,
584608
pathfinder: P,
585609
) -> Self {
586-
// Initialize the probabilistic scorer with default parameters for learning from payment
587-
// history. These parameters control how much successful/failed payments affect routing
588-
// scores and how quickly these scores decay over time.
589-
let scorer = ProbabilisticScorer::new(
590-
ProbabilisticScoringDecayParameters::default(),
591-
pathfinding_graph.clone(),
592-
Arc::new(WrappedLog {}),
593-
);
594-
595-
Ok(SimNode {
610+
SimNode {
596611
info,
597612
network: payment_network,
598613
in_flight: Mutex::new(HashMap::new()),
599614
pathfinding_graph,
600-
scorer,
601615
clock,
602616
pathfinder,
603-
})
617+
}
604618
}
605619

606620
/// Dispatches a payment to a specified route. If `custom_records` is `Some`, they will be attached to the outgoing
@@ -735,7 +749,6 @@ impl<T: SimNetwork, C: Clock, P: PathFinder> LightningNode for SimNode<T, C, P>
735749
dest,
736750
amount_msat,
737751
&self.pathfinding_graph,
738-
&self.scorer,
739752
) {
740753
Ok(path) => path,
741754
// In the case that we can't find a route for the payment, we still report a successful payment *api call*
@@ -1146,7 +1159,7 @@ pub async fn ln_node_from_graph<C: Clock, P>(
11461159
pathfinder: P,
11471160
) -> Result<HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>>, LightningError>
11481161
where
1149-
P: for<'a> PathFinder<'a> + Clone + 'static,
1162+
P: PathFinder + 'static,
11501163
{
11511164
let mut nodes: HashMap<PublicKey, Arc<Mutex<dyn LightningNode>>> = HashMap::new();
11521165

@@ -1651,7 +1664,6 @@ mod tests {
16511664
use mockall::mock;
16521665
use ntest::assert_true;
16531666
use std::time::Duration;
1654-
use tokio::sync::oneshot;
16551667
use tokio::time::{self, timeout};
16561668

16571669
/// Creates a test channel policy with its maximum HTLC size set to half of the in flight limit of the channel.
@@ -2204,7 +2216,6 @@ mod tests {
22042216
graph: SimGraph,
22052217
nodes: Vec<PublicKey>,
22062218
routing_graph: Arc<LdkNetworkGraph>,
2207-
scorer: Mutex<ProbabilisticScorer<Arc<LdkNetworkGraph>, Arc<WrappedLog>>>,
22082219
shutdown: (Trigger, Listener),
22092220
pathfinder: DefaultPathFinder,
22102221
}
@@ -2227,12 +2238,19 @@ mod tests {
22272238
populate_network_graph(channels.clone(), Arc::new(SystemClock {})).unwrap(),
22282239
);
22292240

2241+
<<<<<<< HEAD
22302242
let scorer = Mutex::new(ProbabilisticScorer::new(
2243+
=======
2244+
<<<<<<< HEAD
2245+
let scorer = ProbabilisticScorer::new(
2246+
>>>>>>> 4a0f276 (Remove LDK specific scoring from PathFinder trait)
22312247
ProbabilisticScoringDecayParameters::default(),
22322248
routing_graph.clone(),
22332249
Arc::new(WrappedLog {}),
22342250
));
22352251

2252+
=======
2253+
>>>>>>> 3e72658 (Remove LDK specific scoring from PathFinder trait)
22362254
// Collect pubkeys in-order, pushing the last node on separately because they don't have an outgoing
22372255
// channel (they are not node_1 in any channel, only node_2).
22382256
let mut nodes = channels
@@ -2253,9 +2271,8 @@ mod tests {
22532271
.expect("could not create test graph"),
22542272
nodes,
22552273
routing_graph,
2256-
scorer,
22572274
shutdown: shutdown_clone,
2258-
pathfinder: DefaultPathFinder,
2275+
pathfinder: DefaultPathFinder::new(),
22592276
};
22602277

22612278
// Assert that our channel balance is all on the side of the channel opener when we start up.
@@ -2300,18 +2317,14 @@ mod tests {
23002317
) -> (Route, Result<PaymentResult, LightningError>) {
23012318
let route = self
23022319
.pathfinder
2303-
.find_route(&source, dest, amt, &self.routing_graph, &self.scorer)
2320+
.find_route(&source, dest, amt, &self.routing_graph)
23042321
.unwrap();
2322+
let (sender, receiver) = tokio::sync::oneshot::channel();
23052323

2306-
let (sender, receiver) = oneshot::channel();
23072324
self.graph
2308-
.dispatch_payment(source, route.clone(), None, PaymentHash([1; 32]), sender);
2309-
2310-
let payment_result = timeout(Duration::from_millis(10), receiver).await;
2311-
// Assert that we receive from the channel or fail.
2312-
assert!(payment_result.is_ok());
2325+
.dispatch_payment(source, route.clone(), PaymentHash([0; 32]), sender);
23132326

2314-
(route, payment_result.unwrap().unwrap())
2327+
(route, receiver.await.unwrap())
23152328
}
23162329

23172330
// Sets the balance on the channel to the tuple provided, used to arrange liquidity for testing.

0 commit comments

Comments
 (0)