Skip to content

Commit b02bcee

Browse files
committed
add combined scorer
1 parent c11ff12 commit b02bcee

File tree

1 file changed

+198
-2
lines changed

1 file changed

+198
-2
lines changed

lightning/src/routing/scoring.rs

Lines changed: 198 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ where L::Target: Logger {
477477
channel_liquidities: ChannelLiquidities,
478478
}
479479
/// ChannelLiquidities contains live and historical liquidity bounds for each channel.
480+
#[derive(Clone)]
480481
pub struct ChannelLiquidities(HashMap<u64, ChannelLiquidity>);
481482

482483
impl ChannelLiquidities {
@@ -522,6 +523,7 @@ impl DerefMut for ChannelLiquidities {
522523
}
523524
}
524525

526+
// TODO: Avoid extra level of tlv serialization
525527
impl Readable for ChannelLiquidities {
526528
#[inline]
527529
fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
@@ -533,6 +535,7 @@ impl Readable for ChannelLiquidities {
533535
}
534536
}
535537

538+
536539
/// Parameters for configuring [`ProbabilisticScorer`].
537540
///
538541
/// Used to configure base, liquidity, and amount penalties, the sum of which comprises the channel
@@ -860,6 +863,7 @@ impl ProbabilisticScoringDecayParameters {
860863
/// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity
861864
/// offset fields gives the opposite direction.
862865
#[repr(C)] // Force the fields in memory to be in the order we specify
866+
#[derive(Clone)]
863867
pub struct ChannelLiquidity {
864868
/// Lower channel liquidity bound in terms of an offset from zero.
865869
min_liquidity_offset_msat: u64,
@@ -1130,6 +1134,17 @@ impl ChannelLiquidity {
11301134
}
11311135
}
11321136

1137+
fn merge(&mut self, other: &Self) {
1138+
// Todo: check updated times for equality.
1139+
1140+
// Take average for min/max liquidity offsets.
1141+
self.min_liquidity_offset_msat = (self.min_liquidity_offset_msat + other.min_liquidity_offset_msat) / 2;
1142+
self.max_liquidity_offset_msat = (self.max_liquidity_offset_msat + other.max_liquidity_offset_msat) / 2;
1143+
1144+
// Merge historical liquidity data.
1145+
self.liquidity_history.merge(&other.liquidity_history);
1146+
}
1147+
11331148
/// Returns a view of the channel liquidity directed from `source` to `target` assuming
11341149
/// `capacity_msat`.
11351150
fn as_directed(
@@ -1663,6 +1678,85 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for Probabilistic
16631678
}
16641679
}
16651680

