Skip to content

Commit d7b6201

Browse files
committed
add combined scorer
1 parent c11ff12 commit d7b6201

File tree

1 file changed

+201
-2
lines changed

1 file changed

+201
-2
lines changed

lightning/src/routing/scoring.rs

Lines changed: 201 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ where L::Target: Logger {
477477
channel_liquidities: ChannelLiquidities,
478478
}
479479
/// ChannelLiquidities contains live and historical liquidity bounds for each channel.
480+
#[derive(Clone)]
480481
pub struct ChannelLiquidities(HashMap<u64, ChannelLiquidity>);
481482

482483
impl ChannelLiquidities {
@@ -860,6 +861,7 @@ impl ProbabilisticScoringDecayParameters {
860861
/// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity
861862
/// offset fields gives the opposite direction.
862863
#[repr(C)] // Force the fields in memory to be in the order we specify
864+
#[derive(Clone)]
863865
pub struct ChannelLiquidity {
864866
/// Lower channel liquidity bound in terms of an offset from zero.
865867
min_liquidity_offset_msat: u64,
@@ -1130,6 +1132,15 @@ impl ChannelLiquidity {
11301132
}
11311133
}
11321134

1135+
fn merge(&mut self, other: &Self) {
1136+
// Take average for min/max liquidity offsets.
1137+
self.min_liquidity_offset_msat = (self.min_liquidity_offset_msat + other.min_liquidity_offset_msat) / 2;
1138+
self.max_liquidity_offset_msat = (self.max_liquidity_offset_msat + other.max_liquidity_offset_msat) / 2;
1139+
1140+
// Merge historical liquidity data.
1141+
self.liquidity_history.merge(&other.liquidity_history);
1142+
}
1143+
11331144
/// Returns a view of the channel liquidity directed from `source` to `target` assuming
11341145
/// `capacity_msat`.
11351146
fn as_directed(
@@ -1663,6 +1674,91 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for Probabilistic
16631674
}
16641675
}
16651676

1677+
/// A probabilistic scorer that combines local and external information to score channels. This scorer is
1678+
/// shadow-tracking local only scores, so that it becomes possible to cleanly merge external scores when they become
1679+
/// available.
1680+
pub struct CombinedScorer<G: Deref<Target = NetworkGraph<L>>, L: Deref> where L::Target: Logger {
1681+
local_only_scorer: ProbabilisticScorer<G, L>,
1682+
scorer: ProbabilisticScorer<G, L>,
1683+
}
1684+
1685+
impl<G: Deref<Target = NetworkGraph<L>> + Clone, L: Deref + Clone> CombinedScorer<G, L> where L::Target: Logger {
1686+
/// Create a new combined scorer with the given local scorer.
1687+
pub fn new(local_scorer: ProbabilisticScorer<G, L>) -> Self {
1688+
let decay_params = local_scorer.decay_params;
1689+
let network_graph = local_scorer.network_graph.clone();
1690+
let logger = local_scorer.logger.clone();
1691+
let mut scorer = ProbabilisticScorer::new(decay_params, network_graph, logger);
1692+
1693+
scorer.channel_liquidities = local_scorer.channel_liquidities.clone();
1694+
1695+
Self {
1696+
local_only_scorer: local_scorer,
1697+
scorer: scorer,
1698+
}
1699+
}
1700+
1701+
/// Merge external channel liquidity information into the scorer.
1702+
pub fn merge(&mut self, mut external_scores: ChannelLiquidities, duration_since_epoch: Duration) {
1703+
// Decay both sets of scores to make them comparable and mergeable.
1704+
self.local_only_scorer.time_passed(duration_since_epoch);
1705+
external_scores.time_passed(duration_since_epoch, self.local_only_scorer.decay_params);
1706+
1707+
let local_scores = &self.local_only_scorer.channel_liquidities;
1708+
1709+
// For each channel, merge the external liquidity information with the isolated local liquidity information.
1710+
for (scid, mut liquidity) in external_scores.0 {
1711+
if let Some(local_liquidity) = local_scores.get(&scid) {
1712+
liquidity.merge(local_liquidity);
1713+
}
1714+
self.scorer.channel_liquidities.insert(scid, liquidity);
1715+
}
1716+
}
1717+
}
1718+
1719+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreLookUp for CombinedScorer<G, L> where L::Target: Logger {
1720+
type ScoreParams = ProbabilisticScoringFeeParameters;
1721+
1722+
fn channel_penalty_msat(
1723+
&self, candidate: &CandidateRouteHop, usage: ChannelUsage, score_params: &ProbabilisticScoringFeeParameters
1724+
) -> u64 {
1725+
self.scorer.channel_penalty_msat(candidate, usage, score_params)
1726+
}
1727+
}
1728+
1729+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for CombinedScorer<G, L> where L::Target: Logger {
1730+
fn payment_path_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
1731+
self.local_only_scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
1732+
self.scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
1733+
}
1734+
1735+
fn payment_path_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
1736+
self.local_only_scorer.payment_path_successful(path, duration_since_epoch);
1737+
self.scorer.payment_path_successful(path, duration_since_epoch);
1738+
}
1739+
1740+
fn probe_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
1741+
self.local_only_scorer.probe_failed(path, short_channel_id, duration_since_epoch);
1742+
self.scorer.probe_failed(path, short_channel_id, duration_since_epoch);
1743+
}
1744+
1745+
fn probe_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
1746+
self.local_only_scorer.probe_successful(path, duration_since_epoch);
1747+
self.scorer.probe_successful(path, duration_since_epoch);
1748+
}
1749+
1750+
fn time_passed(&mut self,duration_since_epoch:Duration) {
1751+
self.local_only_scorer.time_passed(duration_since_epoch);
1752+
self.scorer.time_passed(duration_since_epoch);
1753+
}
1754+
}
1755+
1756+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for CombinedScorer<G, L> where L::Target: Logger {
1757+
fn write<W: crate::util::ser::Writer>(&self, writer: &mut W) -> Result<(), crate::io::Error> {
1758+
self.local_only_scorer.write(writer)
1759+
}
1760+
}
1761+
16661762
#[cfg(c_bindings)]
16671763
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Score for ProbabilisticScorer<G, L>
16681764
where L::Target: Logger {}
@@ -1842,6 +1938,13 @@ mod bucketed_history {
18421938
self.buckets[bucket] = self.buckets[bucket].saturating_add(BUCKET_FIXED_POINT_ONE);
18431939
}
18441940
}
1941+
1942+
/// Returns the average of the buckets between the two trackers.
1943+
pub(crate) fn merge(&mut self, other: &Self) -> () {
1944+
for (index, bucket) in self.buckets.iter_mut().enumerate() {
1945+
*bucket = (*bucket + other.buckets[index]) / 2;
1946+
}
1947+
}
18451948
}
18461949

