diff --git a/crates/fiber-lib/src/fiber/amp.rs b/crates/fiber-lib/src/fiber/amp.rs new file mode 100644 index 000000000..942d8a9a7 --- /dev/null +++ b/crates/fiber-lib/src/fiber/amp.rs @@ -0,0 +1,153 @@ +use crate::fiber::{hash_algorithm::HashAlgorithm, types::Hash256}; +use bitcoin::hashes::{sha256::Hash as Sha256, Hash as _}; +use rand::RngCore; +use serde::{Deserialize, Serialize}; + +/// AmpSecret represents an n-of-n sharing of a secret 32-byte value. +/// The secret can be recovered by XORing all n shares together. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct AmpSecret([u8; 32]); + +impl AmpSecret { + /// Create a new AmpSecret from a 32-byte array + pub fn new(bytes: [u8; 32]) -> Self { + Self(bytes) + } + + /// Create a zero AmpSecret + pub fn zero() -> Self { + Self([0u8; 32]) + } + + /// Generate a random AmpSecret + pub fn random() -> Self { + let mut rng = rand::thread_rng(); + let mut bytes = [0u8; 32]; + rng.fill_bytes(&mut bytes); + Self(bytes) + } + + /// XOR this AmpSecret with another AmpSecret, storing the result in self + pub fn xor_assign(&mut self, other: &AmpSecret) { + for (a, b) in self.0.iter_mut().zip(other.0.iter()) { + *a ^= b; + } + } + + /// XOR two shares and return the result + pub fn xor(&self, other: &AmpSecret) -> AmpSecret { + let mut result = *self; + result.xor_assign(other); + result + } + + /// Get the underlying bytes + pub fn as_bytes(&self) -> &[u8; 32] { + &self.0 + } + + /// Convert to bytes + pub fn to_bytes(self) -> [u8; 32] { + self.0 + } + + /// generate a random AmpSecret sequence + pub fn gen_random_sequence(root: AmpSecret, n: u16) -> Vec { + let mut shares: Vec = (0..n - 1).map(|_| AmpSecret::random()).collect(); + + let mut final_secret = root; + for share in &shares { + final_secret.xor_assign(share); + } + shares.push(final_secret); + shares + } +} + +impl From<[u8; 32]> for AmpSecret { + fn from(bytes: [u8; 32]) -> Self { + Self(bytes) + } +} + +impl AsRef<[u8]> for AmpSecret { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +/// AmpChildDesc is the meta data for a child payment derived from the root seed. +/// It contains the index of the child and the share used in the derivation. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AmpChildDesc { + pub index: u16, + pub secret: AmpSecret, +} + +impl AmpChildDesc { + pub fn new(index: u16, secret: AmpSecret) -> Self { + Self { index, secret } + } +} + +/// Child is a payment hash and preimage pair derived from the root seed and ChildDesc. +/// In addition to the derived values, a Child carries all information required in +/// the derivation apart from the root seed (unless n=1). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AmpChild { + /// Preimage is the child payment preimage that can be used to settle the HTLC + pub preimage: Hash256, + + /// Hash is the child payment hash that to be carried by the HTLC + pub hash: Hash256, +} + +impl AmpChild { + pub fn new(preimage: Hash256, hash: Hash256) -> Self { + Self { preimage, hash } + } + + /// DeriveChild computes the child preimage and child hash for a given (root, share, index) tuple. + /// The derivation is defined as: + /// child_preimage = SHA256(root || share || be16(index)) + /// child_hash = SHA256(child_preimage) + pub fn derive_child( + root: AmpSecret, + desc: AmpChildDesc, + hash_algorithm: HashAlgorithm, + ) -> AmpChild { + let index_bytes = desc.index.to_be_bytes(); + + // Compute child_preimage as SHA256(root || share || child_index) + let mut preimage_data = Vec::with_capacity(32 + 32 + 2); + preimage_data.extend_from_slice(root.as_bytes()); + preimage_data.extend_from_slice(desc.secret.as_bytes()); + preimage_data.extend_from_slice(&index_bytes); + + let preimage_hash = Sha256::hash(&preimage_data); + let preimage: Hash256 = preimage_hash.to_byte_array().into(); + + // this is the payment hash for HTLC + let hash: Hash256 = hash_algorithm.hash(preimage.as_ref()).into(); + AmpChild::new(preimage, hash) + } + + /// ReconstructChildren derives the set of children hashes and preimages from the + /// provided descriptors. + pub fn reconstruct_amp_children( + child_descs: &[AmpChildDesc], + hash_algorithm: HashAlgorithm, + ) -> Vec { + // Recompute the root by XORing the provided shares + let mut root = AmpSecret::zero(); + for desc in child_descs { + root.xor_assign(&desc.secret); + } + + // With the root computed, derive the child hashes and preimages + child_descs + .iter() + .map(|data| Self::derive_child(root, data.clone(), hash_algorithm)) + .collect() + } +} diff --git a/crates/fiber-lib/src/fiber/builtin_records.rs b/crates/fiber-lib/src/fiber/builtin_records.rs new file mode 100644 index 000000000..d2e332ed7 --- /dev/null +++ b/crates/fiber-lib/src/fiber/builtin_records.rs @@ -0,0 +1,124 @@ +use serde::{Deserialize, Serialize}; + +use crate::fiber::{ + amp::{AmpChildDesc, AmpSecret}, + network::USER_CUSTOM_RECORDS_MAX_INDEX, + types::Hash256, + PaymentCustomRecords, +}; + +#[derive(Eq, PartialEq, Debug)] +/// Bolt04 basic MPP payment data record +pub struct BasicMppPaymentData { + pub payment_secret: Hash256, + pub total_amount: u128, +} + +impl BasicMppPaymentData { + // record type for payment data record in bolt04 + // custom records key from 65536 is reserved for internal usage + pub const CUSTOM_RECORD_KEY: u32 = USER_CUSTOM_RECORDS_MAX_INDEX + 1; + + pub fn new(payment_secret: Hash256, total_amount: u128) -> Self { + Self { + payment_secret, + total_amount, + } + } + + fn to_vec(&self) -> Vec { + let mut vec = Vec::new(); + vec.extend_from_slice(self.payment_secret.as_ref()); + vec.extend_from_slice(&self.total_amount.to_le_bytes()); + vec + } + + pub fn write(&self, custom_records: &mut PaymentCustomRecords) { + custom_records + .data + .insert(Self::CUSTOM_RECORD_KEY, self.to_vec()); + } + + pub fn read(custom_records: &PaymentCustomRecords) -> Option { + custom_records + .data + .get(&Self::CUSTOM_RECORD_KEY) + .and_then(|data| { + if data.len() != 32 + 16 { + return None; + } + let secret: [u8; 32] = data[..32].try_into().unwrap(); + let payment_secret = Hash256::from(secret); + let total_amount = u128::from_le_bytes(data[32..].try_into().unwrap()); + Some(Self::new(payment_secret, total_amount)) + }) + } +} + +#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] +pub struct AmpPaymentData { + pub total_amp_count: u16, + pub payment_hash: Hash256, + pub child_desc: AmpChildDesc, + pub total_amount: u128, +} + +impl AmpPaymentData { + pub const CUSTOM_RECORD_KEY: u32 = USER_CUSTOM_RECORDS_MAX_INDEX + 2; + + pub fn new( + payment_hash: Hash256, + total_amp_count: u16, + child_desc: AmpChildDesc, + total_amount: u128, + ) -> Self { + Self { + payment_hash, + total_amp_count, + child_desc, + total_amount, + } + } + + fn to_vec(&self) -> Vec { + let mut vec = Vec::new(); + vec.extend_from_slice(self.payment_hash.as_ref()); + vec.extend_from_slice(&self.total_amp_count.to_le_bytes()); + vec.extend_from_slice(&self.child_desc.index.to_le_bytes()); + vec.extend_from_slice(self.child_desc.secret.as_bytes()); + vec.extend_from_slice(&self.total_amount.to_le_bytes()); + vec + } + + pub fn index(&self) -> u16 { + self.child_desc.index + } + + pub fn write(&self, custom_records: &mut PaymentCustomRecords) { + custom_records + .data + .insert(Self::CUSTOM_RECORD_KEY, self.to_vec()); + } + + pub fn read(custom_records: &PaymentCustomRecords) -> Option { + custom_records + .data + .get(&Self::CUSTOM_RECORD_KEY) + .and_then(|data| { + if data.len() != 32 + 4 + 32 + 16 { + return None; + } + let parent_hash: [u8; 32] = data[..32].try_into().unwrap(); + let total_amp_count = u16::from_le_bytes(data[32..34].try_into().unwrap()); + let index = u16::from_le_bytes(data[34..36].try_into().unwrap()); + let secret = AmpSecret::new(data[36..68].try_into().unwrap()); + let total_amount = u128::from_le_bytes(data[68..].try_into().unwrap()); + Some(Self::new( + Hash256::from(parent_hash), + total_amp_count, + AmpChildDesc::new(index, secret), + total_amount, + )) + }) + } +} diff --git a/crates/fiber-lib/src/fiber/channel.rs b/crates/fiber-lib/src/fiber/channel.rs index 0fbf3a859..dfb2b3d73 100644 --- a/crates/fiber-lib/src/fiber/channel.rs +++ b/crates/fiber-lib/src/fiber/channel.rs @@ -5,6 +5,7 @@ use super::{ gossip::SOFT_BROADCAST_MESSAGES_CONSIDERED_STALE_DURATION, graph::ChannelUpdateInfo, types::ForwardTlcResult, }; +use crate::fiber::builtin_records::AmpPaymentData; use crate::fiber::config::MILLI_SECONDS_PER_EPOCH; use crate::fiber::fee::{check_open_channel_parameters, check_tlc_delta_with_epochs}; #[cfg(any(debug_assertions, feature = "bench"))] @@ -904,6 +905,19 @@ where tlc_id: TLCId, ) { let tlc_info = state.get_received_tlc(tlc_id).expect("expect tlc").clone(); + let parent_payment_hash = tlc_info + .parent_payment_hash + .unwrap_or(tlc_info.payment_hash); + + // MPP or AMP + if tlc_info.total_amount.is_some() { + // add to pending settlement tlc set + // the tlc set will be settled by network actor + state + .pending_notify_settle_tlcs + .push((parent_payment_hash, tlc_info.id(), true)); + return; + } let Some(preimage) = self.store.get_preimage(&tlc_info.payment_hash) else { return; @@ -922,13 +936,13 @@ where CkbInvoiceStatus::Expired => { remove_reason = RemoveTlcReason::RemoveTlcFail(TlcErrPacket::new( TlcErr::new(TlcErrorCode::InvoiceExpired), - &tlc.shared_secret, + &tlc_info.shared_secret, )); } CkbInvoiceStatus::Cancelled => { remove_reason = RemoveTlcReason::RemoveTlcFail(TlcErrPacket::new( TlcErr::new(TlcErrorCode::InvoiceCancelled), - &tlc.shared_secret, + &tlc_info.shared_secret, )); } CkbInvoiceStatus::Paid => { @@ -937,28 +951,17 @@ where error!("invoice already paid, ignore"); return; } - _ if invoice.allow_mpp() => { - // add to pending settlement tlc set - // the tlc set will be settled by network actor - state - .pending_notify_settle_tlcs - .push((tlc.payment_hash, tlc.id(), true)); - - // just return, the tlc set will be settled by network actor - return; - } _ => { // single path payment if !is_invoice_fulfilled(invoice, std::slice::from_ref(&tlc)) { remove_reason = RemoveTlcReason::RemoveTlcFail(TlcErrPacket::new( TlcErr::new(TlcErrorCode::IncorrectOrUnknownPaymentDetails), - &tlc.shared_secret, + &tlc_info.shared_secret, )); } } } } - // remove tlc if matches!(remove_reason, RemoveTlcReason::RemoveTlcFulfill(_)) && invoice.is_some() { state @@ -1034,6 +1037,7 @@ where ) -> ProcessingChannelResult { let payment_hash = add_tlc.payment_hash; let forward_amount = peeled_onion_packet.current.amount; + let channel_id = state.get_id(); let tlc = state .tlc_state @@ -1050,7 +1054,16 @@ where return Err(ProcessingChannelError::IncorrectFinalTlcExpiry); } - let invoice = self.store.get_invoice(&payment_hash); + let basic_mpp_payment_data = peeled_onion_packet.basic_mpp_custom_records(); + let atomic_mpp_payment_data = peeled_onion_packet.atomic_mpp_custom_records(); + let is_atomic_mpp = atomic_mpp_payment_data.is_some(); + + let parent_payment_hash = atomic_mpp_payment_data + .as_ref() + .map(|a| a.payment_hash) + .unwrap_or(payment_hash); + + let invoice = self.store.get_invoice(&parent_payment_hash); if let Some(ref invoice) = invoice { let invoice_status = self.get_invoice_status(invoice); if !matches!(invoice_status, CkbInvoiceStatus::Open) { @@ -1066,8 +1079,8 @@ where tlc.is_last = true; // extract MPP total payment fields from onion packet - match (&invoice, peeled_onion_packet.mpp_custom_records()) { - (Some(invoice), Some(record)) => { + match (&invoice, basic_mpp_payment_data, atomic_mpp_payment_data) { + (Some(invoice), Some(record), None) => { if record.total_amount < invoice.amount.unwrap_or_default() { error!( "total amount is less than invoice amount: {:?}", @@ -1092,8 +1105,24 @@ where tlc.payment_secret = Some(record.payment_secret); tlc.total_amount = Some(record.total_amount); } - (Some(invoice), None) => { - if invoice.allow_mpp() { + (_, None, Some(record)) => { + tlc.total_amount = Some(record.total_amount); + tlc.parent_payment_hash = Some(parent_payment_hash); + // now save amp_record with parent payment_hash as key + self.store.insert_atomic_mpp_payment_data( + parent_payment_hash, + channel_id, + tlc.tlc_id.into(), + record.clone(), + ); + } + (Some(invoice), _, None) if invoice.atomic_mpp() => { + return Err(ProcessingChannelError::FinalIncorrectMPPInfo( + "invoice is atomic mpp but no MPP records in onion packet".to_string(), + )); + } + (Some(invoice), None, _) => { + if invoice.basic_mpp() { // FIXME: whether we allow MPP without MPP records in onion packet? // currently we allow it pay with enough amount // TODO: add a unit test of using single path payment pay MPP invoice successfully @@ -1107,7 +1136,7 @@ where return Err(ProcessingChannelError::FinalIncorrectHTLCAmount); } } - (None, Some(_record)) => { + (None, Some(_), _) => { error!("invoice not found for MPP payment: {:?}", payment_hash); return Err(ProcessingChannelError::FinalIncorrectMPPInfo( "invoice not found".to_string(), @@ -1143,6 +1172,8 @@ where // Don't call self.store_preimage here, because it will reveal the preimage to watchtower. self.store.insert_preimage(payment_hash, preimage); + } else if is_atomic_mpp { + // atomic mpp don't have preimage } else if let Some(invoice) = invoice { // The TLC should be held until the invoice is expired or the TLC itself is // expired. @@ -3010,6 +3041,8 @@ pub struct TlcInfo { pub total_amount: Option, /// bolt04 payment secret, only exists for last hop in multi-path payment pub payment_secret: Option, + /// atomic parent payment_hash + pub parent_payment_hash: Option, /// The attempt id associate with the tlc, only on outbound tlc /// only exists for first hop in multi-path payment pub attempt_id: Option, @@ -5795,6 +5828,13 @@ impl ChannelActorState { if is_sent { // local peer can not sent more tlc amount than they have let pending_sent_amount = self.get_offered_tlc_balance(); + debug!( + "add_amount: {}, to_local_amount:{} - pending_sent_amount:{} = {}", + add_amount, + self.to_local_amount, + pending_sent_amount, + self.to_local_amount.saturating_sub(pending_sent_amount) + ); if add_amount > self.to_local_amount.saturating_sub(pending_sent_amount) { return Err(ProcessingChannelError::TlcAmountExceedLimit); } @@ -5814,6 +5854,13 @@ impl ChannelActorState { } else { // remote peer can not sent more tlc amount than they have let pending_recv_amount = self.get_received_tlc_balance(); + debug!( + "remote_amount: {}, to_remote_amount:{} - pending_recv_amount:{} = {}", + add_amount, + self.to_remote_amount, + pending_recv_amount, + self.to_local_amount.saturating_sub(pending_recv_amount) + ); if add_amount > self.to_remote_amount.saturating_sub(pending_recv_amount) { return Err(ProcessingChannelError::TlcAmountExceedLimit); } @@ -5849,6 +5896,7 @@ impl ChannelActorState { attempt_id: command.attempt_id, amount: command.amount, payment_hash: command.payment_hash, + parent_payment_hash: None, expiry: command.expiry, hash_algorithm: command.hash_algorithm, created_at: self.get_current_commitment_numbers(), @@ -5876,6 +5924,7 @@ impl ChannelActorState { channel_id: self.get_id(), amount: message.amount, payment_hash: message.payment_hash, + parent_payment_hash: None, attempt_id: None, expiry: message.expiry, hash_algorithm: message.hash_algorithm, @@ -7602,6 +7651,17 @@ pub trait ChannelActorStateStore { fn remove_payment_hold_tlc(&self, payment_hash: &Hash256, channel_id: &Hash256, tlc_id: u64); fn get_payment_hold_tlcs(&self, payment_hash: Hash256) -> Vec; fn get_node_hold_tlcs(&self) -> HashMap>; + fn insert_atomic_mpp_payment_data( + &self, + payment_hash: Hash256, + channel_id: Hash256, + tlc_id: u64, + payment_data: AmpPaymentData, + ); + fn get_atomic_mpp_payment_data( + &self, + payment_hash: &Hash256, + ) -> Vec<((Hash256, u64), AmpPaymentData)>; } /// A wrapper on CommitmentTransaction that has a partial signature along with diff --git a/crates/fiber-lib/src/fiber/features.rs b/crates/fiber-lib/src/fiber/features.rs index d61bdd96a..5e69d1e73 100644 --- a/crates/fiber-lib/src/fiber/features.rs +++ b/crates/fiber-lib/src/fiber/features.rs @@ -74,6 +74,7 @@ pub mod feature_bits { declare_feature_bits_and_methods! { GOSSIP_QUERIES, 1; BASIC_MPP, 3; + ATOMIC_MPP, 5; // more features, please note that base bit must be defined as increasing odd numbers } } @@ -88,6 +89,7 @@ impl Default for FeatureVector { let mut feature = Self::new(); feature.set_gossip_queries_required(); feature.set_basic_mpp_required(); + feature.set_atomic_mpp_required(); // set other default features here // ... diff --git a/crates/fiber-lib/src/fiber/graph.rs b/crates/fiber-lib/src/fiber/graph.rs index 86cf72f5b..572541037 100644 --- a/crates/fiber-lib/src/fiber/graph.rs +++ b/crates/fiber-lib/src/fiber/graph.rs @@ -2088,4 +2088,9 @@ pub trait NetworkGraphStateStore { fn get_attempts(&self, payment_hash: Hash256) -> Vec; fn delete_attempts(&self, payment_hash: Hash256); fn get_attempts_with_statuses(&self, status: &[AttemptStatus]) -> Vec; + fn get_payment_hash_with_attempt_hash( + &self, + attempt_hash: Hash256, + attempt_id: u64, + ) -> Option; } diff --git a/crates/fiber-lib/src/fiber/mod.rs b/crates/fiber-lib/src/fiber/mod.rs index 1a1c59605..c0ee70d25 100644 --- a/crates/fiber-lib/src/fiber/mod.rs +++ b/crates/fiber-lib/src/fiber/mod.rs @@ -1,3 +1,5 @@ +pub mod amp; +pub mod builtin_records; pub mod channel; pub mod config; pub mod features; diff --git a/crates/fiber-lib/src/fiber/network.rs b/crates/fiber-lib/src/fiber/network.rs index 6e39ee7d9..821691dbc 100644 --- a/crates/fiber-lib/src/fiber/network.rs +++ b/crates/fiber-lib/src/fiber/network.rs @@ -62,9 +62,9 @@ use super::gossip::{GossipActorMessage, GossipMessageStore, GossipMessageUpdates use super::graph::{NetworkGraph, NetworkGraphStateStore, OwnedChannelUpdateEvent, RouterHop}; use super::key::blake2b_hash_with_salt; use super::types::{ - BasicMppPaymentData, BroadcastMessageWithTimestamp, EcdsaSignature, FiberMessage, - ForwardTlcResult, GossipMessage, Hash256, Init, NodeAnnouncement, OpenChannel, Privkey, Pubkey, - RemoveTlcFulfill, RemoveTlcReason, TlcErr, TlcErrData, TlcErrorCode, + BroadcastMessageWithTimestamp, EcdsaSignature, FiberMessage, ForwardTlcResult, GossipMessage, + Hash256, Init, NodeAnnouncement, OpenChannel, Privkey, Pubkey, RemoveTlcFulfill, + RemoveTlcReason, TlcErr, TlcErrData, TlcErrorCode, }; use super::{ FiberConfig, InFlightCkbTxActor, InFlightCkbTxActorArguments, InFlightCkbTxKind, @@ -76,6 +76,8 @@ use crate::ckb::{ CkbChainMessage, FundingError, FundingRequest, FundingTx, GetShutdownTxRequest, GetShutdownTxResponse, }; +use crate::fiber::amp::{AmpChild, AmpChildDesc, AmpSecret}; +use crate::fiber::builtin_records::{AmpPaymentData, BasicMppPaymentData}; use crate::fiber::channel::{ tlc_expiry_delay, AddTlcCommand, AddTlcResponse, ChannelActorState, ChannelEphemeralConfig, ChannelInitializationOperation, RetryableTlcOperation, ShutdownCommand, TxCollaborationCommand, @@ -91,6 +93,7 @@ use crate::fiber::config::{ use crate::fiber::fee::{check_open_channel_parameters, check_tlc_delta_with_epochs}; use crate::fiber::gossip::{GossipConfig, GossipService, SubscribableGossipMessageStore}; use crate::fiber::graph::GraphChannelStat; +use crate::fiber::payment::MppMode; #[cfg(any(debug_assertions, test, feature = "bench"))] use crate::fiber::payment::SessionRoute; use crate::fiber::payment::{Attempt, AttemptStatus, PaymentSession, PaymentStatus}; @@ -102,8 +105,8 @@ use crate::fiber::KeyPair; use crate::invoice::{ CkbInvoice, CkbInvoiceStatus, InvoiceError, InvoiceStore, PreimageStore, SettleInvoiceError, }; -use crate::utils::payment::is_invoice_fulfilled; -use crate::{now_timestamp_as_millis_u64, unwrap_or_return, Error}; +use crate::utils::payment::{is_atomic_mpp_fulfilled, is_invoice_fulfilled}; +use crate::{gen_rand_sha256_hash, now_timestamp_as_millis_u64, unwrap_or_return, Error}; pub const FIBER_PROTOCOL_ID: ProtocolId = ProtocolId::new(42); @@ -401,6 +404,8 @@ pub struct SendPaymentCommand { pub max_parts: Option, // keysend payment, default is false pub keysend: Option, + // allow atomic mpp, default is false, + pub atomic_mpp: Option, // udt type script #[serde_as(as = "Option")] pub udt_type_script: Option