1681+
/// A probabilistic scorer that combines local and external information to score channels.
1682+
pub struct CombinedScorer<G: Deref<Target = NetworkGraph<L>>, L: Deref> where L::Target: Logger {
1683+
local_only_scorer: ProbabilisticScorer<G, L>,
1684+
scorer: ProbabilisticScorer<G, L>,
1685+
}
1686+
1687+
impl<G: Deref<Target = NetworkGraph<L>> + Clone, L: Deref + Clone> CombinedScorer<G, L> where L::Target: Logger {
1688+
/// Create a new combined scorer with the given local scorer.
1689+
pub fn new(local_scorer: ProbabilisticScorer<G, L>) -> Self {
1690+
let decay_params = local_scorer.decay_params;
1691+
let network_graph = local_scorer.network_graph.clone();
1692+
let logger = local_scorer.logger.clone();
1693+
let mut scorer = ProbabilisticScorer::new(decay_params, network_graph, logger);
1694+
1695+
scorer.channel_liquidities = local_scorer.channel_liquidities.clone();
1696+
1697+
Self {
1698+
local_only_scorer: local_scorer,
1699+
scorer: scorer,
1700+
}
1701+
}
1702+
1703+
/// Merge external channel liquidity information into the scorer.
1704+
pub fn merge(&mut self, external_scores: ChannelLiquidities) {
1705+
let local_scores = &self.local_only_scorer.channel_liquidities;
1706+
1707+
// For each channel, merge the external liquidity information with the isolated local liquidity information.
1708+
for (scid, mut liquidity) in external_scores.0 {
1709+
if let Some(local_liquidity) = local_scores.get(&scid) {
1710+
liquidity.merge(local_liquidity);
1711+
}
1712+
self.scorer.channel_liquidities.insert(scid, liquidity);
1713+
}
1714+
}
1715+
}
1716+
1717+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreLookUp for CombinedScorer<G, L> where L::Target: Logger {
1718+
type ScoreParams = ProbabilisticScoringFeeParameters;
1719+
1720+
fn channel_penalty_msat(
1721+
&self, candidate: &CandidateRouteHop, usage: ChannelUsage, score_params: &ProbabilisticScoringFeeParameters
1722+
) -> u64 {
1723+
self.scorer.channel_penalty_msat(candidate, usage, score_params)
1724+
}
1725+
}
1726+
1727+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for CombinedScorer<G, L> where L::Target: Logger {
1728+
fn payment_path_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
1729+
self.local_only_scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
1730+
self.scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
1731+
}
1732+
1733+
fn payment_path_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
1734+
self.local_only_scorer.payment_path_successful(path, duration_since_epoch);
1735+
self.scorer.payment_path_successful(path, duration_since_epoch);
1736+
}
1737+
1738+
fn probe_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
1739+
self.local_only_scorer.probe_failed(path, short_channel_id, duration_since_epoch);
1740+
self.scorer.probe_failed(path, short_channel_id, duration_since_epoch);
1741+
}
1742+
1743+
fn probe_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
1744+
self.local_only_scorer.probe_successful(path, duration_since_epoch);
1745+
self.scorer.probe_successful(path, duration_since_epoch);
1746+
}
1747+
1748+
fn time_passed(&mut self,duration_since_epoch:Duration) {
1749+
self.local_only_scorer.time_passed(duration_since_epoch);
1750+
self.scorer.time_passed(duration_since_epoch);
1751+
}
1752+
}
1753+
1754+
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for CombinedScorer<G, L> where L::Target: Logger {
1755+
fn write<W: crate::util::ser::Writer>(&self, writer: &mut W) -> Result<(), crate::io::Error> {
1756+
self.local_only_scorer.write(writer)
1757+
}
1758+
}
1759+
16661760
#[cfg(c_bindings)]
16671761
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Score for ProbabilisticScorer<G, L>
16681762
where L::Target: Logger {}
@@ -1842,6 +1936,13 @@ mod bucketed_history {
18421936
self.buckets[bucket] = self.buckets[bucket].saturating_add(BUCKET_FIXED_POINT_ONE);
18431937
}
18441938
}
1939+
1940+
/// Returns the average of the buckets between the two trackers.
1941+
pub(crate) fn merge(&mut self, other: &Self) -> () {
1942+
for (index, bucket) in self.buckets.iter_mut().enumerate() {
1943+
*bucket = (*bucket + other.buckets[index]) / 2;
1944+
}
1945+
}
18451946
}
18461947

18471948
impl_writeable_tlv_based!(HistoricalBucketRangeTracker, { (0, buckets, required) });
@@ -1938,6 +2039,12 @@ mod bucketed_history {
19382039
-> DirectedHistoricalLiquidityTracker<&'a mut HistoricalLiquidityTracker> {
19392040
DirectedHistoricalLiquidityTracker { source_less_than_target, tracker: self }
19402041
}
2042+
2043+
pub fn merge(&mut self, other: &Self) {
2044+
self.min_liquidity_offset_history.merge(&other.min_liquidity_offset_history);
2045+
self.max_liquidity_offset_history.merge(&other.max_liquidity_offset_history);
2046+
self.recalculate_valid_point_count();
2047+
}
19412048
}
19422049

19432050
/// A set of buckets representing the history of where we've seen the minimum- and maximum-
@@ -2096,7 +2203,54 @@ mod bucketed_history {
20962203
Some((cumulative_success_prob * (1024.0 * 1024.0 * 1024.0)) as u64)
20972204
}
20982205
}
2206+
2207+
#[cfg(test)]
2208+
mod tests {
2209+
use crate::routing::scoring::ProbabilisticScoringFeeParameters;
2210+
2211+
use super::{HistoricalBucketRangeTracker, HistoricalLiquidityTracker};
2212+
#[test]
2213+
fn historical_liquidity_bucket_merge() {
2214+
let mut bucket1 = HistoricalBucketRangeTracker::new();
2215+
bucket1.track_datapoint(100, 1000);
2216+
assert_eq!(bucket1.buckets, [0u16,0,0,0,0,0,0,0,0,0,0,32,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]);
2217+
2218+
let mut bucket2 = HistoricalBucketRangeTracker::new();
2219+
bucket2.track_datapoint(0, 1000);
2220+
assert_eq!(bucket2.buckets, [32u16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]);
2221+
2222+
bucket1.merge(&bucket2);
2223+
assert_eq!(bucket1.buckets, [16u16,0,0,0,0,0,0,0,0,0,0,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]);
2224+
}
2225+
2226+
#[test]
2227+
fn historical_liquidity_tracker_merge() {
2228+
let params = ProbabilisticScoringFeeParameters::default();
2229+
2230+
let probability1: Option<u64>;
2231+
let mut tracker1 = HistoricalLiquidityTracker::new();
2232+
{
2233+
let mut directed_tracker1 = tracker1.as_directed_mut(true);
2234+
directed_tracker1.track_datapoint(100, 200, 1000);
2235+
probability1 = directed_tracker1.calculate_success_probability_times_billion(&params, 500, 1000);
2236+
}
2237+
2238+
let mut tracker2 = HistoricalLiquidityTracker::new();
2239+
{
2240+
let mut directed_tracker2 = tracker2.as_directed_mut(true);
2241+
directed_tracker2.track_datapoint(200, 300, 1000);
2242+
}
2243+
2244+
tracker1.merge(&tracker2);
2245+
2246+
let directed_tracker1 = tracker1.as_directed(true);
2247+
let probability = directed_tracker1.calculate_success_probability_times_billion(&params, 500, 1000);
2248+
2249+
assert_ne!(probability1, probability);
2250+
}
2251+
}
20992252
}
2253+
21002254
use bucketed_history::{LegacyHistoricalBucketRangeTracker, HistoricalBucketRangeTracker, DirectedHistoricalLiquidityTracker, HistoricalLiquidityTracker};
21012255