18471950
impl_writeable_tlv_based!(HistoricalBucketRangeTracker, { (0, buckets, required) });
@@ -1938,6 +2041,13 @@ mod bucketed_history {
19382041
-> DirectedHistoricalLiquidityTracker<&'a mut HistoricalLiquidityTracker> {
19392042
DirectedHistoricalLiquidityTracker { source_less_than_target, tracker: self }
19402043
}
2044+
2045+
/// Merges the historical liquidity data from another tracker into this one.
2046+
pub fn merge(&mut self, other: &Self) {
2047+
self.min_liquidity_offset_history.merge(&other.min_liquidity_offset_history);
2048+
self.max_liquidity_offset_history.merge(&other.max_liquidity_offset_history);
2049+
self.recalculate_valid_point_count();
2050+
}
19412051
}
19422052

19432053
/// A set of buckets representing the history of where we've seen the minimum- and maximum-
@@ -2096,7 +2206,54 @@ mod bucketed_history {
20962206
Some((cumulative_success_prob * (1024.0 * 1024.0 * 1024.0)) as u64)
20972207
}
20982208
}
2209+
2210+
#[cfg(test)]
2211+
mod tests {
2212+
use crate::routing::scoring::ProbabilisticScoringFeeParameters;
2213+
2214+
use super::{HistoricalBucketRangeTracker, HistoricalLiquidityTracker};
2215+
#[test]
2216+
fn historical_liquidity_bucket_merge() {
2217+
let mut bucket1 = HistoricalBucketRangeTracker::new();
2218+
bucket1.track_datapoint(100, 1000);
2219+
assert_eq!(bucket1.buckets, [0u16,0,0,0,0,0,0,0,0,0,0,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]);
2220+
2221+
let mut bucket2 = HistoricalBucketRangeTracker::new();
2222+
bucket2.track_datapoint(0, 1000);
2223+
assert_eq!(bucket2.buckets, [32u16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]);
2224+
2225+
bucket1.merge(&bucket2);
2226+
assert_eq!(bucket1.buckets, [16u16,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]);
2227+
}
2228+
2229+
#[test]
2230+
fn historical_liquidity_tracker_merge() {
2231+
let params = ProbabilisticScoringFeeParameters::default();
2232+
2233+
let probability1: Option<u64>;
2234+
let mut tracker1 = HistoricalLiquidityTracker::new();
2235+
{
2236+
let mut directed_tracker1 = tracker1.as_directed_mut(true);
2237+
directed_tracker1.track_datapoint(100, 200, 1000);
2238+
probability1 = directed_tracker1.calculate_success_probability_times_billion(&params, 500, 1000);
2239+
}
2240+
2241+
let mut tracker2 = HistoricalLiquidityTracker::new();
2242+
{
2243+
let mut directed_tracker2 = tracker2.as_directed_mut(true);
2244+
directed_tracker2.track_datapoint(200, 300, 1000);
2245+
}
2246+
2247+
tracker1.merge(&tracker2);
2248+
2249+
let directed_tracker1 = tracker1.as_directed(true);
2250+
let probability = directed_tracker1.calculate_success_probability_times_billion(&params, 500, 1000);
2251+
2252+
assert_ne!(probability1, probability);
2253+
}
2254+
}
20992255
}
2256+
21002257
use bucketed_history::{LegacyHistoricalBucketRangeTracker, HistoricalBucketRangeTracker, DirectedHistoricalLiquidityTracker, HistoricalLiquidityTracker};
21012258