21022256
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for ProbabilisticScorer<G, L> where L::Target: Logger {
@@ -2193,15 +2347,15 @@ impl Readable for ChannelLiquidity {
21932347

21942348
#[cfg(test)]
21952349
mod tests {
2196-
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters, ProbabilisticScorer};
2350+
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScorer, ProbabilisticScoringDecayParameters, ProbabilisticScoringFeeParameters};
21972351
use crate::blinded_path::BlindedHop;
21982352
use crate::util::config::UserConfig;
21992353

22002354
use crate::ln::channelmanager;
22012355
use crate::ln::msgs::{ChannelAnnouncement, ChannelUpdate, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
22022356
use crate::routing::gossip::{EffectiveCapacity, NetworkGraph, NodeId};
22032357
use crate::routing::router::{BlindedTail, Path, RouteHop, CandidateRouteHop, PublicHopCandidate};
2204-
use crate::routing::scoring::{ChannelUsage, ScoreLookUp, ScoreUpdate};
2358+
use crate::routing::scoring::{ChannelLiquidities, ChannelUsage, CombinedScorer, ScoreLookUp, ScoreUpdate};
22052359
use crate::util::ser::{ReadableArgs, Writeable};
22062360
use crate::util::test_utils::{self, TestLogger};
22072361

@@ -2211,6 +2365,7 @@ mod tests {
22112365
use bitcoin::network::Network;
22122366
use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
22132367
use core::time::Duration;
2368+
use std::rc::Rc;
22142369
use crate::io;
22152370

22162371
fn source_privkey() -> SecretKey {
@@ -3702,6 +3857,47 @@ mod tests {
37023857
assert_eq!(scorer.historical_estimated_payment_success_probability(42, &target, amount_msat, &params, false),
37033858
Some(0.0));
37043859
}
3860+
3861+
#[test]
3862+
fn combined_scorer() {
3863+
let logger = TestLogger::new();
3864+
let network_graph = network_graph(&logger);
3865+
let params = ProbabilisticScoringFeeParameters::default();
3866+
let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
3867+
scorer.payment_path_failed(&payment_path_for_amount(600), 42, Duration::ZERO);
3868+
3869+
let mut combined_scorer = CombinedScorer::new(scorer);
3870+
3871+
let source = source_node_id();
3872+
let usage = ChannelUsage {
3873+
amount_msat: 750,
3874+
inflight_htlc_msat: 0,
3875+
effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000, htlc_maximum_msat: 1_000 },
3876+
};
3877+
let network_graph = network_graph.read_only();
3878+
let channel = network_graph.channel(42).unwrap();
3879+
let (info, _) = channel.as_directed_from(&source).unwrap();
3880+
let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate {
3881+
info,
3882+
short_channel_id: 42,
3883+
});
3884+
let penalty = combined_scorer.channel_penalty_msat(&candidate, usage, &params);
3885+
3886+
let mut external_liquidity = ChannelLiquidity::new(Duration::ZERO);
3887+
let logger_rc = Rc::new(&logger); // Why necessary and not above for the network graph?
3888+
external_liquidity.as_directed_mut(&source_node_id(), &target_node_id(), 1_000).
3889+
successful(1000, Duration::ZERO, format_args!("test channel"), logger_rc.as_ref());
3890+
3891+
let mut external_scores = ChannelLiquidities::new();
3892+
3893+
external_scores.insert(42, external_liquidity);
3894+
combined_scorer.merge(external_scores);
3895+
3896+
let penalty_after_merge = combined_scorer.channel_penalty_msat(&candidate, usage, &params);
3897+
3898+
// Since the external source observed a successful payment, the penalty should be lower after the merge.
3899+
assert!(penalty_after_merge < penalty);
3900+
}
37053901
}
37063902

37073903
#[cfg(ldk_bench)]

0 commit comments

Comments
 (0)