21022259
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for ProbabilisticScorer<G, L> where L::Target: Logger {
@@ -2193,15 +2350,15 @@ impl Readable for ChannelLiquidity {
21932350

21942351
#[cfg(test)]
21952352
mod tests {
2196-
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters, ProbabilisticScorer};
2353+
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScorer, ProbabilisticScoringDecayParameters, ProbabilisticScoringFeeParameters};
21972354
use crate::blinded_path::BlindedHop;
21982355
use crate::util::config::UserConfig;
21992356

22002357
use crate::ln::channelmanager;
22012358
use crate::ln::msgs::{ChannelAnnouncement, ChannelUpdate, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
22022359
use crate::routing::gossip::{EffectiveCapacity, NetworkGraph, NodeId};
22032360
use crate::routing::router::{BlindedTail, Path, RouteHop, CandidateRouteHop, PublicHopCandidate};
2204-
use crate::routing::scoring::{ChannelUsage, ScoreLookUp, ScoreUpdate};
2361+
use crate::routing::scoring::{ChannelLiquidities, ChannelUsage, CombinedScorer, ScoreLookUp, ScoreUpdate};
22052362
use crate::util::ser::{ReadableArgs, Writeable};
22062363
use crate::util::test_utils::{self, TestLogger};
22072364

@@ -2211,6 +2368,7 @@ mod tests {
22112368
use bitcoin::network::Network;
22122369
use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
22132370
use core::time::Duration;
2371+
use std::rc::Rc;
22142372
use crate::io;
22152373

22162374
fn source_privkey() -> SecretKey {
@@ -3702,6 +3860,47 @@ mod tests {
37023860
assert_eq!(scorer.historical_estimated_payment_success_probability(42, &target, amount_msat, &params, false),
37033861
Some(0.0));
37043862
}
3863+
3864+
#[test]
3865+
fn combined_scorer() {
3866+
let logger = TestLogger::new();
3867+
let network_graph = network_graph(&logger);
3868+
let params = ProbabilisticScoringFeeParameters::default();
3869+
let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
3870+
scorer.payment_path_failed(&payment_path_for_amount(600), 42, Duration::ZERO);
3871+
3872+
let mut combined_scorer = CombinedScorer::new(scorer);
3873+
3874+
let source = source_node_id();
3875+
let usage = ChannelUsage {
3876+
amount_msat: 750,
3877+
inflight_htlc_msat: 0,
3878+
effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000, htlc_maximum_msat: 1_000 },
3879+
};
3880+
let network_graph = network_graph.read_only();
3881+
let channel = network_graph.channel(42).unwrap();
3882+
let (info, _) = channel.as_directed_from(&source).unwrap();
3883+
let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate {
3884+
info,
3885+
short_channel_id: 42,
3886+
});
3887+
let penalty = combined_scorer.channel_penalty_msat(&candidate, usage, &params);
3888+
3889+
let mut external_liquidity = ChannelLiquidity::new(Duration::ZERO);
3890+
let logger_rc = Rc::new(&logger); // Why necessary and not above for the network graph?
3891+
external_liquidity.as_directed_mut(&source_node_id(), &target_node_id(), 1_000).
3892+
successful(1000, Duration::ZERO, format_args!("test channel"), logger_rc.as_ref());
3893+
3894+
let mut external_scores = ChannelLiquidities::new();
3895+
3896+
external_scores.insert(42, external_liquidity);
3897+
combined_scorer.merge(external_scores, Duration::ZERO);
3898+
3899+
let penalty_after_merge = combined_scorer.channel_penalty_msat(&candidate, usage, &params);
3900+
3901+
// Since the external source observed a successful payment, the penalty should be lower after the merge.
3902+
assert!(penalty_after_merge < penalty);
3903+
}
37053904
}
37063905

37073906
#[cfg(ldk_bench)]

0 commit comments

Comments
 (0)