From c8a54d05d97ac78b3caf9023b9e92c0c79abf5fa Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Fri, 17 Jan 2025 10:46:27 +0100 Subject: [PATCH 1/6] recipient pays blinded fees When using blinded payment paths, the fees of the blinded path should be paid by the recipient, not the sender. - It is more fair: the sender chooses the non blinded part of the path and pays the corresponding fees, the recipient chooses the blinded part of the path and pays the corresponding fees. - It is more private: the sender does not learn the actual fees for the path and can't use this information to unblind the path. --- .../main/scala/fr/acinq/eclair/Setup.scala | 2 +- .../fr/acinq/eclair/db/DbEventHandler.scala | 2 +- .../fr/acinq/eclair/db/DualDatabases.scala | 12 +- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 11 +- .../fr/acinq/eclair/db/pg/PgAuditDb.scala | 31 ++- .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 59 +++-- .../eclair/db/sqlite/SqliteAuditDb.scala | 31 ++- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 66 +++-- .../main/scala/fr/acinq/eclair/package.scala | 11 + .../acinq/eclair/payment/PaymentEvents.scala | 5 +- .../acinq/eclair/payment/PaymentPacket.scala | 1 - .../eclair/payment/offer/OfferManager.scala | 46 ++-- .../payment/offer/OfferPaymentMetadata.scala | 3 + .../payment/receive/MultiPartHandler.scala | 121 +++------ .../payment/receive/MultiPartPaymentFSM.scala | 13 +- .../eclair/payment/relay/NodeRelay.scala | 8 +- .../acinq/eclair/payment/relay/Relayer.scala | 4 + .../payment/send/PaymentLifecycle.scala | 3 + .../eclair/router/BlindedRouteCreation.scala | 16 +- .../scala/fr/acinq/eclair/router/Router.scala | 10 + .../eclair/wire/protocol/RouteBlinding.scala | 5 +- .../scala/fr/acinq/eclair/PackageSpec.scala | 11 + .../fr/acinq/eclair/db/AuditDbSpec.scala | 4 +- .../fr/acinq/eclair/db/PaymentsDbSpec.scala | 36 +-- .../integration/PaymentIntegrationSpec.scala | 73 +++--- .../basic/fixtures/MinimalNodeFixture.scala | 2 +- .../basic/payment/OfferPaymentSpec.scala | 238 +++++++++++++++--- .../eclair/payment/MultiPartHandlerSpec.scala | 128 ++++------ .../payment/MultiPartPaymentFSMSpec.scala | 2 +- .../MultiPartPaymentLifecycleSpec.scala | 2 +- .../eclair/payment/PaymentInitiatorSpec.scala | 2 +- .../eclair/payment/PaymentPacketSpec.scala | 26 +- .../payment/PostRestartHtlcCleanerSpec.scala | 2 +- .../payment/offer/OfferManagerSpec.scala | 92 ++++++- .../payment/receive/InvoicePurgerSpec.scala | 4 +- .../send/BlindedPathsResolverSpec.scala | 12 +- .../acinq/eclair/router/BaseRouterSpec.scala | 2 +- .../router/BlindedRouteCreationSpec.scala | 6 +- .../src/test/resources/api/received-success | 2 +- .../fr/acinq/eclair/api/ApiServiceSpec.scala | 6 +- 40 files changed, 684 insertions(+), 426 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index 2617d7e1e7..9f09263c55 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -357,7 +357,7 @@ class Setup(val datadir: File, } dbEventHandler = system.actorOf(SimpleSupervisor.props(DbEventHandler.props(nodeParams), "db-event-handler", SupervisorStrategy.Resume)) register = system.actorOf(SimpleSupervisor.props(Register.props(), "register", SupervisorStrategy.Resume)) - offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, router, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager") + offerManager = system.spawn(Behaviors.supervise(OfferManager(nodeParams, paymentTimeout = 1 minute)).onFailure(typed.SupervisorStrategy.resume), name = "offer-manager") paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler", SupervisorStrategy.Resume)) triggerer = system.spawn(Behaviors.supervise(AsyncPaymentTriggerer()).onFailure(typed.SupervisorStrategy.resume), name = "async-payment-triggerer") peerReadyManager = system.spawn(Behaviors.supervise(PeerReadyManager()).onFailure(typed.SupervisorStrategy.restart), name = "peer-ready-manager") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala index 9356799054..7873bdef3f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala @@ -68,7 +68,7 @@ class DbEventHandler(nodeParams: NodeParams) extends Actor with DiagnosticActorL PaymentMetrics.PaymentFailed.withTag(PaymentTags.Direction, PaymentTags.Directions.Sent).increment() case e: PaymentReceived => - PaymentMetrics.PaymentAmount.withTag(PaymentTags.Direction, PaymentTags.Directions.Received).record(e.amount.truncateToSatoshi.toLong) + PaymentMetrics.PaymentAmount.withTag(PaymentTags.Direction, PaymentTags.Directions.Received).record(e.realAmount.truncateToSatoshi.toLong) PaymentMetrics.PaymentParts.withTag(PaymentTags.Direction, PaymentTags.Directions.Received).record(e.parts.length) auditDb.add(e) e.parts.foreach(p => channelsDb.updateChannelMeta(p.fromChannelId, ChannelEvent.EventType.PaymentReceived)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index 7fb7e56eba..d50acd2f8e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -319,14 +319,14 @@ case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends Pa primary.addIncomingPayment(pr, preimage, paymentType) } - override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = { - runAsync(secondary.receiveIncomingPayment(paymentHash, amount, receivedAt)) - primary.receiveIncomingPayment(paymentHash, amount, receivedAt) + override def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = { + runAsync(secondary.receiveIncomingPayment(paymentHash, virtualAmount, realAmount, receivedAt)) + primary.receiveIncomingPayment(paymentHash, virtualAmount, realAmount, receivedAt) } - override def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = { - runAsync(secondary.receiveIncomingOfferPayment(pr, preimage, amount, receivedAt, paymentType)) - primary.receiveIncomingOfferPayment(pr, preimage, amount, receivedAt, paymentType) + override def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = { + runAsync(secondary.receiveIncomingOfferPayment(pr, preimage, virtualAmount, realAmount, receivedAt, paymentType)) + primary.receiveIncomingOfferPayment(pr, preimage, virtualAmount, realAmount, receivedAt, paymentType) } override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index f5fcdfc3fa..a10af427ab 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -36,13 +36,13 @@ trait IncomingPaymentsDb { * Mark an incoming payment as received (paid). The received amount may exceed the invoice amount. * If there was no matching invoice in the DB, this will return false. */ - def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Boolean + def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Boolean /** * Add a new incoming offer payment as received. * If the invoice is already paid, adds `amount` to the amount paid. */ - def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now(), paymentType: String = PaymentType.Blinded): Unit + def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now(), paymentType: String = PaymentType.Blinded): Unit /** Get information about the incoming payment (paid or not) for the given payment hash, if any. */ def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] @@ -150,10 +150,11 @@ object IncomingPaymentStatus { /** * Payment has been successfully received. * - * @param amount amount of the payment received, in milli-satoshis (may exceed the invoice amount). - * @param receivedAt absolute time in milli-seconds since UNIX epoch when the payment was received. + * @param virtualAmount amount of the payment received, in milli-satoshis (may exceed the invoice amount). + * @param realAmount amount of the payment received, in milli-satoshis (may be less or more than the invoice amount). + * @param receivedAt absolute time in milli-seconds since UNIX epoch when the payment was received. */ - case class Received(amount: MilliSatoshi, receivedAt: TimestampMilli) extends IncomingPaymentStatus + case class Received(virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli) extends IncomingPaymentStatus } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index 75cd8c6b83..49f9c79035 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -36,7 +36,7 @@ import javax.sql.DataSource object PgAuditDb { val DB_NAME = "audit" - val CURRENT_VERSION = 12 + val CURRENT_VERSION = 13 } class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { @@ -114,12 +114,20 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON audit.transactions_published(channel_id)") } + def migration1213(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE audit.received RENAME TO received_old") + statement.executeUpdate("CREATE TABLE audit.received (virtual_amount_msat BIGINT NOT NULL, real_amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("INSERT INTO audit.received SELECT amount_msat, amount_msat, payment_hash, from_channel_id, timestamp FROM audit.received_old") + statement.executeUpdate("DROP TABLE audit.received_old") + statement.executeUpdate("CREATE INDEX received_timestamp_idx ON audit.received(timestamp)") + } + getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE SCHEMA audit") statement.executeUpdate("CREATE TABLE audit.sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE audit.received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE audit.received (virtual_amount_msat BIGINT NOT NULL, real_amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE TABLE audit.relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE TABLE audit.relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE TABLE audit.channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") @@ -149,7 +157,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON audit.transactions_published(channel_id)") statement.executeUpdate("CREATE INDEX transactions_published_timestamp_idx ON audit.transactions_published(timestamp)") statement.executeUpdate("CREATE INDEX transactions_confirmed_timestamp_idx ON audit.transactions_confirmed(timestamp)") - case Some(v@(4 | 5 | 6 | 7 | 8 | 9 | 10 | 11)) => + case Some(v@(4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 5) { migration45(statement) @@ -175,6 +183,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { if (v < 12) { migration1112(statement) } + if (v < 13) { + migration1213(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -220,12 +231,13 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def add(e: PaymentReceived): Unit = withMetrics("audit/add-payment-received", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO audit.received VALUES (?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.received VALUES (?, ?, ?, ?, ?)")) { statement => e.parts.foreach(p => { - statement.setLong(1, p.amount.toLong) - statement.setString(2, e.paymentHash.toHex) - statement.setString(3, p.fromChannelId.toHex) - statement.setTimestamp(4, p.timestamp.toSqlTimestamp) + statement.setLong(1, p.virtualAmount.toLong) + statement.setLong(2, p.realAmount.toLong) + statement.setString(3, e.paymentHash.toHex) + statement.setString(4, p.fromChannelId.toHex) + statement.setTimestamp(5, p.timestamp.toSqlTimestamp) statement.addBatch() }) statement.executeBatch() @@ -404,7 +416,8 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { .foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) => val paymentHash = rs.getByteVector32FromHex("payment_hash") val part = PaymentReceived.PartialPayment( - MilliSatoshi(rs.getLong("amount_msat")), + MilliSatoshi(rs.getLong("virtual_amount_msat")), + MilliSatoshi(rs.getLong("real_amount_msat")), rs.getByteVector32FromHex("from_channel_id"), TimestampMilli.fromSqlTimestamp(rs.getTimestamp("timestamp"))) val received = receivedByHash.get(paymentHash) match { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index d78883a296..c16076b1d8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -36,7 +36,7 @@ import scala.util.{Failure, Success, Try} object PgPaymentsDb { val DB_NAME = "payments" - val CURRENT_VERSION = 8 + val CURRENT_VERSION = 9 } class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb with Logging { @@ -77,11 +77,19 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON payments.sent(offer_id)") } + def migration89(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE payments.received RENAME TO received_old") + statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, path_ids BYTEA, payment_request TEXT NOT NULL, virtual_received_msat BIGINT, real_received_msat BIGINT, created_at TIMESTAMP WITH TIME ZONE NOT NULL, expire_at TIMESTAMP WITH TIME ZONE NOT NULL, received_at TIMESTAMP WITH TIME ZONE)") + statement.executeUpdate("INSERT INTO payments.received SELECT payment_hash, payment_type, payment_preimage, path_ids, payment_request, received_msat, received_msat, created_at, expire_at, received_at FROM payments.received_old") + statement.executeUpdate("DROP TABLE payments.received_old") + statement.executeUpdate("CREATE INDEX received_created_idx ON payments.received(created_at)") + } + getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE SCHEMA payments") - statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, path_ids BYTEA, payment_request TEXT NOT NULL, received_msat BIGINT, created_at TIMESTAMP WITH TIME ZONE NOT NULL, expire_at TIMESTAMP WITH TIME ZONE NOT NULL, received_at TIMESTAMP WITH TIME ZONE)") + statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, path_ids BYTEA, payment_request TEXT NOT NULL, virtual_received_msat BIGINT, real_received_msat BIGINT, created_at TIMESTAMP WITH TIME ZONE NOT NULL, expire_at TIMESTAMP WITH TIME ZONE NOT NULL, received_at TIMESTAMP WITH TIME ZONE)") statement.executeUpdate("CREATE TABLE payments.sent (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, offer_id TEXT, payer_key TEXT, payment_route BYTEA, failures BYTEA, created_at TIMESTAMP WITH TIME ZONE NOT NULL, completed_at TIMESTAMP WITH TIME ZONE)") statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON payments.sent(parent_id)") @@ -89,7 +97,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON payments.sent(offer_id)") statement.executeUpdate("CREATE INDEX sent_created_idx ON payments.sent(created_at)") statement.executeUpdate("CREATE INDEX received_created_idx ON payments.received(created_at)") - case Some(v@(4 | 5 | 6 | 7)) => + case Some(v@(4 | 5 | 6 | 7 | 8)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 5) { migration45(statement) @@ -103,6 +111,9 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit if (v < 8) { migration78(statement) } + if (v < 9) { + migration89(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -268,30 +279,32 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Postgres) { + override def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("UPDATE payments.received SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => - update.setLong(1, amount.toLong) - update.setTimestamp(2, receivedAt.toSqlTimestamp) - update.setString(3, paymentHash.toHex) + using(pg.prepareStatement("UPDATE payments.received SET (virtual_received_msat, real_received_msat, received_at) = (? + COALESCE(virtual_received_msat, 0), ? + COALESCE(real_received_msat, 0), ?) WHERE payment_hash = ?")) { update => + update.setLong(1, virtualAmount.toLong) + update.setLong(2, realAmount.toLong) + update.setTimestamp(3, receivedAt.toSqlTimestamp) + update.setString(4, paymentHash.toHex) val updated = update.executeUpdate() updated > 0 } } } - override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Postgres) { + override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("INSERT INTO payments.received (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + - "ON CONFLICT (payment_hash) DO UPDATE SET (received_msat, received_at) = (payments.received.received_msat + EXCLUDED.received_msat, EXCLUDED.received_at)")) { statement => + using(pg.prepareStatement("INSERT INTO payments.received (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, virtual_received_msat, real_received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" + + "ON CONFLICT (payment_hash) DO UPDATE SET (virtual_received_msat, real_received_msat, received_at) = (payments.received.virtual_received_msat + EXCLUDED.virtual_received_msat, payments.received.real_received_msat + EXCLUDED.real_received_msat, EXCLUDED.received_at)")) { statement => statement.setString(1, invoice.paymentHash.toHex) statement.setString(2, preimage.toHex) statement.setString(3, paymentType) statement.setString(4, invoice.toString) statement.setTimestamp(5, invoice.createdAt.toSqlTimestamp) statement.setTimestamp(6, (invoice.createdAt + invoice.relativeExpiry.toSeconds).toSqlTimestamp) - statement.setLong(7, amount.toLong) - statement.setTimestamp(8, receivedAt.toSqlTimestamp) + statement.setLong(7, virtualAmount.toLong) + statement.setLong(8, realAmount.toLong) + statement.setTimestamp(9, receivedAt.toSqlTimestamp) statement.executeUpdate() } } @@ -304,10 +317,10 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit val createdAt = TimestampMilli.fromSqlTimestamp(rs.getTimestamp("created_at")) Invoice.fromString(invoice) match { case Success(invoice: Bolt11Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) Some(IncomingStandardPayment(invoice, preimage, paymentType, createdAt, status)) case Success(invoice: MinimalBolt12Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) Some(IncomingBlindedPayment(invoice, preimage, paymentType, createdAt, status)) case _ => logger.error(s"could not parse DB invoice=$invoice, this should not happen") @@ -315,11 +328,11 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - private def buildIncomingPaymentStatus(amount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { - amount_opt match { - case Some(amount) => IncomingPaymentStatus.Received(amount, receivedAt_opt.getOrElse(0 unixms)) - case None if invoice.isExpired() => IncomingPaymentStatus.Expired - case None => IncomingPaymentStatus.Pending + private def buildIncomingPaymentStatus(virtualAmount_opt: Option[MilliSatoshi], realAmount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { + (virtualAmount_opt, realAmount_opt) match { + case (Some(virtualAmount), Some(realAmount)) => IncomingPaymentStatus.Received(virtualAmount, realAmount, receivedAt_opt.getOrElse(0 unixms)) + case _ if invoice.isExpired() => IncomingPaymentStatus.Expired + case _ => IncomingPaymentStatus.Pending } } @@ -366,7 +379,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-received", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE virtual_received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setTimestamp(1, from.toSqlTimestamp) statement.setTimestamp(2, to.toSqlTimestamp) statement.executeQuery().flatMap(parseIncomingPayment).toSeq @@ -376,7 +389,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listPendingIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-pending", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => + using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => statement.setTimestamp(1, from.toSqlTimestamp) statement.setTimestamp(2, to.toSqlTimestamp) statement.setTimestamp(3, Timestamp.from(Instant.now())) @@ -387,7 +400,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listExpiredIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-expired", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setTimestamp(1, from.toSqlTimestamp) statement.setTimestamp(2, to.toSqlTimestamp) statement.setTimestamp(3, Timestamp.from(Instant.now())) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index c8b8f070df..ede02bdb04 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -34,7 +34,7 @@ import java.util.UUID object SqliteAuditDb { val DB_NAME = "audit" - val CURRENT_VERSION = 9 + val CURRENT_VERSION = 10 } class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { @@ -114,10 +114,18 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON transactions_published(channel_id)") } + def migration910(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE received RENAME TO received_old") + statement.executeUpdate("CREATE TABLE received (virtual_amount_msat INTEGER NOT NULL, real_amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("INSERT INTO received SELECT amount_msat, amount_msat, payment_hash, from_channel_id, timestamp FROM received_old") + statement.executeUpdate("DROP TABLE received_old") + statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)") + } + getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE TABLE sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE received (virtual_amount_msat INTEGER NOT NULL, real_amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, next_node_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)") @@ -145,7 +153,7 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON transactions_published(channel_id)") statement.executeUpdate("CREATE INDEX transactions_published_timestamp_idx ON transactions_published(timestamp)") statement.executeUpdate("CREATE INDEX transactions_confirmed_timestamp_idx ON transactions_confirmed(timestamp)") - case Some(v@(1 | 2 | 3 | 4 | 5 | 6 | 7 | 8)) => + case Some(v@(1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 2) { migration12(statement) @@ -171,6 +179,9 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { if (v < 9) { migration89(statement) } + if (v < 10) { + migration910(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -210,12 +221,13 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { } override def add(e: PaymentReceived): Unit = withMetrics("audit/add-payment-received", DbBackends.Sqlite) { - using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => + using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?, ?)")) { statement => e.parts.foreach(p => { - statement.setLong(1, p.amount.toLong) - statement.setBytes(2, e.paymentHash.toArray) - statement.setBytes(3, p.fromChannelId.toArray) - statement.setLong(4, p.timestamp.toLong) + statement.setLong(1, p.virtualAmount.toLong) + statement.setLong(2, p.realAmount.toLong) + statement.setBytes(3, e.paymentHash.toArray) + statement.setBytes(4, p.fromChannelId.toArray) + statement.setLong(5, p.timestamp.toLong) statement.addBatch() }) statement.executeBatch() @@ -374,7 +386,8 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { .foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) => val paymentHash = rs.getByteVector32("payment_hash") val part = PaymentReceived.PartialPayment( - MilliSatoshi(rs.getLong("amount_msat")), + MilliSatoshi(rs.getLong("virtual_amount_msat")), + MilliSatoshi(rs.getLong("real_amount_msat")), rs.getByteVector32("from_channel_id"), TimestampMilli(rs.getLong("timestamp"))) val received = receivedByHash.get(paymentHash) match { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index d08008388a..96ea34467d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -108,9 +108,17 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON sent_payments(offer_id)") } + def migration67(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE received_payments RENAME TO received_payments_old") + statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage BLOB NOT NULL, path_ids BLOB, payment_request TEXT NOT NULL, virtual_received_msat INTEGER, real_received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") + statement.executeUpdate("INSERT INTO received_payments SELECT payment_hash, payment_type, payment_preimage, path_ids, payment_request, received_msat, received_msat, created_at, expire_at, received_at FROM received_payments_old") + statement.executeUpdate("DROP TABLE received_payments_old") + statement.executeUpdate("CREATE INDEX received_created_idx ON received_payments(created_at)") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage BLOB NOT NULL, path_ids BLOB, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage BLOB NOT NULL, path_ids BLOB, payment_request TEXT NOT NULL, virtual_received_msat INTEGER, real_received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, payment_preimage BLOB, payment_type TEXT NOT NULL, amount_msat INTEGER NOT NULL, fees_msat INTEGER, recipient_amount_msat INTEGER NOT NULL, recipient_node_id BLOB NOT NULL, payment_request TEXT, offer_id BLOB, payer_key BLOB, payment_route BLOB, failures BLOB, created_at INTEGER NOT NULL, completed_at INTEGER)") statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON sent_payments(parent_id)") @@ -118,7 +126,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON sent_payments(offer_id)") statement.executeUpdate("CREATE INDEX sent_created_idx ON sent_payments(created_at)") statement.executeUpdate("CREATE INDEX received_created_idx ON received_payments(created_at)") - case Some(v@(1 | 2 | 3 | 4 | 5)) => + case Some(v@(1 | 2 | 3 | 4 | 5 | 6)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 2) { migration12(statement) @@ -135,6 +143,9 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { if (v < 6) { migration56(statement) } + if (v < 7) { + migration67(statement) + } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -279,32 +290,35 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } } - override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Sqlite) { - using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => - update.setLong(1, amount.toLong) - update.setLong(2, receivedAt.toLong) - update.setBytes(3, paymentHash.toArray) + override def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Sqlite) { + using(sqlite.prepareStatement("UPDATE received_payments SET (virtual_received_msat, real_received_msat, received_at) = (? + COALESCE(virtual_received_msat, 0), ? + COALESCE(real_received_msat, 0), ?) WHERE payment_hash = ?")) { update => + update.setLong(1, virtualAmount.toLong) + update.setLong(2, realAmount.toLong) + update.setLong(3, receivedAt.toLong) + update.setBytes(4, paymentHash.toArray) val updated = update.executeUpdate() updated > 0 } } - override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Sqlite) { - if (using(sqlite.prepareStatement("INSERT OR IGNORE INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Sqlite) { + if (using(sqlite.prepareStatement("INSERT OR IGNORE INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, virtual_received_msat, real_received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setBytes(1, invoice.paymentHash.toArray) statement.setBytes(2, preimage.toArray) statement.setString(3, paymentType) statement.setString(4, invoice.toString) statement.setLong(5, invoice.createdAt.toTimestampMilli.toLong) statement.setLong(6, (invoice.createdAt + invoice.relativeExpiry).toLong.seconds.toMillis) - statement.setLong(7, amount.toLong) - statement.setLong(8, receivedAt.toLong) + statement.setLong(7, virtualAmount.toLong) + statement.setLong(8, realAmount.toLong) + statement.setLong(9, receivedAt.toLong) statement.executeUpdate() } == 0) { - using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (received_msat + ?, ?) WHERE payment_hash = ?")) { statement => - statement.setLong(1, amount.toLong) - statement.setLong(2, receivedAt.toLong) - statement.setBytes(3, invoice.paymentHash.toArray) + using(sqlite.prepareStatement("UPDATE received_payments SET (virtual_received_msat, real_received_msat, received_at) = (virtual_received_msat + ?, real_received_msat + ?, ?) WHERE payment_hash = ?")) { statement => + statement.setLong(1, virtualAmount.toLong) + statement.setLong(2, realAmount.toLong) + statement.setLong(3, receivedAt.toLong) + statement.setBytes(4, invoice.paymentHash.toArray) statement.executeUpdate() } } @@ -317,10 +331,10 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { val createdAt = TimestampMilli(rs.getLong("created_at")) Invoice.fromString(invoice) match { case Success(invoice: Bolt11Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) Some(IncomingStandardPayment(invoice, preimage, paymentType, createdAt, status)) case Success(invoice: MinimalBolt12Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) Some(IncomingBlindedPayment(invoice, preimage, paymentType, createdAt, status)) case _ => logger.error(s"could not parse DB invoice=$invoice, this should not happen") @@ -328,11 +342,11 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } } - private def buildIncomingPaymentStatus(amount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { - amount_opt match { - case Some(amount) => IncomingPaymentStatus.Received(amount, receivedAt_opt.getOrElse(0 unixms)) - case None if invoice.isExpired() => IncomingPaymentStatus.Expired - case None => IncomingPaymentStatus.Pending + private def buildIncomingPaymentStatus(virtualAmount_opt: Option[MilliSatoshi], realAmount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { + (virtualAmount_opt, realAmount_opt) match { + case (Some(virtualAmount), Some(realAmount)) => IncomingPaymentStatus.Received(virtualAmount, realAmount, receivedAt_opt.getOrElse(0 unixms)) + case _ if invoice.isExpired() => IncomingPaymentStatus.Expired + case _ => IncomingPaymentStatus.Pending } } @@ -368,7 +382,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } override def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-received", DbBackends.Sqlite) { - using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE virtual_received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setLong(1, from.toLong) statement.setLong(2, to.toLong) statement.executeQuery().flatMap(parseIncomingPayment).toSeq @@ -376,7 +390,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } override def listPendingIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-pending", DbBackends.Sqlite) { - using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => + using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => statement.setLong(1, from.toLong) statement.setLong(2, to.toLong) statement.setLong(3, TimestampMilli.now().toLong) @@ -385,7 +399,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } override def listExpiredIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-expired", DbBackends.Sqlite) { - using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setLong(1, from.toLong) statement.setLong(2, to.toLong) statement.setLong(3, TimestampMilli.now().toLong) @@ -397,5 +411,5 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { object SqlitePaymentsDb { val DB_NAME = "payments" - val CURRENT_VERSION = 6 + val CURRENT_VERSION = 7 } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala index 92bccca7a3..810a5527f0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/package.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/package.scala @@ -71,6 +71,17 @@ package object eclair { def nodeFee(relayFees: RelayFees, paymentAmount: MilliSatoshi): MilliSatoshi = nodeFee(relayFees.feeBase, relayFees.feeProportionalMillionths, paymentAmount) + /** + * @param baseFee fixed fee + * @param proportionalFee proportional fee (millionths) + * @param incomingAmount incoming payment amount + * @return the amount that a node should forward after paying itself the base and proportional fees + */ + def amountAfterFee(baseFee: MilliSatoshi, proportionalFee: Long, incomingAmount: MilliSatoshi): MilliSatoshi = + ((incomingAmount - baseFee).toLong * 1_000_000 + 1_000_000 + proportionalFee - 1).msat / (1_000_000 + proportionalFee) + + def amountAfterFee(relayFees: RelayFees, incomingAmount: MilliSatoshi): MilliSatoshi = amountAfterFee(relayFees.feeBase, relayFees.feeProportionalMillionths, incomingAmount) + implicit class MilliSatoshiLong(private val n: Long) extends AnyVal { def msat = MilliSatoshi(n) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index c070d9fb7e..e467bf030a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -120,13 +120,14 @@ object PaymentRelayed { case class PaymentReceived(paymentHash: ByteVector32, parts: Seq[PaymentReceived.PartialPayment]) extends PaymentEvent { require(parts.nonEmpty, "must have at least one payment part") - val amount: MilliSatoshi = parts.map(_.amount).sum + val virtualAmount: MilliSatoshi = parts.map(_.virtualAmount).sum + val realAmount: MilliSatoshi = parts.map(_.realAmount).sum val timestamp: TimestampMilli = parts.map(_.timestamp).max // we use max here because we fulfill the payment only once we received all the parts } object PaymentReceived { - case class PartialPayment(amount: MilliSatoshi, fromChannelId: ByteVector32, timestamp: TimestampMilli = TimestampMilli.now()) + case class PartialPayment(virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, fromChannelId: ByteVector32, timestamp: TimestampMilli = TimestampMilli.now()) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala index bb696d1caa..083373f7e1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentPacket.scala @@ -221,7 +221,6 @@ object IncomingPaymentPacket { case payload if payload.paymentConstraints_opt.exists(c => add.amountMsat < c.minAmount) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload if payload.paymentConstraints_opt.exists(c => c.maxCltvExpiry < add.cltvExpiry) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload if !Features.areCompatible(Features.empty, payload.allowedFeatures) => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) - case payload if add.amountMsat < payload.amount => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload if add.cltvExpiry < payload.expiry => Left(InvalidOnionBlinding(Sphinx.hash(add.onionRoutingPacket))) case payload => Right(FinalPacket(add, payload)) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala index a742b3d625..06c9242aea 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala @@ -20,6 +20,7 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.Crypto.PrivateKey import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} +import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.crypto.Sphinx.RouteBlinding import fr.acinq.eclair.db.{IncomingBlindedPayment, IncomingPaymentStatus, PaymentType} import fr.acinq.eclair.message.{OnionMessages, Postman} @@ -27,10 +28,13 @@ import fr.acinq.eclair.payment.MinimalBolt12Invoice import fr.acinq.eclair.payment.offer.OfferPaymentMetadata.MinimalInvoiceData import fr.acinq.eclair.payment.receive.MultiPartHandler import fr.acinq.eclair.payment.receive.MultiPartHandler.{CreateInvoiceActor, ReceivingRoute} +import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair.router.BlindedRouteCreation.aggregatePaymentInfo +import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, InvoiceTlv, Offer} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Logs, MilliSatoshi, NodeParams, TimestampMilli, TimestampSecond, randomBytes32} +import fr.acinq.eclair.{CltvExpiryDelta, Logs, MilliSatoshi, NodeParams, TimestampMilli, TimestampSecond, nodeFee, randomBytes32} import scodec.bits.ByteVector import scala.concurrent.duration.FiniteDuration @@ -61,7 +65,7 @@ object OfferManager { case class RequestInvoice(messagePayload: MessageOnion.InvoiceRequestPayload, blindedKey: PrivateKey, postman: ActorRef[Postman.SendMessage]) extends Command - case class ReceivePayment(replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], paymentHash: ByteVector32, payload: FinalPayload.Blinded) extends Command + case class ReceivePayment(replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], paymentHash: ByteVector32, payload: FinalPayload.Blinded, realAmount: MilliSatoshi) extends Command /** * Offer handlers must be implemented in separate plugins and respond to these two `HandlerCommand`. @@ -89,15 +93,15 @@ object OfferManager { private case class RegisteredOffer(offer: Offer, nodeKey: Option[PrivateKey], pathId_opt: Option[ByteVector32], handler: ActorRef[HandlerCommand]) - def apply(nodeParams: NodeParams, router: akka.actor.ActorRef, paymentTimeout: FiniteDuration): Behavior[Command] = { + def apply(nodeParams: NodeParams, paymentTimeout: FiniteDuration): Behavior[Command] = { Behaviors.setup { context => Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT))) { - new OfferManager(nodeParams, router, paymentTimeout, context).normal(Map.empty) + new OfferManager(nodeParams, paymentTimeout, context).normal(Map.empty) } } } - private class OfferManager(nodeParams: NodeParams, router: akka.actor.ActorRef, paymentTimeout: FiniteDuration, context: ActorContext[Command]) { + private class OfferManager(nodeParams: NodeParams, paymentTimeout: FiniteDuration, context: ActorContext[Command]) { def normal(registeredOffers: Map[ByteVector32, RegisteredOffer]): Behavior[Command] = { Behaviors.receiveMessage { case RegisterOffer(offer, nodeKey, pathId_opt, handler) => @@ -108,17 +112,19 @@ object OfferManager { registeredOffers.get(messagePayload.invoiceRequest.offer.offerId) match { case Some(registered) if registered.pathId_opt.map(_.bytes) == messagePayload.pathId_opt && messagePayload.invoiceRequest.isValid => context.log.debug("received valid invoice request for offerId={}", messagePayload.invoiceRequest.offer.offerId) - val child = context.spawnAnonymous(InvoiceRequestActor(nodeParams, messagePayload.invoiceRequest, registered.handler, registered.nodeKey.getOrElse(blindedKey), router, messagePayload.replyPath, postman)) + val child = context.spawnAnonymous(InvoiceRequestActor(nodeParams, messagePayload.invoiceRequest, registered.handler, registered.nodeKey.getOrElse(blindedKey), messagePayload.replyPath, postman)) child ! InvoiceRequestActor.RequestInvoice case _ => context.log.debug("offer {} is not registered or invoice request is invalid", messagePayload.invoiceRequest.offer.offerId) } Behaviors.same - case ReceivePayment(replyTo, paymentHash, payload) => + case ReceivePayment(replyTo, paymentHash, payload, realAmount) => MinimalInvoiceData.decode(payload.pathId) match { case Some(signed) => registeredOffers.get(signed.offerId) match { case Some(RegisteredOffer(offer, _, _, handler)) => MinimalInvoiceData.verify(nodeParams.nodeId, signed) match { + case Some(metadata) if realAmount + nodeFee(metadata.hiddenFees, realAmount) < payload.amount => + replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(s"incorrect amount received for offer ${signed.offerId.toHex}: realAmount=$realAmount, hiddenFees=${metadata.hiddenFees}, virtualAmount=${payload.amount}") case Some(metadata) if Crypto.sha256(metadata.preimage) == paymentHash => val child = context.spawnAnonymous(PaymentActor(nodeParams, replyTo, offer, metadata, paymentTimeout)) handler ! HandlePayment(child, signed.offerId, metadata.pluginData_opt) @@ -145,16 +151,20 @@ object OfferManager { * * @param amount Amount for the invoice (must be the same as the invoice request if it contained an amount). * @param routes Routes to use for the payment. + * @param hideFees If true, fees for the blinded route will be hidden to the payer and paid by the recipient. * @param pluginData_opt Some data for the handler by the handler. It will be sent to the handler when a payment is attempted. * @param additionalTlvs additional TLVs to add to the invoice. * @param customTlvs custom TLVs to add to the invoice. */ case class ApproveRequest(amount: MilliSatoshi, - routes: Seq[ReceivingRoute], + routes: Seq[Route], + hideFees: Boolean, pluginData_opt: Option[ByteVector] = None, additionalTlvs: Set[InvoiceTlv] = Set.empty, customTlvs: Set[GenericTlv] = Set.empty) extends Command + case class Route(hops: Seq[Router.ChannelHop], maxFinalExpiryDelta: CltvExpiryDelta, shortChannelIdDir_opt: Option[ShortChannelIdDir] = None) + /** * Sent by the offer handler to reject the request. For instance because stock has been exhausted. */ @@ -168,7 +178,6 @@ object OfferManager { invoiceRequest: InvoiceRequest, offerHandler: ActorRef[HandleInvoiceRequest], nodeKey: PrivateKey, - router: akka.actor.ActorRef, pathToSender: RouteBlinding.BlindedRoute, postman: ActorRef[Postman.SendMessage]): Behavior[Command] = { Behaviors.setup { context => @@ -176,7 +185,7 @@ object OfferManager { Behaviors.receiveMessagePartial { case RequestInvoice => offerHandler ! HandleInvoiceRequest(context.self, invoiceRequest) - new InvoiceRequestActor(nodeParams, invoiceRequest, nodeKey, router, pathToSender, postman, context).waitForHandler() + new InvoiceRequestActor(nodeParams, invoiceRequest, nodeKey, pathToSender, postman, context).waitForHandler() } } } @@ -185,7 +194,6 @@ object OfferManager { private class InvoiceRequestActor(nodeParams: NodeParams, invoiceRequest: InvoiceRequest, nodeKey: PrivateKey, - router: akka.actor.ActorRef, pathToSender: RouteBlinding.BlindedRoute, postman: ActorRef[Postman.SendMessage], context: ActorContext[Command]) { @@ -195,11 +203,17 @@ object OfferManager { context.log.debug("offer handler rejected invoice request: {}", error) postman ! Postman.SendMessage(OfferTypes.BlindedPath(pathToSender), OnionMessages.RoutingStrategy.FindRoute, TlvStream(OnionMessagePayloadTlv.InvoiceError(TlvStream(OfferTypes.Error(error)))), expectsReply = false, context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse)) waitForSent() - case ApproveRequest(amount, routes, pluginData_opt, additionalTlvs, customTlvs) => + case ApproveRequest(amount, routes, hideFees, pluginData_opt, additionalTlvs, customTlvs) => val preimage = randomBytes32() - val metadata = MinimalInvoiceData(preimage, invoiceRequest.payerId, TimestampSecond.now(), invoiceRequest.quantity, amount, pluginData_opt) - val pathId = MinimalInvoiceData.encode(nodeParams.privateKey, invoiceRequest.offer.offerId, metadata) - val receivePayment = MultiPartHandler.ReceiveOfferPayment(context.messageAdapter[CreateInvoiceActor.Bolt12InvoiceResponse](WrappedInvoiceResponse), nodeKey, invoiceRequest, routes, router, preimage, pathId, additionalTlvs, customTlvs) + val receivingRoutes = routes.map(route => { + val paymentInfo = aggregatePaymentInfo(amount, route.hops, nodeParams.channelConf.minFinalExpiryDelta) + val hiddenFees = if (hideFees) RelayFees(paymentInfo.feeBase, paymentInfo.feeProportionalMillionths) else RelayFees.zero + val metadata = MinimalInvoiceData(preimage, invoiceRequest.payerId, TimestampSecond.now(), invoiceRequest.quantity, amount, hiddenFees, pluginData_opt) + val pathId = MinimalInvoiceData.encode(nodeParams.privateKey, invoiceRequest.offer.offerId, metadata) + val paymentInfo1 = if (hideFees) paymentInfo.copy(feeBase = MilliSatoshi(0), feeProportionalMillionths = 0) else paymentInfo + ReceivingRoute(route.hops, pathId, route.maxFinalExpiryDelta, paymentInfo1, route.shortChannelIdDir_opt) + }) + val receivePayment = MultiPartHandler.ReceiveOfferPayment(context.messageAdapter[CreateInvoiceActor.Bolt12InvoiceResponse](WrappedInvoiceResponse), nodeKey, invoiceRequest, receivingRoutes, preimage, additionalTlvs, customTlvs) val child = context.spawnAnonymous(CreateInvoiceActor(nodeParams)) child ! CreateInvoiceActor.CreateBolt12Invoice(receivePayment) waitForInvoice() @@ -253,7 +267,7 @@ object OfferManager { case AcceptPayment(additionalTlvs, customTlvs) => val minimalInvoice = MinimalBolt12Invoice(offer, nodeParams.chainHash, metadata.amount, metadata.quantity, Crypto.sha256(metadata.preimage), metadata.payerKey, metadata.createdAt, additionalTlvs, customTlvs) val incomingPayment = IncomingBlindedPayment(minimalInvoice, metadata.preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment) + replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment, metadata.hiddenFees) Behaviors.stopped case RejectPayment(reason) => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(reason) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala index 5edb7c5fd5..fa60b83899 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.payment.offer import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.scalacompat.{ByteVector32, ByteVector64, Crypto} +import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.{MilliSatoshi, TimestampSecond} import scodec.bits.ByteVector @@ -49,6 +50,7 @@ object OfferPaymentMetadata { createdAt: TimestampSecond, quantity: Long, amount: MilliSatoshi, + hiddenFees: RelayFees, pluginData_opt: Option[ByteVector]) /** @@ -69,6 +71,7 @@ object OfferPaymentMetadata { ("createdAt" | timestampSecond) :: ("quantity" | uint64overflow) :: ("amount" | millisatoshi) :: + ("hiddenFees" | (millisatoshi :: int64).as[RelayFees]) :: ("pluginData" | optional(bitsRemaining, bytes))).as[MinimalInvoiceData] private val signedDataCodec: Codec[SignedMinimalInvoiceData] = diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index f749d80ba2..6a3484bc8e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -35,13 +35,14 @@ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.offer.OfferManager -import fr.acinq.eclair.router.BlindedRouteCreation.{aggregatePaymentInfo, createBlindedRouteFromHops, createBlindedRouteWithoutHops} +import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair.router.BlindedRouteCreation.{aggregatePaymentInfo, createBlindedRouteFromHops} import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, PaymentRouteResponse} import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, InvoiceTlv} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Bolt11Feature, CltvExpiryDelta, FeatureSupport, Features, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, randomBytes32} +import fr.acinq.eclair.{Bolt11Feature, CltvExpiryDelta, FeatureSupport, Features, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, nodeFee, randomBytes32} import scodec.bits.{ByteVector, HexStringSyntax} import scala.concurrent.duration.DurationInt @@ -63,10 +64,10 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP private def addHtlcPart(ctx: ActorContext, add: UpdateAddHtlc, payload: FinalPayload, payment: IncomingPayment): Unit = { pendingPayments.get(add.paymentHash) match { case Some((_, handler)) => - handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, add) + handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, payload.amount, add) case None => val handler = ctx.actorOf(MultiPartPaymentFSM.props(nodeParams, add.paymentHash, payload.totalAmount, ctx.self)) - handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, add) + handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, payload.amount, add) pendingPayments = pendingPayments + (add.paymentHash -> (payment, handler)) } } @@ -132,14 +133,14 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP } } - case ProcessBlindedPacket(add, payload, payment) if doHandle(add.paymentHash) => + case ProcessBlindedPacket(add, payload, payment, hiddenRelayFees) if doHandle(add.paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(add.paymentHash))) { validateBlindedPayment(nodeParams, add, payload, payment) match { case Some(cmdFail) => Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, Tags.FailureType(cmdFail)).increment() PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmdFail) case None => - log.debug("received payment for amount={} totalAmount={}", add.amountMsat, payload.totalAmount) + log.debug("received payment for virtualAmount={} realAmount={} totalAmount={}", payload.amount, add.amountMsat, payload.totalAmount) addHtlcPart(ctx, add, payload, payment) } } @@ -152,7 +153,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP case MultiPartPaymentFSM.MultiPartPaymentFailed(paymentHash, failure, parts) if doHandle(paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(paymentHash))) { Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, failure.getClass.getSimpleName).increment() - log.warning("payment with paidAmount={} failed ({})", parts.map(_.amount).sum, failure) + log.warning("payment with paidAmount={} failed ({})", parts.map(_.virtualAmount).sum, failure) pendingPayments.get(paymentHash).foreach { case (_, handler: ActorRef) => handler ! PoisonPill } parts.collect { case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(failure), commit = true)) @@ -162,7 +163,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP case s@MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts) if doHandle(paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(paymentHash))) { - log.info("received complete payment for amount={}", parts.map(_.amount).sum) + log.info("received complete payment for amount={}", parts.map(_.virtualAmount).sum) pendingPayments.get(paymentHash).foreach { case (payment: IncomingPayment, handler: ActorRef) => handler ! PoisonPill @@ -181,12 +182,12 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP // NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more // expensive code path by fetching the preimage from DB. case p: MultiPartPaymentFSM.HtlcPart => db.getIncomingPayment(paymentHash).foreach(record => { - val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) :: Nil) - if (db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp)) { + val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.virtualAmount, p.realAmount, p.htlc.channelId) :: Nil) + if (db.receiveIncomingPayment(paymentHash, p.virtualAmount, p.realAmount, received.timestamp)) { PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true)) ctx.system.eventStream.publish(received) } else { - val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true) + val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.virtualAmount, nodeParams.currentBlockHeight)), commit = true) PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail) } }) @@ -196,18 +197,18 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP case DoFulfill(payment, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts)) if doHandle(paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(paymentHash))) { - log.debug("fulfilling payment for amount={}", parts.map(_.amount).sum) + log.debug("fulfilling payment for virtualAmount={}, realAmount={}", parts.map(_.virtualAmount).sum, parts.map(_.realAmount).sum) val received = PaymentReceived(paymentHash, parts.map { - case p: MultiPartPaymentFSM.HtlcPart => PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) + case p: MultiPartPaymentFSM.HtlcPart => PaymentReceived.PartialPayment(p.virtualAmount, p.realAmount, p.htlc.channelId) }) val recordedInDb = payment match { // Incoming offer payments are not stored in the database until they have been paid. case IncomingBlindedPayment(invoice, preimage, paymentType, _, _) => - db.receiveIncomingOfferPayment(invoice, preimage, received.amount, received.timestamp, paymentType) + db.receiveIncomingOfferPayment(invoice, preimage, received.virtualAmount, received.realAmount, received.timestamp, paymentType) true // Incoming standard payments are already stored and need to be marked as received. case _: IncomingStandardPayment => - db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp) + db.receiveIncomingPayment(paymentHash, received.virtualAmount, received.realAmount, received.timestamp) } if (recordedInDb) { parts.collect { @@ -219,7 +220,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP parts.collect { case p: MultiPartPaymentFSM.HtlcPart => Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, "InvoiceNotFound").increment() - val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true) + val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.virtualAmount, nodeParams.currentBlockHeight)), commit = true) PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail) } } @@ -236,7 +237,7 @@ object MultiPartHandler { // @formatter:off case class ProcessPacket(add: UpdateAddHtlc, payload: FinalPayload.Standard, payment_opt: Option[IncomingStandardPayment]) - case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment) + case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment, hiddenRelayFees: RelayFees) case class RejectPacket(add: UpdateAddHtlc, failure: FailureMessage) case class DoFulfill(payment: IncomingPayment, success: MultiPartPaymentFSM.MultiPartPaymentSucceeded) @@ -265,20 +266,14 @@ object MultiPartHandler { paymentPreimage_opt: Option[ByteVector32] = None, paymentType: String = PaymentType.Standard) extends ReceivePayment - /** - * A dummy blinded hop that will be added at the end of a blinded route. - * The fees and expiry delta should match those of real channels, otherwise it will be obvious that dummy hops are used. - */ - case class DummyBlindedHop(feeBase: MilliSatoshi, feeProportionalMillionths: Long, cltvExpiryDelta: CltvExpiryDelta) - /** * A route that will be blinded and included in a Bolt 12 invoice. * - * @param nodes a valid route ending at our nodeId. + * @param hops hops to reach our node, or the empty sequence if we do not want to hide our node id. + * @param pathId path id for this route. * @param maxFinalExpiryDelta maximum expiry delta that senders can use: the route expiry will be computed based on this value. - * @param dummyHops (optional) dummy hops to add to the blinded route. */ - case class ReceivingRoute(nodes: Seq[PublicKey], maxFinalExpiryDelta: CltvExpiryDelta, dummyHops: Seq[DummyBlindedHop] = Nil, shortChannelIdDir_opt: Option[ShortChannelIdDir] = None) + case class ReceivingRoute(hops: Seq[Router.ChannelHop], pathId: ByteVector, maxFinalExpiryDelta: CltvExpiryDelta, paymentInfo: OfferTypes.PaymentInfo, shortChannelIdDir_opt: Option[ShortChannelIdDir] = None) /** * Use this message to create a Bolt 12 invoice to receive a payment for a given offer. @@ -287,20 +282,15 @@ object MultiPartHandler { * and may be different from our public nodeId. * @param invoiceRequest the request this invoice responds to. * @param routes routes that must be blinded and provided in the invoice. - * @param router router actor. * @param paymentPreimage payment preimage. - * @param pathId path id that will be used for all payment paths. */ case class ReceiveOfferPayment(replyTo: typed.ActorRef[CreateInvoiceActor.Bolt12InvoiceResponse], nodeKey: PrivateKey, invoiceRequest: InvoiceRequest, routes: Seq[ReceivingRoute], - router: ActorRef, paymentPreimage: ByteVector32, - pathId: ByteVector, additionalTlvs: Set[InvoiceTlv] = Set.empty, customTlvs: Set[GenericTlv] = Set.empty) extends ReceivePayment { - require(routes.forall(_.nodes.nonEmpty), "each route must have at least one node") require(invoiceRequest.offer.amount.nonEmpty || invoiceRequest.amount.nonEmpty, "an amount must be specified in the offer or in the invoice request") val amount = invoiceRequest.amount.orElse(invoiceRequest.offer.amount.map(_ * invoiceRequest.quantity)).get @@ -317,7 +307,6 @@ object MultiPartHandler { sealed trait Bolt12InvoiceResponse case class InvoiceCreated(invoice: Bolt12Invoice) extends Bolt12InvoiceResponse sealed trait InvoiceCreationFailed extends Bolt12InvoiceResponse { def message: String } - case object InvalidBlindedRouteRecipient extends InvoiceCreationFailed { override def message: String = "receiving routes must end at our node" } case class BlindedRouteCreationFailed(message: String) extends InvoiceCreationFailed // @formatter:on @@ -351,56 +340,20 @@ object MultiPartHandler { nodeParams.db.payments.addIncomingPayment(invoice, paymentPreimage, r.paymentType) r.replyTo ! invoice Behaviors.stopped - case CreateBolt12Invoice(r) if r.routes.exists(!_.nodes.lastOption.contains(nodeParams.nodeId)) => - r.replyTo ! InvalidBlindedRouteRecipient - Behaviors.stopped case CreateBolt12Invoice(r) => - implicit val ec: ExecutionContextExecutor = context.executionContext - val log = context.log - context.pipeToSelf(Future.sequence(r.routes.map(route => { - val dummyHops = route.dummyHops.map(h => { - // We don't want to restrict HTLC size in dummy hops, so we use htlc_minimum_msat = 1 msat and htlc_maximum_msat = None. - val edge = Invoice.ExtraEdge(nodeParams.nodeId, nodeParams.nodeId, ShortChannelId.toSelf, h.feeBase, h.feeProportionalMillionths, h.cltvExpiryDelta, htlcMinimum = 1 msat, htlcMaximum_opt = None) - ChannelHop(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId, HopRelayParams.FromHint(edge)) - }) - if (route.nodes.length == 1) { - val blindedRoute = if (dummyHops.isEmpty) { - createBlindedRouteWithoutHops(route.nodes.last, r.pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) - } else { - createBlindedRouteFromHops(dummyHops, r.pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) - } - val contactInfo = route.shortChannelIdDir_opt match { - case Some(shortChannelIdDir) => BlindedRoute(shortChannelIdDir, blindedRoute.route.firstPathKey, blindedRoute.route.blindedHops) - case None => blindedRoute.route - } - val paymentInfo = aggregatePaymentInfo(r.amount, dummyHops, nodeParams.channelConf.minFinalExpiryDelta) - Future.successful(PaymentBlindedRoute(contactInfo, paymentInfo)) - } else { - r.router.toTyped.ask[PaymentRouteResponse](replyTo => Router.FinalizeRoute(replyTo, Router.PredefinedNodeRoute(r.amount, route.nodes)))(10.seconds, context.system.scheduler).mapTo[Router.RouteResponse].map(routeResponse => { - val clearRoute = routeResponse.routes.head - val blindedRoute = createBlindedRouteFromHops(clearRoute.hops ++ dummyHops, r.pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) - val contactInfo = route.shortChannelIdDir_opt match { - case Some(shortChannelIdDir) => BlindedRoute(shortChannelIdDir, blindedRoute.route.firstPathKey, blindedRoute.route.blindedHops) - case None => blindedRoute.route - } - val paymentInfo = aggregatePaymentInfo(r.amount, clearRoute.hops ++ dummyHops, nodeParams.channelConf.minFinalExpiryDelta) - PaymentBlindedRoute(contactInfo, paymentInfo) - }) + val paths = r.routes.map(route => { + val blindedRoute = createBlindedRouteFromHops(route.hops, nodeParams.nodeId, route.pathId, nodeParams.channelConf.htlcMinimum, route.maxFinalExpiryDelta.toCltvExpiry(nodeParams.currentBlockHeight)) + val contactInfo = route.shortChannelIdDir_opt match { + case Some(shortChannelIdDir) => BlindedRoute(shortChannelIdDir, blindedRoute.route.firstPathKey, blindedRoute.route.blindedHops) + case None => blindedRoute.route } - })).map(paths => { - val invoiceFeatures = nodeParams.features.bolt12Features() - val invoice = Bolt12Invoice(r.invoiceRequest, r.paymentPreimage, r.nodeKey, nodeParams.invoiceExpiry, invoiceFeatures, paths, r.additionalTlvs, r.customTlvs) - log.debug("generated invoice={} for offer={}", invoice.toString, r.invoiceRequest.offer.toString) - invoice - }))(WrappedInvoiceResult) - Behaviors.receiveMessagePartial { - case WrappedInvoiceResult(result) => - result match { - case Failure(f) => r.replyTo ! BlindedRouteCreationFailed(f.getMessage) - case Success(invoice) => r.replyTo ! InvoiceCreated(invoice) - } - Behaviors.stopped - } + PaymentBlindedRoute(contactInfo, route.paymentInfo) + }) + val invoiceFeatures = nodeParams.features.bolt12Features() + val invoice = Bolt12Invoice(r.invoiceRequest, r.paymentPreimage, r.nodeKey, nodeParams.invoiceExpiry, invoiceFeatures, paths, r.additionalTlvs, r.customTlvs) + context.log.debug("generated invoice={} for offer={}", invoice.toString, r.invoiceRequest.offer.toString) + r.replyTo ! InvoiceCreated(invoice) + Behaviors.stopped } } } @@ -411,7 +364,7 @@ object MultiPartHandler { // @formatter:off sealed trait Command case class GetIncomingPayment(replyTo: ActorRef) extends Command - case class ProcessPayment(payment: IncomingBlindedPayment) extends Command + case class ProcessPayment(payment: IncomingBlindedPayment, hiddenRelayFees: RelayFees) extends Command case class RejectPayment(reason: String) extends Command // @formatter:on @@ -431,7 +384,7 @@ object MultiPartHandler { } Behaviors.stopped case payload: FinalPayload.Blinded => - offerManager ! OfferManager.ReceivePayment(context.self, packet.add.paymentHash, payload) + offerManager ! OfferManager.ReceivePayment(context.self, packet.add.paymentHash, payload, packet.add.amountMsat) waitForPayment(context, nodeParams, replyTo, packet.add, payload) } } @@ -441,8 +394,8 @@ object MultiPartHandler { private def waitForPayment(context: typed.scaladsl.ActorContext[Command], nodeParams: NodeParams, replyTo: ActorRef, add: UpdateAddHtlc, payload: FinalPayload.Blinded): Behavior[Command] = { Behaviors.receiveMessagePartial { - case ProcessPayment(payment) => - replyTo ! ProcessBlindedPacket(add, payload, payment) + case ProcessPayment(payment, hiddenRelayFees) => + replyTo ! ProcessBlindedPacket(add, payload, payment, hiddenRelayFees) Behaviors.stopped case RejectPayment(reason) => context.log.info("rejecting blinded htlc #{} from channel {}: {}", add.id, add.channelId, reason) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala index db28e8f4ab..a8be934ec4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala @@ -58,7 +58,7 @@ class MultiPartPaymentFSM(nodeParams: NodeParams, paymentHash: ByteVector32, tot if (totalAmount != part.totalAmount) { log.warning("multi-part payment total amount mismatch: previously {}, now {}", totalAmount, part.totalAmount) goto(PAYMENT_FAILED) using PaymentFailed(IncorrectOrUnknownPaymentDetails(part.totalAmount, nodeParams.currentBlockHeight), updatedParts) - } else if (d.paidAmount + part.amount >= totalAmount) { + } else if (d.paidAmount + part.virtualAmount >= totalAmount) { goto(PAYMENT_SUCCEEDED) using PaymentSucceeded(updatedParts) } else { stay() using d.copy(parts = updatedParts) @@ -71,7 +71,7 @@ class MultiPartPaymentFSM(nodeParams: NodeParams, paymentHash: ByteVector32, tot // intermediate nodes will be able to fulfill that htlc anyway. This is a harmless spec violation. case Event(part: PaymentPart, _) => require(part.paymentHash == paymentHash, s"invalid payment hash (expected $paymentHash, received ${part.paymentHash}") - log.info("received extraneous payment part with amount={}", part.amount) + log.info("received extraneous payment part with virtualAmount={}, realAmount={}", part.virtualAmount, part.realAmount) replyTo ! ExtraPaymentReceived(paymentHash, part, None) stay() } @@ -130,13 +130,14 @@ object MultiPartPaymentFSM { /** An incoming payment that we're currently holding until we decide to fulfill or fail it (depending on whether we receive the complete payment). */ sealed trait PaymentPart { def paymentHash: ByteVector32 - def amount: MilliSatoshi + def virtualAmount: MilliSatoshi + def realAmount: MilliSatoshi def totalAmount: MilliSatoshi } /** An incoming HTLC. */ - case class HtlcPart(totalAmount: MilliSatoshi, htlc: UpdateAddHtlc) extends PaymentPart { + case class HtlcPart(totalAmount: MilliSatoshi, virtualAmount: MilliSatoshi, htlc: UpdateAddHtlc) extends PaymentPart { override def paymentHash: ByteVector32 = htlc.paymentHash - override def amount: MilliSatoshi = htlc.amountMsat + override def realAmount: MilliSatoshi = htlc.amountMsat } /** We successfully received all parts of the payment. */ case class MultiPartPaymentSucceeded(paymentHash: ByteVector32, parts: Queue[PaymentPart]) @@ -156,7 +157,7 @@ object MultiPartPaymentFSM { // @formatter:off sealed trait Data { def parts: Queue[PaymentPart] - lazy val paidAmount = parts.map(_.amount).sum + lazy val paidAmount = parts.map(_.virtualAmount).sum } case class WaitingForHtlc(parts: Queue[PaymentPart]) extends Data case class PaymentSucceeded(parts: Queue[PaymentPart]) extends Data diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 98f1708d3f..4fbf1b5dba 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -237,15 +237,15 @@ class NodeRelay private(nodeParams: NodeParams, case Relay(packet: IncomingPaymentPacket.NodeRelayPacket, originNode) => require(packet.outerPayload.paymentSecret == paymentSecret, "payment secret mismatch") context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", packet.add.id, packet.add.channelId) - handler ! MultiPartPaymentFSM.HtlcPart(packet.outerPayload.totalAmount, packet.add) + handler ! MultiPartPaymentFSM.HtlcPart(packet.outerPayload.totalAmount, packet.add.amountMsat, packet.add) receiving(htlcs :+ Upstream.Hot.Channel(packet.add.removeUnknownTlvs(), TimestampMilli.now(), originNode), nextPayload, nextPacket_opt, handler) case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) => - context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure) + context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.realAmount).sum, failure) Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline) - parts.collect { case p: MultiPartPaymentFSM.HtlcPart => rejectHtlc(p.htlc.id, p.htlc.channelId, p.amount, Some(failure)) } + parts.collect { case p: MultiPartPaymentFSM.HtlcPart => rejectHtlc(p.htlc.id, p.htlc.channelId, p.realAmount, Some(failure)) } stopping() case WrappedMultiPartPaymentSucceeded(MultiPartPaymentFSM.MultiPartPaymentSucceeded(_, parts)) => - context.log.info("completed incoming multi-part payment with parts={} paidAmount={}", parts.size, parts.map(_.amount).sum) + context.log.info("completed incoming multi-part payment with parts={} paidAmount={}", parts.size, parts.map(_.realAmount).sum) val upstream = Upstream.Hot.Trampoline(htlcs.toList) validateRelay(nodeParams, upstream, nextPayload) match { case Some(failure) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index 04813a94dc..0d69468b06 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -134,6 +134,10 @@ object Relayer extends Logging { require(feeProportionalMillionths >= 0.0, "feeProportionalMillionths must be nonnegative") } + object RelayFees { + val zero: RelayFees = RelayFees(MilliSatoshi(0), 0) + } + case class AsyncPaymentsParams(holdTimeoutBlocks: Int, cancelSafetyBeforeTimeout: CltvExpiryDelta) case class RelayParams(publicChannelFees: RelayFees, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index ab37eeaff0..e570551bc7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -338,6 +338,9 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A // this is most likely a liquidity issue, we remove this edge for our next payment attempt data.recipient.extraEdges.filterNot(edge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId) } + case _: HopRelayParams.Dummy => + log.error("received an update for a dummy hop, this should never happen") + data.recipient.extraEdges } case None => log.error(s"couldn't find node=$nodeId in the route, this should never happen") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala index 37e31c7ff8..8e2e2696c8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/BlindedRouteCreation.scala @@ -40,9 +40,8 @@ object BlindedRouteCreation { } } - /** Create a blinded route from a non-empty list of channel hops. */ - def createBlindedRouteFromHops(hops: Seq[Router.ChannelHop], pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { - require(hops.nonEmpty, "route must contain at least one hop") + /** Create a blinded route from a list of channel hops. */ + def createBlindedRouteFromHops(hops: Seq[Router.ChannelHop], finalNodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { // We use the same constraints for all nodes so they can't use it to guess their position. val routeExpiry = hops.foldLeft(routeFinalExpiry) { case (expiry, hop) => expiry + hop.cltvExpiryDelta } val routeMinAmount = hops.foldLeft(minAmount) { case (amount, hop) => amount.max(hop.params.htlcMinimum) } @@ -82,19 +81,10 @@ object BlindedRouteCreation { tlvs.copy(records = tlvs.records + RouteBlindingEncryptedDataTlv.Padding(ByteVector.fill(targetLength - payloadLength)(0))) }) val encodedPayloads = paddedPayloads.map(RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(_).require.bytes) :+ finalPayload - val nodeIds = hops.map(_.nodeId) :+ hops.last.nextNodeId + val nodeIds = hops.map(_.nodeId) :+ finalNodeId Sphinx.RouteBlinding.create(randomKey(), nodeIds, encodedPayloads) } - /** Create a blinded route where the recipient is also the introduction point (which reveals the recipient's identity). */ - def createBlindedRouteWithoutHops(nodeId: PublicKey, pathId: ByteVector, minAmount: MilliSatoshi, routeExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { - val finalPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream( - RouteBlindingEncryptedDataTlv.PaymentConstraints(routeExpiry, minAmount), - RouteBlindingEncryptedDataTlv.PathId(pathId), - )).require.bytes - Sphinx.RouteBlinding.create(randomKey(), Seq(nodeId), Seq(finalPayload)) - } - /** Create a blinded route where the recipient is a wallet node. */ def createBlindedRouteToWallet(hop: Router.ChannelHop, pathId: ByteVector, minAmount: MilliSatoshi, routeFinalExpiry: CltvExpiry): Sphinx.RouteBlinding.BlindedRouteDetails = { val routeExpiry = routeFinalExpiry + hop.cltvExpiryDelta diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index a849fb42c0..d95a1d6d22 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -512,6 +512,11 @@ object Router { override val htlcMaximum_opt = extraHop.htlcMaximum_opt } + case class Dummy(relayFees: Relayer.RelayFees, cltvExpiryDelta: CltvExpiryDelta) extends HopRelayParams { + override val htlcMinimum: MilliSatoshi = 1 msat + override val htlcMaximum_opt: Option[MilliSatoshi] = None + } + def areSame(a: HopRelayParams, b: HopRelayParams, ignoreHtlcSize: Boolean = false): Boolean = a.cltvExpiryDelta == b.cltvExpiryDelta && a.relayFees == b.relayFees && @@ -533,6 +538,11 @@ object Router { // @formatter:on } + object ChannelHop { + def dummy(nodeId: PublicKey, feeBase: MilliSatoshi, feeProportionalMillionths: Long, cltvExpiryDelta: CltvExpiryDelta): ChannelHop = + ChannelHop(ShortChannelId.toSelf, nodeId, nodeId, HopRelayParams.Dummy(Relayer.RelayFees(feeBase, feeProportionalMillionths), cltvExpiryDelta)) + } + sealed trait FinalHop extends Hop /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 0ee4626a1e..93e6c9ad61 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -21,7 +21,7 @@ import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.wire.protocol.CommonCodecs.{cltvExpiry, cltvExpiryDelta, featuresCodec} import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} import fr.acinq.eclair.wire.protocol.TlvCodecs.{fixedLengthTlvField, tlvField, tmillisatoshi, tmillisatoshi32} -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Feature, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, UInt64} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, EncodedNodeId, Feature, Features, MilliSatoshi, ShortChannelId, UInt64, amountAfterFee} import scodec.bits.ByteVector import scala.util.{Failure, Success} @@ -106,8 +106,7 @@ object BlindedRouteData { val paymentConstraints: PaymentConstraints = records.get[RouteBlindingEncryptedDataTlv.PaymentConstraints].get val allowedFeatures: Features[Feature] = records.get[RouteBlindingEncryptedDataTlv.AllowedFeatures].map(_.features).getOrElse(Features.empty) - def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = - ((incomingAmount - paymentRelay.feeBase).toLong * 1_000_000 + 1_000_000 + paymentRelay.feeProportionalMillionths - 1).msat / (1_000_000 + paymentRelay.feeProportionalMillionths) + def amountToForward(incomingAmount: MilliSatoshi): MilliSatoshi = amountAfterFee(paymentRelay.feeBase, paymentRelay.feeProportionalMillionths, incomingAmount) def outgoingCltv(incomingCltv: CltvExpiry): CltvExpiry = incomingCltv - paymentRelay.cltvExpiryDelta } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala index 593ad580ed..8929d2fc5e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala @@ -114,4 +114,15 @@ class PackageSpec extends AnyFunSuite { assert(ShortChannelId(Long.MaxValue) < ShortChannelId(Long.MaxValue + 1)) } + test("node fees") { + val rng = new scala.util.Random() + for (_ <- 1 to 100) { + val amount = rng.nextLong(1_000_000_000_000L) msat + val baseFee = rng.nextLong(10_000) msat + val proportionalFee = rng.nextLong(5_000) + val amountWithFees = amount + nodeFee(baseFee, proportionalFee, amount) + assert(amountAfterFee(baseFee, proportionalFee, amountWithFees) == amount) + } + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index d79665a0a0..9fa933daae 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -66,8 +66,8 @@ class AuditDbSpec extends AnyFunSuite { val now = TimestampMilli.now() val e1 = PaymentSent(ZERO_UUID, randomBytes32(), randomBytes32(), 40000 msat, randomKey().publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32(), None) :: Nil) - val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32()) - val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32()) + val pp2a = PaymentReceived.PartialPayment(42000 msat, 42000 msat,randomBytes32()) + val pp2b = PaymentReceived.PartialPayment(42100 msat, 42100 msat,randomBytes32()) val e2 = PaymentReceived(randomBytes32(), pp2a :: pp2b :: Nil) val e3 = ChannelPaymentRelayed(42000 msat, 1000 msat, randomBytes32(), randomBytes32(), randomBytes32(), now - 3.seconds, now) val e4a = TransactionPublished(randomBytes32(), randomKey().publicKey, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index 829b521612..49d1c5c2ec 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -82,11 +82,11 @@ class PaymentsDbSpec extends AnyFunSuite { // add a few rows val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, PaymentType.Standard, 12345 msat, 12345 msat, alice, 1000 unixms, None, None, OutgoingPaymentStatus.Pending) val i1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, Left("Some invoice"), CltvExpiryDelta(18), expirySeconds = None, timestamp = 1 unixsec) - val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(550 msat, 1100 unixms)) + val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(550 msat, 550 msat, 1100 unixms)) db.addOutgoingPayment(ps1) db.addIncomingPayment(i1, preimage1) - db.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100 unixms) + db.receiveIncomingPayment(i1.paymentHash, 550 msat, 550 msat, 1100 unixms) assert(db.listIncomingPayments(1 unixms, 1500 unixms, None) == Seq(pr1)) assert(db.listOutgoingPayments(1 unixms, 1500 unixms) == Seq(ps1)) @@ -105,7 +105,7 @@ class PaymentsDbSpec extends AnyFunSuite { val ps2 = OutgoingPayment(id2, id2, None, randomBytes32(), PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010 unixms, None, None, OutgoingPaymentStatus.Failed(Nil, 1050 unixms)) val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040 unixms, None, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060 unixms)) val i1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, Left("Some invoice"), CltvExpiryDelta(18), expirySeconds = None, timestamp = 1 unixsec) - val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, 1090 unixms)) + val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, 12345678 msat, 1090 unixms)) val i2 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, Left("Another invoice"), CltvExpiryDelta(18), expirySeconds = Some(30), timestamp = 1 unixsec) val pr2 = IncomingStandardPayment(i2, preimage2, PaymentType.Standard, i2.createdAt.toTimestampMilli, IncomingPaymentStatus.Expired) @@ -166,7 +166,7 @@ class PaymentsDbSpec extends AnyFunSuite { statement.setBytes(1, i1.paymentHash.toArray) statement.setBytes(2, pr1.paymentPreimage.toArray) statement.setString(3, i1.toString) - statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong) + statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount.toLong) statement.setLong(5, pr1.createdAt.toLong) statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt.toLong) statement.executeUpdate() @@ -298,7 +298,7 @@ class PaymentsDbSpec extends AnyFunSuite { val pendingInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(2500 msat), paymentHash1, bobPriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = now, expirySeconds = Some(30)) val pending = IncomingStandardPayment(pendingInvoice, preimage1, PaymentType.Standard, pendingInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Pending) val paidInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(10_000 msat), paymentHash2, bobPriv, Left("invoice #2"), CltvExpiryDelta(12), timestamp = 250 unixsec, expirySeconds = Some(60)) - val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 300.unixsec.toTimestampMilli)) + val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 11_000 msat, 300.unixsec.toTimestampMilli)) migrationCheck( dbs = dbs, @@ -428,7 +428,7 @@ class PaymentsDbSpec extends AnyFunSuite { val ps2 = OutgoingPayment(id2, id2, None, randomBytes32(), PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, TimestampMilli(Instant.parse("2020-05-14T13:47:21.00Z").toEpochMilli), None, None, OutgoingPaymentStatus.Failed(Nil, TimestampMilli(Instant.parse("2021-05-15T04:12:40.00Z").toEpochMilli))) val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, TimestampMilli(Instant.parse("2021-01-28T09:12:05.00Z").toEpochMilli), None, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, TimestampMilli.now())) val i1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, Left("Some invoice"), CltvExpiryDelta(18), expirySeconds = None, timestamp = TimestampSecond.now()) - val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, TimestampMilli.now())) + val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, 12345678 msat, TimestampMilli.now())) val i2 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, Left("Another invoice"), CltvExpiryDelta(18), expirySeconds = Some(24 * 3600), timestamp = TimestampSecond(Instant.parse("2020-12-30T10:00:55.00Z").getEpochSecond)) val pr2 = IncomingStandardPayment(i2, preimage2, PaymentType.Standard, i2.createdAt.toTimestampMilli, IncomingPaymentStatus.Expired) @@ -487,7 +487,7 @@ class PaymentsDbSpec extends AnyFunSuite { } using(connection.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => - update.setLong(1, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong) + update.setLong(1, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount.toLong) update.setLong(2, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt.toLong) update.setString(3, pr1.invoice.paymentHash.toHex) val updated = update.executeUpdate() @@ -526,7 +526,7 @@ class PaymentsDbSpec extends AnyFunSuite { val pendingInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(2500 msat), paymentHash1, bobPriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = now, expirySeconds = Some(30)) val pending = IncomingStandardPayment(pendingInvoice, preimage1, PaymentType.Standard, pendingInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Pending) val paidInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(10_000 msat), paymentHash2, bobPriv, Left("invoice #2"), CltvExpiryDelta(12), timestamp = 250 unixsec, expirySeconds = Some(60)) - val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 300.unixsec.toTimestampMilli)) + val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 11_000 msat, 300.unixsec.toTimestampMilli)) migrationCheck( dbs = dbs, @@ -650,7 +650,7 @@ class PaymentsDbSpec extends AnyFunSuite { // can't receive a payment without an invoice associated with it val unknownPaymentHash = randomBytes32() - assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat)) + assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat, 12345600 msat)) assert(db.getIncomingPayment(unknownPaymentHash).isEmpty) val expiredInvoice1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(561 msat), randomBytes32(), alicePriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = 1 unixsec) @@ -672,9 +672,9 @@ class PaymentsDbSpec extends AnyFunSuite { val receivedAt2 = TimestampMilli.now() + 2.milli val receivedAt3 = TimestampMilli.now() + 3.milli val receivedAt4 = TimestampMilli.now() + 4.milli - val payment1 = IncomingStandardPayment(paidInvoice1, randomBytes32(), PaymentType.Standard, paidInvoice1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(561 msat, receivedAt2)) - val payment2 = IncomingStandardPayment(paidInvoice2, randomBytes32(), PaymentType.Standard, paidInvoice2.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1111 msat, receivedAt2)) - val payment3 = IncomingBlindedPayment(paidInvoice3, randomBytes32(), PaymentType.Blinded, paidInvoice3.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1730 msat, receivedAt3)) + val payment1 = IncomingStandardPayment(paidInvoice1, randomBytes32(), PaymentType.Standard, paidInvoice1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(561 msat, 560 msat, receivedAt2)) + val payment2 = IncomingStandardPayment(paidInvoice2, randomBytes32(), PaymentType.Standard, paidInvoice2.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1111 msat, 1100 msat, receivedAt2)) + val payment3 = IncomingBlindedPayment(paidInvoice3, randomBytes32(), PaymentType.Blinded, paidInvoice3.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1730 msat, 1720 msat, receivedAt3)) db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage) db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage, PaymentType.SwapIn) @@ -682,7 +682,7 @@ class PaymentsDbSpec extends AnyFunSuite { db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage) db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage) db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage) - db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 1730 msat, receivedAt3) + db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 1730 msat, 1720 msat, receivedAt3) assert(db.getIncomingPayment(pendingInvoice1.paymentHash).contains(pendingPayment1)) assert(db.getIncomingPayment(expiredInvoice2.paymentHash).contains(expiredPayment2)) @@ -695,12 +695,12 @@ class PaymentsDbSpec extends AnyFunSuite { assert(db.listReceivedIncomingPayments(0 unixms, now, None) == Seq(payment3)) assert(db.listPendingIncomingPayments(0 unixms, now, None) == Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending))) - db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, receivedAt1) - db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, receivedAt2) // adding another payment to this invoice should sum - db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2) - db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 3400 msat, receivedAt4) + db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, 460 msat, receivedAt1) + db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, 100 msat, receivedAt2) // adding another payment to this invoice should sum + db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, 1100 msat, receivedAt2) + db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 3400 msat, 3400 msat, receivedAt4) - val payment4 = payment3.copy(status = IncomingPaymentStatus.Received(5130 msat, receivedAt4)) + val payment4 = payment3.copy(status = IncomingPaymentStatus.Received(5130 msat, 5120 msat, receivedAt4)) assert(db.getIncomingPayment(paidInvoice1.paymentHash).contains(payment1)) assert(db.getIncomingPayment(paidInvoice3.paymentHash).contains(payment4)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala index 73fd0f01f7..cf8177dca8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala @@ -37,13 +37,14 @@ import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient, buildRoute} import fr.acinq.eclair.payment._ +import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.payment.offer.OfferManager._ -import fr.acinq.eclair.payment.receive.MultiPartHandler.{DummyBlindedHop, ReceiveStandardPayment, ReceivingRoute} +import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceiveStandardPayment import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendTrampolinePayment} import fr.acinq.eclair.router.Graph.PaymentWeightRatios -import fr.acinq.eclair.router.Router.{GossipDecision, PublicChannel} +import fr.acinq.eclair.router.Router.{ChannelHop, GossipDecision, PublicChannel} import fr.acinq.eclair.router.{Announcements, AnnouncementsBatchValidationSpec, Router} import fr.acinq.eclair.wire.protocol.OfferTypes.{Offer, OfferPaths} import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, IncorrectOrUnknownPaymentDetails} @@ -372,7 +373,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(sent.head.copy(parts = sent.head.parts.sortBy(_.timestamp)) == paymentSent.copy(parts = paymentSent.parts.map(_.copy(route = None)).sortBy(_.timestamp)), sent) awaitCond(nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) } @@ -429,7 +430,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentParts.forall(p => p.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].feesPaid == 0.msat), paymentParts) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) } @@ -481,7 +482,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid == amount * 0.002) // 0.2% awaitCond(nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) awaitCond({ @@ -514,7 +515,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.recipientAmount == amount, paymentSent) awaitCond(nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) eventListener.expectMsg(PaymentMetadataReceived(invoice.paymentHash, invoice.paymentMetadata.get)) @@ -552,7 +553,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.recipientAmount == amount, paymentSent) awaitCond(nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) eventListener.expectMsg(PaymentMetadataReceived(invoice.paymentHash, invoice.paymentMetadata.get)) @@ -624,13 +625,20 @@ class PaymentIntegrationSpec extends IntegrationSpec { val bob = new EclairImpl(nodes("B")) bob.payOfferBlocking(offer, amount, 1, maxAttempts_opt = Some(3))(30 seconds).pipeTo(sender.ref) + nodes("D").router ! Router.FinalizeRoute(sender.ref, Router.PredefinedNodeRoute(amount, Seq(nodes("G").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId))) + val route1 = sender.expectMsgType[Router.RouteResponse].routes.head + nodes("D").router ! Router.FinalizeRoute(sender.ref, Router.PredefinedNodeRoute(amount, Seq(nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId))) + val route2 = sender.expectMsgType[Router.RouteResponse].routes.head + nodes("D").router ! Router.FinalizeRoute(sender.ref, Router.PredefinedNodeRoute(amount, Seq(nodes("E").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId))) + val route3 = sender.expectMsgType[Router.RouteResponse].routes.head + val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] val receivingRoutes = Seq( - ReceivingRoute(Seq(nodes("G").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId), CltvExpiryDelta(1000)), - ReceivingRoute(Seq(nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId), CltvExpiryDelta(1000)), - ReceivingRoute(Seq(nodes("E").nodeParams.nodeId, nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId), CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(route1.hops, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(route2.hops, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(route3.hops, CltvExpiryDelta(1000)), ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"abcd")) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"abcd")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -642,7 +650,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid > 0.msat, paymentSent) awaitCond(nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -660,10 +668,10 @@ class PaymentIntegrationSpec extends IntegrationSpec { val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] // C uses a 0-hop blinded route and signs the invoice with its public nodeId. val receivingRoutes = Seq( - ReceivingRoute(Seq(nodes("C").nodeParams.nodeId), CltvExpiryDelta(1000)), - ReceivingRoute(Seq(nodes("C").nodeParams.nodeId), CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(Nil, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(Nil, CltvExpiryDelta(1000)), ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"0123")) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"0123")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -675,7 +683,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid == 0.msat, paymentSent) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount == amount) } @@ -695,9 +703,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] val receivingRoutes = Seq( - ReceivingRoute(Seq(nodes("A").nodeParams.nodeId), CltvExpiryDelta(1000), Seq(DummyBlindedHop(100 msat, 100, CltvExpiryDelta(48)), DummyBlindedHop(150 msat, 50, CltvExpiryDelta(36)))) + OfferManager.InvoiceRequestActor.Route(Seq(ChannelHop.dummy(nodes("A").nodeParams.nodeId, 100 msat, 100, CltvExpiryDelta(48)), ChannelHop.dummy(nodes("A").nodeParams.nodeId, 150 msat, 50, CltvExpiryDelta(36))), CltvExpiryDelta(1000)) ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -709,7 +717,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid >= 0.msat, paymentSent) awaitCond(nodes("A").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -727,11 +735,14 @@ class PaymentIntegrationSpec extends IntegrationSpec { val bob = new EclairImpl(nodes("B")) bob.payOfferBlocking(offer, amount, 1, maxAttempts_opt = Some(3))(30 seconds).pipeTo(sender.ref) + nodes("C").router ! Router.FinalizeRoute(sender.ref, Router.PredefinedNodeRoute(amount, Seq(nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] val receivingRoutes = Seq( - ReceivingRoute(Seq(nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId), CltvExpiryDelta(555), Seq(DummyBlindedHop(55 msat, 55, CltvExpiryDelta(55)))) + OfferManager.InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(nodes("C").nodeParams.nodeId, 55 msat, 55, CltvExpiryDelta(55)), CltvExpiryDelta(555)) ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"eff0")) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"eff0")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -743,7 +754,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid >= 0.msat, paymentSent) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -758,9 +769,12 @@ class PaymentIntegrationSpec extends IntegrationSpec { val alice = new EclairImpl(nodes("A")) alice.payOfferTrampoline(offer, amount, 1, nodes("B").nodeParams.nodeId, maxAttempts_opt = Some(1))(30 seconds).pipeTo(sender.ref) + nodes("D").router ! Router.FinalizeRoute(sender.ref, Router.PredefinedNodeRoute(amount, Seq(nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] - val receivingRoutes = Seq(ReceivingRoute(Seq(nodes("C").nodeParams.nodeId, nodes("D").nodeParams.nodeId), CltvExpiryDelta(500))) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"0123")) + val receivingRoutes = Seq(OfferManager.InvoiceRequestActor.Route(route.hops, CltvExpiryDelta(500))) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"0123")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -772,7 +786,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid >= 0.msat, paymentSent) awaitCond(nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -797,6 +811,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { val alice = new EclairImpl(nodes("A")) alice.payOfferBlocking(offer, amount, 1, maxAttempts_opt = Some(3))(30 seconds).pipeTo(sender.ref) + nodes("C").router ! Router.FinalizeRoute(sender.ref, Router.PredefinedNodeRoute(amount, Seq(nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] val scidDirCB = { probe.send(nodes("B").router, Router.GetChannels) @@ -804,9 +821,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { ShortChannelIdDir(channelBC.nodeId1 == nodes("B").nodeParams.nodeId, channelBC.shortChannelId) } val receivingRoutes = Seq( - ReceivingRoute(Seq(nodes("B").nodeParams.nodeId, nodes("C").nodeParams.nodeId), CltvExpiryDelta(555), Seq(DummyBlindedHop(55 msat, 55, CltvExpiryDelta(55))), Some(scidDirCB)) + OfferManager.InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(nodes("C").nodeParams.nodeId, 55 msat, 55, CltvExpiryDelta(55)), CltvExpiryDelta(555), Some(scidDirCB)) ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -819,7 +836,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(invoice.blindedPaths.forall(_.route.firstNodeId.isInstanceOf[EncodedNodeId.ShortChannelIdDir])) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala index 9be6def1bd..30e9091349 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala @@ -93,7 +93,7 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat val watcherTyped = watcher.ref.toTyped[ZmqWatcher.Command] val register = system.actorOf(Register.props(), "register") val router = system.actorOf(Router.props(nodeParams, watcherTyped), "router") - val offerManager = system.spawn(OfferManager(nodeParams, router, 1 minute), "offer-manager") + val offerManager = system.spawn(OfferManager(nodeParams, 1 minute), "offer-manager") val paymentHandler = system.actorOf(PaymentHandler.props(nodeParams, register, offerManager), "payment-handler") val relayer = system.actorOf(Relayer.props(nodeParams, router, register, paymentHandler), "relayer") val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, bitcoinClient) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala index 5dc0ed6dc3..efb37a76f9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.integration.basic.payment import akka.actor.typed.Behavior import akka.actor.typed.scaladsl.Behaviors -import akka.actor.typed.scaladsl.adapter.ClassicActorSystemOps +import akka.actor.typed.scaladsl.adapter.{ClassicActorRefOps, ClassicActorSystemOps} import akka.testkit.TestProbe import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -33,9 +33,11 @@ import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient, buildRoute} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.offer.OfferManager -import fr.acinq.eclair.payment.receive.MultiPartHandler.{DummyBlindedHop, ReceivingRoute} +import fr.acinq.eclair.payment.offer.OfferManager.InvoiceRequestActor import fr.acinq.eclair.payment.send.OfferPayment import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendSpontaneousPayment} +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.testutils.FixtureSpec import fr.acinq.eclair.wire.protocol.OfferTypes.{Offer, OfferPaths} import fr.acinq.eclair.wire.protocol.{IncorrectOrUnknownPaymentDetails, InvalidOnionBlinding} @@ -151,10 +153,10 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { } } - def offerHandler(amount: MilliSatoshi, routes: Seq[ReceivingRoute]): Behavior[OfferManager.HandlerCommand] = { + def offerHandler(amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], hideFees: Boolean): Behavior[OfferManager.HandlerCommand] = { Behaviors.receiveMessage { case OfferManager.HandleInvoiceRequest(replyTo, _) => - replyTo ! OfferManager.InvoiceRequestActor.ApproveRequest(amount, routes) + replyTo ! InvoiceRequestActor.ApproveRequest(amount, routes, hideFees) Behaviors.same case OfferManager.HandlePayment(replyTo, _, _) => replyTo ! OfferManager.PaymentActor.AcceptPayment() @@ -162,12 +164,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { } } - def sendOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[ReceivingRoute], maxAttempts: Int = 1): (Offer, PaymentEvent) = { + def sendOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], maxAttempts: Int = 1, hideFees: Boolean = false): (Offer, PaymentEvent) = { import f._ val sender = TestProbe("sender") val offer = Offer(None, Some("test"), recipient.nodeId, Features.empty, recipient.nodeParams.chainHash) - val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) + val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes, hideFees)) recipient.offerManager ! OfferManager.RegisterOffer(offer, Some(recipient.nodeParams.privateKey), None, handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.router, payer.register, payer.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts, payer.routeParams, blocking = true) @@ -175,19 +177,18 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { (offer, sender.expectMsgType[PaymentEvent]) } - def sendPrivateOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[ReceivingRoute], maxAttempts: Int = 1): (Offer, PaymentEvent) = { + def sendPrivateOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], maxAttempts: Int = 1, hideFees: Boolean = false): (Offer, PaymentEvent) = { import f._ val sender = TestProbe("sender") val recipientKey = randomKey() val pathId = randomBytes32() val offerPaths = routes.map(route => { - val ourNodeId = route.nodes.last - val intermediateNodes = route.nodes.dropRight(1).map(IntermediateNode(_)) ++ route.dummyHops.map(_ => IntermediateNode(ourNodeId)) - buildRoute(randomKey(), intermediateNodes, Recipient(ourNodeId, Some(pathId))).route + val intermediateNodes = route.hops.map(hop => IntermediateNode(hop.nodeId)) + buildRoute(randomKey(), intermediateNodes, Recipient(recipient.nodeId, Some(pathId))).route }) val offer = Offer(None, Some("test"), recipientKey.publicKey, Features.empty, recipient.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(offerPaths))) - val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) + val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes, hideFees)) recipient.offerManager ! OfferManager.RegisterOffer(offer, Some(recipientKey), Some(pathId), handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.router, payer.register, payer.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts, payer.routeParams, blocking = true) @@ -195,13 +196,13 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { (offer, sender.expectMsgType[PaymentEvent]) } - def sendOfferPaymentWithInvalidAmount(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, payerAmount: MilliSatoshi, recipientAmount: MilliSatoshi, routes: Seq[ReceivingRoute]): PaymentFailed = { + def sendOfferPaymentWithInvalidAmount(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, payerAmount: MilliSatoshi, recipientAmount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], hideFees: Boolean = false): PaymentFailed = { import f._ val sender = TestProbe("sender") val paymentInterceptor = TestProbe("payment-interceptor") val offer = Offer(None, Some("test"), recipient.nodeId, Features.empty, recipient.nodeParams.chainHash) - val handler = recipient.system.spawnAnonymous(offerHandler(recipientAmount, routes)) + val handler = recipient.system.spawnAnonymous(offerHandler(recipientAmount, routes, hideFees)) recipient.offerManager ! OfferManager.RegisterOffer(offer, Some(recipient.nodeParams.privateKey), None, handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.router, payer.register, paymentInterceptor.ref)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, payer.routeParams, blocking = true) @@ -225,30 +226,80 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { import f._ val amount = 25_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) } + test("send blinded payment a->b->c, hidden fees") { f => + import f._ + + val amount = 25_000_000 msat + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, hideFees = true) + val payment = verifyPaymentSuccess(offer, amount, result) + assert(payment.parts.length == 1) + assert(payment.parts.head.amount == amount) + assert(payment.parts.head.feesPaid == 0.msat) + } + test("send blinded multi-part payment a->b->c") { f => import f._ val amount = 125_000_000 msat + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + val routes = Seq( - ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta), - ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), ) val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 2) } + test("send blinded multi-part payment a->b->c, hidden fees") { f => + import f._ + + val amount = 125_000_000 msat + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq( + InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), + ) + val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3, hideFees = true) + val payment = verifyPaymentSuccess(offer, amount, result) + assert(payment.parts.length == 2) + assert(payment.parts.forall(_.feesPaid == 0.msat)) + } + test("send blinded multi-part payment a->b->c (single channel a->b)", Tag(PrivateChannels)) { f => import f._ + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + // Carol advertises a single blinded path from Bob to herself. - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) // We make a first set of payments to ensure channels have less than 50 000 sat on Bob's side. Seq(50_000_000 msat, 50_000_000 msat).foreach(amount => { @@ -276,9 +327,13 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val channelId_bc_1 = openChannel(bob, carol, 250_000 sat).channelId waitForChannelCreatedBC(f, channelId_bc_1) - val route = ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta) + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val amount1 = 150_000_000 msat - val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount1, Seq(route), maxAttempts = 3) + val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount1, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount1, result) assert(payment.parts.length > 1) } @@ -286,30 +341,72 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { test("send blinded payment a->b->c with dummy hops") { f => import f._ + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + val amount = 125_000_000 msat val routes = Seq( - ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta, Seq(DummyBlindedHop(150 msat, 0, CltvExpiryDelta(50)))), - ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta, Seq(DummyBlindedHop(50 msat, 0, CltvExpiryDelta(20)), DummyBlindedHop(100 msat, 0, CltvExpiryDelta(30)))), + InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 150 msat, 0, CltvExpiryDelta(50)), maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops ++ Seq(ChannelHop.dummy(carol.nodeId, 50 msat, 0, CltvExpiryDelta(20)), ChannelHop.dummy(carol.nodeId, 100 msat, 0, CltvExpiryDelta(30))), maxFinalExpiryDelta), ) val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 2) } + test("send blinded payment a->b->c with dummy hops, hidden fees") { f => + import f._ + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val amount = 125_000_000 msat + val routes = Seq( + InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 150 msat, 0, CltvExpiryDelta(50)), maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops ++ Seq(ChannelHop.dummy(carol.nodeId, 50 msat, 0, CltvExpiryDelta(20)), ChannelHop.dummy(carol.nodeId, 100 msat, 0, CltvExpiryDelta(30))), maxFinalExpiryDelta), + ) + val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, hideFees = true) + val payment = verifyPaymentSuccess(offer, amount, result) + assert(payment.parts.length == 2) + assert(payment.parts.forall(_.feesPaid == 0.msat)) + } + test("send blinded payment a->b->c through private channels", Tag(PrivateChannels)) { f => import f._ val amount = 50_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount, routes) verifyPaymentSuccess(offer, amount, result) } + test("send blinded payment a->b->c through private channels, hidden fees", Tag(PrivateChannels)) { f => + import f._ + + val amount = 50_000_000 msat + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount, routes, hideFees = true) + val payment = verifyPaymentSuccess(offer, amount, result) + assert(payment.parts.forall(_.feesPaid == 0.msat)) + } + test("send blinded payment a->b") { f => import f._ val amount = 75_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(Nil, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -319,17 +416,33 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { import f._ val amount = 250_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId), maxFinalExpiryDelta, Seq(DummyBlindedHop(10 msat, 25, CltvExpiryDelta(24)), DummyBlindedHop(5 msat, 10, CltvExpiryDelta(36))))) + val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 10 msat, 25, CltvExpiryDelta(24)), ChannelHop.dummy(bob.nodeId, 5 msat, 10, CltvExpiryDelta(36))), maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) } + test("send blinded payment a->b with dummy hops, hidden fees") { f => + import f._ + + val amount = 250_000_000 msat + val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 10 msat, 25, CltvExpiryDelta(24)), ChannelHop.dummy(bob.nodeId, 5 msat, 10, CltvExpiryDelta(36))), maxFinalExpiryDelta)) + val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes, hideFees = true) + val payment = verifyPaymentSuccess(offer, amount, result) + assert(payment.parts.length == 1) + assert(payment.parts.forall(_.feesPaid == 0.msat)) + } + test("send fully blinded payment b->c") { f => import f._ val amount = 50_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, bob, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -342,8 +455,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val channelId_bc_1 = openChannel(bob, carol, 200_000 sat).channelId waitForChannelCreatedBC(f, channelId_bc_1) + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(50_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + // Carol creates a blinded path using that channel. - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) // We make a payment to ensure that the channel contains less than 150 000 sat on Bob's side. assert(sendPayment(bob, carol, 50_000_000 msat).isRight) @@ -363,7 +480,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { import f._ val amount = 50_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta, Seq(DummyBlindedHop(25 msat, 250, CltvExpiryDelta(75))))) + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 25 msat, 250, CltvExpiryDelta(75)), maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, bob, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -383,9 +505,13 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { // We wait for Carol to receive information about the channel between Alice and Bob. waitForAllChannelUpdates(f, channelsCount = 2) + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(alice.nodeId, bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + // Carol receives a first payment through those channels. { - val routes = Seq(ReceivingRoute(Seq(alice.nodeId, bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val amount1 = 100_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount1, routes) val payment = verifyPaymentSuccess(offer, amount1, result) @@ -401,7 +527,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { // Carol receives a second payment that requires using MPP. { - val routes = Seq(ReceivingRoute(Seq(alice.nodeId, bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val amount2 = 200_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount2, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount2, result) @@ -425,8 +551,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { // We wait for Carol to receive information about the channel between Alice and Bob. waitForAllChannelUpdates(f, channelsCount = 3) + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(alice.nodeId, bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + // Carol receives a payment that requires using MPP. - val routes = Seq(ReceivingRoute(Seq(alice.nodeId, bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val amount = 300_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) @@ -449,8 +579,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { // We wait for Carol to receive information about the channel between Alice and Bob. waitForAllChannelUpdates(f, channelsCount = 3) + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(alice.nodeId, bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + // Carol receives a payment that requires using MPP. - val routes = Seq(ReceivingRoute(Seq(alice.nodeId, bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val amount = 200_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) @@ -470,6 +604,9 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { import f._ val sender = TestProbe("sender") + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(75_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + // Bob sends payments to Carol to reduce the liquidity on both of his channels. Seq(1, 2).foreach(_ => { sender.send(bob.paymentInitiator, SendSpontaneousPayment(50_000_000 msat, carol.nodeId, randomBytes32(), 1, routeParams = bob.routeParams)) @@ -477,7 +614,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { sender.expectMsgType[PaymentSent] }) // Bob now doesn't have enough funds to relay the payment. - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) val (_, result) = sendOfferPayment(f, alice, carol, 75_000_000 msat, routes) verifyBlindedFailure(result, bob.nodeId) } @@ -485,7 +622,11 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { test("send blinded payment a->b->c using expired route") { f => import f._ - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), CltvExpiryDelta(-500))) + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(25_000_000 msat, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, CltvExpiryDelta(-500))) val (_, result) = sendOfferPayment(f, alice, carol, 25_000_000 msat, routes) verifyBlindedFailure(result, bob.nodeId) } @@ -495,7 +636,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val payerAmount = 20_000_000 msat val recipientAmount = 25_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(recipientAmount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) // The amount is below what Carol expects. val payment = sendOfferPaymentWithInvalidAmount(f, alice, carol, payerAmount, recipientAmount, routes) verifyBlindedFailure(payment, bob.nodeId) @@ -506,7 +652,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val payerAmount = 25_000_000 msat val recipientAmount = 50_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(Nil, maxFinalExpiryDelta)) // The amount is below what Bob expects: since he is both the introduction node and the final recipient, he sends // back a normal error. val payment = sendOfferPaymentWithInvalidAmount(f, alice, bob, payerAmount, recipientAmount, routes) @@ -522,7 +668,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val payerAmount = 25_000_000 msat val recipientAmount = 50_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId), maxFinalExpiryDelta, Seq(DummyBlindedHop(1 msat, 100, CltvExpiryDelta(48))))) + val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 1 msat, 100, CltvExpiryDelta(48))), maxFinalExpiryDelta)) // The amount is below what Bob expects: since he is both the introduction node and the final recipient, he sends // back a normal error. val payment = sendOfferPaymentWithInvalidAmount(f, alice, bob, payerAmount, recipientAmount, routes) @@ -538,7 +684,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val payerAmount = 45_000_000 msat val recipientAmount = 50_000_000 msat - val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) + + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(recipientAmount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) // The amount is below what Carol expects. val payment = sendOfferPaymentWithInvalidAmount(f, bob, carol, payerAmount, recipientAmount, routes) assert(payment.failures.head.isInstanceOf[PaymentFailure]) @@ -559,8 +710,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val compactOffer = Offer(None, Some("test"), recipientKey.publicKey, Features.empty, carol.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(Seq(compactBlindedRoute)))) assert(compactOffer.toString.length < offer.toString.length) - val receivingRoute = ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta) - val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(receivingRoute))) + val sender = TestProbe() + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + + val receivingRoute = InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta) + val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(receivingRoute), hideFees = false)) carol.offerManager ! OfferManager.RegisterOffer(compactOffer, Some(recipientKey), Some(pathId), handler) val offerPayment = alice.system.spawnAnonymous(OfferPayment(alice.nodeParams, alice.postman, alice.router, alice.register, alice.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, alice.routeParams, blocking = true) @@ -576,9 +731,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val pathId = randomBytes32() val amount = 25_000_000 msat + carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) + val route = sender.expectMsgType[Router.RouteResponse].routes.head + val offerPaths = Seq(OnionMessages.buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId)), Recipient(carol.nodeId, Some(pathId))).route) val offer = Offer.withPaths(None, Some("implicit node id"), offerPaths, Features.empty, carol.nodeParams.chainHash) - val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)))) + val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)), hideFees = false)) carol.offerManager ! OfferManager.RegisterOffer(offer, None, Some(pathId), handler) val offerPayment = alice.system.spawnAnonymous(OfferPayment(alice.nodeParams, alice.postman, alice.router, alice.register, alice.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, alice.routeParams, blocking = true) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index ac9df1e86f..a56b9760c0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -16,7 +16,6 @@ package fr.acinq.eclair.payment -import akka.actor.Status import akka.actor.typed.scaladsl.adapter.actorRefAdapter import akka.testkit.{TestActorRef, TestProbe} import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto} @@ -31,8 +30,10 @@ import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.payment.receive.MultiPartHandler._ import fr.acinq.eclair.payment.receive.MultiPartPaymentFSM.HtlcPart import fr.acinq.eclair.payment.receive.{MultiPartPaymentFSM, PaymentHandler} +import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair.router.BlindedRouteCreation.aggregatePaymentInfo import fr.acinq.eclair.router.Router -import fr.acinq.eclair.router.Router.{PaymentRouteNotFound, RouteResponse} +import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer, PaymentInfo} import fr.acinq.eclair.wire.protocol.OnionPaymentPayloadTlv._ import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload @@ -82,7 +83,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike lazy val handlerWithKeySend = TestActorRef[PaymentHandler](PaymentHandler.props(nodeParams.copy(features = featuresWithKeySend), register.ref, offerManager.ref)) lazy val handlerWithRouteBlinding = TestActorRef[PaymentHandler](PaymentHandler.props(nodeParams.copy(features = featuresWithRouteBlinding), register.ref, offerManager.ref)) - def createEmptyReceivingRoute(): Seq[ReceivingRoute] = Seq(ReceivingRoute(Seq(nodeParams.nodeId), CltvExpiryDelta(144))) + def createEmptyReceivingRoute(pathId: ByteVector): Seq[ReceivingRoute] = Seq(ReceivingRoute(Nil, pathId, CltvExpiryDelta(144), PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 0 msat, 1_000_000_000 msat, Features.empty))) } override def withFixture(test: OneArgTest): Outcome = { @@ -120,10 +121,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -136,10 +137,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(add.amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(70_000 msat, add.amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(add.amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(70_000 msat, add.amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -154,10 +155,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -168,7 +169,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val preimage = randomBytes32() val pathId = randomBytes32() val router = TestProbe() - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, privKey, invoiceReq, createEmptyReceivingRoute(), router.ref, preimage, pathId)) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, privKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) router.expectNoMessage(50 millis) val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice // Offer invoices shouldn't be stored in the DB until we receive a payment for it. @@ -180,14 +181,14 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, RelayFees.zero) assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == finalPacket.add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(finalPacket.add.paymentHash, PartialPayment(amountMsat, finalPacket.add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(finalPacket.add.paymentHash, PartialPayment(amountMsat, amountMsat, finalPacket.add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -266,27 +267,23 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val privKey = randomKey() - val offer = Offer(Some(25_000 msat), Some("a blinded coffee please"), privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) - val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() + val amount = 25_000 msat + val offer = Offer(Some(amount), Some("a blinded coffee please"), privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) + val invoiceReq = InvoiceRequest(offer, amount, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) val (a, b, c, d) = (randomKey().publicKey, randomKey().publicKey, randomKey().publicKey, nodeParams.nodeId) - val hop_ab = Router.ChannelHop(ShortChannelId(1), a, b, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(a, b, ShortChannelId(1), 1000 msat, 0, CltvExpiryDelta(100), 1 msat, None))) - val hop_bd = Router.ChannelHop(ShortChannelId(2), b, d, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(b, d, ShortChannelId(2), 800 msat, 0, CltvExpiryDelta(50), 1 msat, None))) - val hop_cd = Router.ChannelHop(ShortChannelId(3), c, d, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(c, d, ShortChannelId(3), 0 msat, 0, CltvExpiryDelta(75), 1 msat, None))) + val hop_ab = ChannelHop(ShortChannelId(1), a, b, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(a, b, ShortChannelId(1), 1000 msat, 0, CltvExpiryDelta(100), 1 msat, None))) + val hop_bd = ChannelHop(ShortChannelId(2), b, d, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(b, d, ShortChannelId(2), 800 msat, 0, CltvExpiryDelta(50), 1 msat, None))) + val hop_cd = ChannelHop(ShortChannelId(3), c, d, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(c, d, ShortChannelId(3), 0 msat, 0, CltvExpiryDelta(75), 1 msat, None))) + val hops1 = Seq(hop_ab, hop_bd, ChannelHop.dummy(d, 150 msat, 0, CltvExpiryDelta(25))) + val hops2 = Seq(hop_cd, ChannelHop.dummy(d, 250 msat, 0, CltvExpiryDelta(10)), ChannelHop.dummy(d, 150 msat, 0, CltvExpiryDelta(80))) val receivingRoutes = Seq( - ReceivingRoute(Seq(a, b, d), CltvExpiryDelta(100), Seq(DummyBlindedHop(150 msat, 0, CltvExpiryDelta(25)))), - ReceivingRoute(Seq(c, d), CltvExpiryDelta(50), Seq(DummyBlindedHop(250 msat, 0, CltvExpiryDelta(10)), DummyBlindedHop(150 msat, 0, CltvExpiryDelta(80)))), - ReceivingRoute(Seq(d), CltvExpiryDelta(250)), + ReceivingRoute(hops1, randomBytes32(), CltvExpiryDelta(100), aggregatePaymentInfo(amount, hops1, nodeParams.channelConf.minFinalExpiryDelta)), + ReceivingRoute(hops2, randomBytes32(), CltvExpiryDelta(50), aggregatePaymentInfo(amount, hops2, nodeParams.channelConf.minFinalExpiryDelta)), + ReceivingRoute(Nil, randomBytes32(), CltvExpiryDelta(250), PaymentInfo(0 msat, 0, nodeParams.channelConf.minFinalExpiryDelta, 0 msat, amount, Features.empty)), ) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, privKey, invoiceReq, receivingRoutes, router.ref, randomBytes32(), randomBytes32())) - val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute1.route == Router.PredefinedNodeRoute(25_000 msat, Seq(a, b, d))) - finalizeRoute1.replyTo ! RouteResponse(Seq(Router.Route(25_000 msat, Seq(hop_ab, hop_bd), None))) - val finalizeRoute2 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute2.route == Router.PredefinedNodeRoute(25_000 msat, Seq(c, d))) - finalizeRoute2.replyTo ! RouteResponse(Seq(Router.Route(25_000 msat, Seq(hop_cd), None))) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, privKey, invoiceReq, receivingRoutes, randomBytes32())) val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice - assert(invoice.amount == 25_000.msat) + assert(invoice.amount == amount) assert(invoice.nodeId == privKey.publicKey) assert(invoice.blindedPaths.nonEmpty) assert(invoice.description.contains("a blinded coffee please")) @@ -307,29 +304,6 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(invoice.blindedPaths.flatMap(_.route.encryptedPayloads.dropRight(1)).map(_.length).toSet.size == 1) } - test("Invoice generation with route blinding should fail when router returns an error") { f => - import f._ - - val privKey = randomKey() - val offer = Offer(Some(25_000 msat), Some("a blinded coffee please"), privKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) - val invoiceReq = InvoiceRequest(offer, 25_000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) - val router = TestProbe() - val (a, b, c) = (randomKey().publicKey, randomKey().publicKey, nodeParams.nodeId) - val hop_ac = Router.ChannelHop(ShortChannelId(1), a, c, Router.HopRelayParams.FromHint(Invoice.ExtraEdge(a, c, ShortChannelId(1), 100 msat, 0, CltvExpiryDelta(50), 1 msat, None))) - val receivingRoutes = Seq( - ReceivingRoute(Seq(a, c), CltvExpiryDelta(100)), - ReceivingRoute(Seq(b, c), CltvExpiryDelta(100)), - ) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, privKey, invoiceReq, receivingRoutes, router.ref, randomBytes32(), randomBytes32())) - val finalizeRoute1 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute1.route == Router.PredefinedNodeRoute(25_000 msat, Seq(a, c))) - finalizeRoute1.replyTo ! RouteResponse(Seq(Router.Route(25_000 msat, Seq(hop_ac), None))) - val finalizeRoute2 = router.expectMsgType[Router.FinalizeRoute] - assert(finalizeRoute2.route == Router.PredefinedNodeRoute(25_000 msat, Seq(b, c))) - finalizeRoute2.replyTo ! PaymentRouteNotFound(new IllegalArgumentException("invalid route")) - sender.expectMsgType[CreateInvoiceActor.BlindedRouteCreationFailed] - } - test("Generated invoice contains the provided extra hops") { f => import f._ @@ -487,7 +461,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val nodeKey = randomKey() val offer = Offer(None, Some("a blinded coffee please"), nodeKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref, randomBytes32(), randomBytes32())) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(randomBytes32()), randomBytes32())) val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice val add = UpdateAddHtlc(ByteVector32.One, 0, 5000 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None, 1.0, None) @@ -505,7 +479,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) val preimage = randomBytes32() val pathId = randomBytes32() - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref, preimage, pathId)) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).isEmpty) @@ -515,7 +489,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, RelayFees.zero) register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status.isInstanceOf[IncomingPaymentStatus.Received]) } @@ -528,7 +502,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val pathId = randomBytes(128) val offer = Offer(None, Some("a blinded coffee please"), nodeKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref, preimage, pathId)) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice val packet = createBlindedPacket(5000 msat, invoice.paymentHash, defaultExpiry, CltvExpiry(nodeParams.currentBlockHeight), pathId) @@ -548,7 +522,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) val preimage = randomBytes32() val pathId = randomBytes32() - sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(), TestProbe().ref, preimage, pathId)) + sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice // We test the case where the HTLC's cltv_expiry is lower than expected and doesn't meet the min_final_expiry_delta. @@ -558,7 +532,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, RelayFees.zero) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(5000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).isEmpty) @@ -596,7 +570,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike }) // Extraneous HTLCs should be failed. - f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(pr1.paymentHash, HtlcPart(1000 msat, UpdateAddHtlc(ByteVector32.One, 42, 200 msat, pr1.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), Some(PaymentTimeout()))) + f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(pr1.paymentHash, HtlcPart(1000 msat, 1000 msat, UpdateAddHtlc(ByteVector32.One, 42, 200 msat, pr1.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), Some(PaymentTimeout()))) f.register.expectMsg(Register.Forward(null, ByteVector32.One, CMD_FAIL_HTLC(42, FailureReason.LocalFailure(PaymentTimeout()), commit = true))) // The payment should still be pending in DB. @@ -627,21 +601,25 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike ) val paymentReceived = f.eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(800 msat, ByteVector32.One, 0 unixms), PartialPayment(200 msat, ByteVector32.Zeroes, 0 unixms))) + assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(800 msat, 800 msat, ByteVector32.One, 0 unixms), PartialPayment(200 msat, 200 msat, ByteVector32.Zeroes, 0 unixms))) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1000.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1000.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1000.msat) awaitCond({ f.sender.send(handler, GetPendingPayments) f.sender.expectMsgType[PendingPayments].paymentHashes.isEmpty }) // Extraneous HTLCs should be fulfilled. - f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(invoice.paymentHash, HtlcPart(1000 msat, UpdateAddHtlc(ByteVector32.One, 44, 200 msat, invoice.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), None)) + f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(invoice.paymentHash, HtlcPart(1000 msat, 200 msat, UpdateAddHtlc(ByteVector32.One, 44, 200 msat, invoice.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), None)) f.register.expectMsg(Register.Forward(null, ByteVector32.One, CMD_FULFILL_HTLC(44, preimage, commit = true))) - assert(f.eventListener.expectMsgType[PaymentReceived].amount == 200.msat) + val paymentReceived2 = f.eventListener.expectMsgType[PaymentReceived] + assert(paymentReceived2.virtualAmount == 200.msat) + assert(paymentReceived2.realAmount == 200.msat) val received2 = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) - assert(received2.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1200.msat) + assert(received2.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1200.msat) + assert(received2.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1200.msat) f.sender.send(handler, GetPendingPayments) f.sender.expectMsgType[PendingPayments].paymentHashes.isEmpty @@ -666,10 +644,11 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike ) val paymentReceived = f.eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(1100 msat, add1.channelId, 0 unixms), PartialPayment(500 msat, add2.channelId, 0 unixms))) + assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(1100 msat, 1100 msat, add1.channelId, 0 unixms), PartialPayment(500 msat, 500 msat, add2.channelId, 0 unixms))) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1600.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1600.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1600.msat) } test("PaymentHandler should handle multi-part payment timeout then success") { f => @@ -703,10 +682,11 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val paymentReceived = f.eventListener.expectMsgType[PaymentReceived] assert(paymentReceived.paymentHash == invoice.paymentHash) - assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(300 msat, ByteVector32.One, 0 unixms), PartialPayment(700 msat, ByteVector32.Zeroes, 0 unixms))) + assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(300 msat, 300 msat, ByteVector32.One, 0 unixms), PartialPayment(700 msat, 700 msat, ByteVector32.Zeroes, 0 unixms))) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1000.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1000.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1000.msat) awaitCond({ f.sender.send(handler, GetPendingPayments) f.sender.expectMsgType[PendingPayments].paymentHashes.isEmpty @@ -729,10 +709,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) } test("PaymentHandler should handle single-part KeySend payment without payment secret") { f => @@ -750,10 +730,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) } test("PaymentHandler should reject KeySend payment when feature is disabled") { f => @@ -812,7 +792,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None, 1.0, None) val invoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, None, paymentHash, randomKey(), Left("dummy"), CltvExpiryDelta(12)) val incomingPayment = IncomingStandardPayment(invoice, paymentPreimage, PaymentType.Standard, invoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Pending) - val fulfill = DoFulfill(incomingPayment, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, add)))) + val fulfill = DoFulfill(incomingPayment, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, 1000 msat, add)))) sender.send(handlerWithoutMpp, fulfill) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.id == add.id) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala index dcdd2ef225..e00e142402 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala @@ -233,7 +233,7 @@ object MultiPartPaymentFSMSpec { def createMultiPartHtlc(totalAmount: MilliSatoshi, htlcAmount: MilliSatoshi, htlcId: Long): HtlcPart = { val htlc = UpdateAddHtlc(htlcIdToChannelId(htlcId), htlcId, htlcAmount, paymentHash, CltvExpiry(42), TestConstants.emptyOnionPacket, None, 1.0, None) - HtlcPart(totalAmount, htlc) + HtlcPart(totalAmount, htlcAmount, htlc) } } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala index f4a7375d98..4d791ac85b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentLifecycleSpec.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.payment -import akka.actor.{ActorContext, ActorRef, Status, typed} +import akka.actor.{ActorContext, ActorRef} import akka.testkit.{TestFSMRef, TestProbe} import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, SatoshiLong} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala index 1ef679f196..79011ef984 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentInitiatorSpec.scala @@ -287,7 +287,7 @@ class PaymentInitiatorSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike def createBolt12Invoice(features: Features[Bolt12Feature], payerKey: PrivateKey): Bolt12Invoice = { val offer = Offer(None, Some("Bolt12 r0cks"), e, features, Block.RegtestGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, finalAmount, 1, features, payerKey, Block.RegtestGenesisBlock.hash) - val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(e, hex"2a2a2a2a", 1 msat, CltvExpiry(500_000)).route + val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Nil, e, hex"2a2a2a2a", 1 msat, CltvExpiry(500_000)).route val paymentInfo = OfferTypes.PaymentInfo(1_000 msat, 0, CltvExpiryDelta(24), 0 msat, finalAmount, Features.empty) Bolt12Invoice(invoiceRequest, paymentPreimage, priv_e.privateKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala index 377c7350eb..7bb38b30ed 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PaymentPacketSpec.scala @@ -215,7 +215,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val features = Features[Bolt12Feature](BasicMultiPartPayment -> Optional) val offer = Offer(None, Some("Bolt12 r0cks"), recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) - val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(c, hex"deadbeef", 1 msat, CltvExpiry(500_000)).route + val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Nil, c, hex"deadbeef", 1 msat, CltvExpiry(500_000)).route val paymentInfo = PaymentInfo(0 msat, 0, CltvExpiryDelta(0), 1 msat, amount_bc, Features.empty) val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, recipientKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) val resolvedPaths = invoice.blindedPaths.map(path => { @@ -484,7 +484,7 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { val offer = Offer(None, Some("Bolt12 r0cks"), c, features, Block.RegtestGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, amount_bc, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) // We send the wrong blinded payload to the introduction node. - val tmpBlindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Seq(channelHopFromUpdate(b, c, channelUpdate_bc)), hex"deadbeef", 1 msat, CltvExpiry(500_000)).route + val tmpBlindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Seq(channelHopFromUpdate(b, c, channelUpdate_bc)), c, hex"deadbeef", 1 msat, CltvExpiry(500_000)).route val blindedRoute = tmpBlindedRoute.copy(blindedHops = tmpBlindedRoute.blindedHops.reverse) val paymentInfo = OfferTypes.PaymentInfo(fee_b, 0, channelUpdate_bc.cltvExpiryDelta, 0 msat, amount_bc, Features.empty) val invoice = Bolt12Invoice(invoiceRequest, paymentPreimage, priv_c.privateKey, 300 seconds, features, Seq(PaymentBlindedRoute(blindedRoute, paymentInfo))) @@ -531,26 +531,6 @@ class PaymentPacketSpec extends AnyFunSuite with BeforeAndAfterAll { assert(failure == FinalIncorrectCltvExpiry(payment.cmd.cltvExpiry - CltvExpiryDelta(12))) } - test("fail to decrypt blinded payment at the final node when amount is too low") { - val (route, recipient) = shortBlindedHops() - val Right(payment) = buildOutgoingPayment(TestConstants.emptyOrigin, paymentHash, route, recipient, 1.0) - assert(payment.outgoingChannel == channelUpdate_cd.shortChannelId) - assert(payment.cmd.amount == amount_cd) - - // A smaller amount is sent to d, who doesn't know that it's invalid. - val add_d = UpdateAddHtlc(randomBytes32(), 0, amount_de, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextPathKey_opt, 1.0, payment.cmd.fundingFee_opt) - val Right(relay_d@ChannelRelayPacket(_, payload_d, packet_e)) = decrypt(add_d, priv_d.privateKey, Features(RouteBlinding -> Optional)) - assert(payload_d.outgoing.contains(channelUpdate_de.shortChannelId)) - assert(relay_d.amountToForward < amount_de) - assert(payload_d.isInstanceOf[IntermediatePayload.ChannelRelay.Blinded]) - val pathKey_e = payload_d.asInstanceOf[IntermediatePayload.ChannelRelay.Blinded].nextPathKey - - // When e receives a smaller amount than expected, it rejects the payment. - val add_e = UpdateAddHtlc(randomBytes32(), 0, relay_d.amountToForward, paymentHash, relay_d.outgoingCltv, packet_e, Some(pathKey_e), 1.0, None) - val Left(failure) = decrypt(add_e, priv_e.privateKey, Features(RouteBlinding -> Optional)) - assert(failure.isInstanceOf[InvalidOnionBlinding]) - } - test("fail to decrypt blinded payment at the final node when expiry is too low") { val (route, recipient) = shortBlindedHops() val Right(payment) = buildOutgoingPayment(TestConstants.emptyOrigin, paymentHash, route, recipient, 1.0) @@ -786,7 +766,7 @@ object PaymentPacketSpec { val amount_ab = amount_bc + fee_b def buildOutgoingBlindedPaymentAB(paymentHash: ByteVector32, routeExpiry: CltvExpiry = CltvExpiry(500_000)): Either[OutgoingPaymentError, OutgoingPaymentPacket] = { - val blindedRoute = BlindedRouteCreation.createBlindedRouteWithoutHops(b, hex"deadbeef", 1.msat, routeExpiry).route + val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(Nil, b, hex"deadbeef", 1.msat, routeExpiry).route val finalPayload = NodePayload(blindedRoute.firstNode.blindedPublicKey, OutgoingBlindedPerHopPayload.createFinalPayload(finalAmount, finalAmount, finalExpiry, blindedRoute.firstNode.encryptedPayload)) val onion = buildOnion(Seq(finalPayload), paymentHash, Some(PaymentOnionCodecs.paymentOnionPayloadLength)).toOption.get // BOLT 2 requires that associatedData == paymentHash val cmd = CMD_ADD_HTLC(ActorRef.noSender, finalAmount, paymentHash, finalExpiry, onion.packet, Some(blindedRoute.firstPathKey), 1.0, None, TestConstants.emptyOrigin, commit = true) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 8133fafc34..301ad9f966 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -197,7 +197,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit val paymentHash = Crypto.sha256(preimage) val invoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(500 msat), paymentHash, TestConstants.Bob.nodeKeyManager.nodeKey.privateKey, Left("Some invoice"), CltvExpiryDelta(18)) nodeParams.db.payments.addIncomingPayment(invoice, preimage) - nodeParams.db.payments.receiveIncomingPayment(paymentHash, 5000 msat) + nodeParams.db.payments.receiveIncomingPayment(paymentHash, 5000 msat, 5000 msat) val htlc_ab_1 = Seq( buildFinalHtlc(0, channelId_ab_1, randomBytes32()), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala index bfc9008bca..f6bc3b256d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala @@ -29,10 +29,12 @@ import fr.acinq.eclair.payment.offer.OfferManager._ import fr.acinq.eclair.payment.receive.MultiPartHandler import fr.acinq.eclair.payment.receive.MultiPartHandler.GetIncomingPaymentActor.{ProcessPayment, RejectPayment} import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceivingRoute +import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, amountAfterFee, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} @@ -49,7 +51,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val nodeParams = TestConstants.Alice.nodeParams val router = akka.testkit.TestProbe()(system.toClassic) val paymentTimeout = if (test.tags.contains(ShortPaymentTimeout)) 100 millis else 5 seconds - val offerManager = testKit.spawn(OfferManager(nodeParams, router.ref, paymentTimeout)) + val offerManager = testKit.spawn(OfferManager(nodeParams, paymentTimeout)) val postman = TestProbe[Postman.Command]() val paymentHandler = TestProbe[MultiPartHandler.GetIncomingPaymentActor.Command]() try { @@ -69,13 +71,13 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app offerManager ! RequestInvoice(messagePayload, offerKey, postman) } - def receiveInvoice(f: FixtureParam, amount: MilliSatoshi, payerKey: PrivateKey, pathNodeId: PublicKey, handler: TestProbe[HandlerCommand], pluginData_opt: Option[ByteVector] = None): Bolt12Invoice = { + def receiveInvoice(f: FixtureParam, amount: MilliSatoshi, payerKey: PrivateKey, pathNodeId: PublicKey, handler: TestProbe[HandlerCommand], pluginData_opt: Option[ByteVector] = None, hops: Seq[ChannelHop] = Nil, hideFees: Boolean = false): Bolt12Invoice = { import f._ val handleInvoiceRequest = handler.expectMessageType[HandleInvoiceRequest] assert(handleInvoiceRequest.invoiceRequest.isValid) assert(handleInvoiceRequest.invoiceRequest.payerId == payerKey.publicKey) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, Seq(ReceivingRoute(Seq(nodeParams.nodeId), CltvExpiryDelta(1000), Nil)), pluginData_opt) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, Seq(InvoiceRequestActor.Route(hops, CltvExpiryDelta(1000))), hideFees, pluginData_opt) val invoiceMessage = postman.expectMessageType[Postman.SendMessage] val Right(invoice) = Bolt12Invoice.validate(invoiceMessage.message.get[OnionMessagePayloadTlv.Invoice].get.tlvs) assert(invoice.validateFor(handleInvoiceRequest.invoiceRequest, pathNodeId).isRight) @@ -89,7 +91,14 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(invoice.blindedPaths.length == 1) val blindedPath = invoice.blindedPaths.head.route - val Right(RouteBlindingDecryptedData(encryptedDataTlvs, _)) = RouteBlindingEncryptedDataCodecs.decode(nodeParams.privateKey, blindedPath.firstPathKey, blindedPath.encryptedPayloads.head) + val Right(RouteBlindingDecryptedData(tlvs, nextPathKey)) = RouteBlindingEncryptedDataCodecs.decode(nodeParams.privateKey, blindedPath.firstPathKey, blindedPath.encryptedPayloads.head) + var encryptedDataTlvs = tlvs + var pathKey = nextPathKey + for (encryptedPayload <- blindedPath.encryptedPayloads.drop(1)) { + val Right(RouteBlindingDecryptedData(tlvs, nextPathKey)) = RouteBlindingEncryptedDataCodecs.decode(nodeParams.privateKey, pathKey, encryptedPayload) + encryptedDataTlvs = tlvs + pathKey = nextPathKey + } val paymentTlvs = TlvStream[OnionPaymentPayloadTlv]( OnionPaymentPayloadTlv.AmountToForward(invoice.amount), OnionPaymentPayloadTlv.TotalAmount(invoice.amount), @@ -111,15 +120,16 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler, pluginData_opt = Some(hex"deadbeef")) // Pay invoice. val paymentPayload = createPaymentPayload(f, invoice) - offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amount) val handlePayment = handler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) assert(handlePayment.pluginData_opt.contains(hex"deadbeef")) handlePayment.replyTo ! PaymentActor.AcceptPayment() - val ProcessPayment(incomingPayment) = paymentHandler.expectMessageType[ProcessPayment] + val ProcessPayment(incomingPayment, hiddenRelayFees) = paymentHandler.expectMessageType[ProcessPayment] assert(Crypto.sha256(incomingPayment.paymentPreimage) == invoice.paymentHash) assert(incomingPayment.invoice.nodeId == nodeParams.nodeId) assert(incomingPayment.invoice.paymentHash == invoice.paymentHash) + assert(hiddenRelayFees == RelayFees.zero) } test("pay offer without path_id") { f => @@ -205,7 +215,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val invoice2 = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler) // Try paying invoice #1 with data from invoice #2. val paymentPayload = createPaymentPayload(f, invoice2) - offerManager ! ReceivePayment(paymentHandler.ref, invoice1.paymentHash, paymentPayload) + offerManager ! ReceivePayment(paymentHandler.ref, invoice1.paymentHash, paymentPayload, amount) paymentHandler.expectMessageType[RejectPayment] handler.expectNoMessage(50 millis) } @@ -228,7 +238,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val invalidPaymentPayload = paymentPayload.copy( blindedRecords = TlvStream(paymentPayload.blindedRecords.records.filterNot(_.isInstanceOf[RouteBlindingEncryptedDataTlv.PathId]) + RouteBlindingEncryptedDataTlv.PathId(invalidPathId)) ) - offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, invalidPaymentPayload) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, invalidPaymentPayload, amount) paymentHandler.expectMessageType[RejectPayment] handler.expectNoMessage(50 millis) } @@ -246,7 +256,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler) // Try paying the invoice, but the plugin handler doesn't respond. val paymentPayload = createPaymentPayload(f, invoice) - offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amount) handler.expectMessageType[HandlePayment] assert(paymentHandler.expectMessageType[RejectPayment].reason == "plugin timeout") } @@ -264,10 +274,70 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler) // Try paying the invoice, but the plugin handler rejects the payment. val paymentPayload = createPaymentPayload(f, invoice) - offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amount) val handlePayment = handler.expectMessageType[HandlePayment] handlePayment.replyTo ! PaymentActor.RejectPayment("internal error") assert(paymentHandler.expectMessageType[RejectPayment].reason == "internal error") } + test("invalid payment (incorrect amount)") { f => + import f._ + + val handler = TestProbe[HandlerCommand]() + val amount = 10_000_000 msat + val offer = Offer(Some(amount), Some("offer"), nodeParams.nodeId, Features.empty, nodeParams.chainHash) + offerManager ! RegisterOffer(offer, Some(nodeParams.privateKey), None, handler.ref) + // Request invoice. + val payerKey = randomKey() + requestInvoice(payerKey, offer, nodeParams.privateKey, amount, offerManager, postman.ref) + val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler) + // Try sending 1 msat less than needed + val paymentPayload = createPaymentPayload(f, invoice) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amount - 1.msat) + paymentHandler.expectMessageType[RejectPayment] + handler.expectNoMessage(50 millis) + } + + test("pay offer with hidden fees") { f => + import f._ + + val handler = TestProbe[HandlerCommand]() + val amount = 10_000_000 msat + val offer = Offer(Some(amount), Some("offer"), nodeParams.nodeId, Features.empty, nodeParams.chainHash) + offerManager ! RegisterOffer(offer, Some(nodeParams.privateKey), None, handler.ref) + // Request invoice. + val payerKey = randomKey() + requestInvoice(payerKey, offer, nodeParams.privateKey, amount, offerManager, postman.ref) + val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler, hops = List(ChannelHop.dummy(nodeParams.nodeId, 1000 msat, 200, CltvExpiryDelta(144))), hideFees = true) + // Sending less than the full amount as fees are paid by the recipient + val paymentPayload = createPaymentPayload(f, invoice) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amountAfterFee(1000 msat, 200, amount)) + + val handlePayment = handler.expectMessageType[HandlePayment] + assert(handlePayment.offerId == offer.offerId) + handlePayment.replyTo ! PaymentActor.AcceptPayment() + val ProcessPayment(incomingPayment, hiddenRelayFees) = paymentHandler.expectMessageType[ProcessPayment] + assert(Crypto.sha256(incomingPayment.paymentPreimage) == invoice.paymentHash) + assert(incomingPayment.invoice.nodeId == nodeParams.nodeId) + assert(incomingPayment.invoice.paymentHash == invoice.paymentHash) + assert(hiddenRelayFees == RelayFees(1000 msat, 200)) + } + + test("invalid payment (incorrect amount with hidden fee)") { f => + import f._ + + val handler = TestProbe[HandlerCommand]() + val amount = 10_000_000 msat + val offer = Offer(Some(amount), Some("offer"), nodeParams.nodeId, Features.empty, nodeParams.chainHash) + offerManager ! RegisterOffer(offer, Some(nodeParams.privateKey), None, handler.ref) + // Request invoice. + val payerKey = randomKey() + requestInvoice(payerKey, offer, nodeParams.privateKey, amount, offerManager, postman.ref) + val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler, hops = List(ChannelHop.dummy(nodeParams.nodeId, 1000 msat, 200, CltvExpiryDelta(144))), hideFees = true) + // Try sending 1 msat less than needed + val paymentPayload = createPaymentPayload(f, invoice) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amountAfterFee(1000 msat, 200, amount) - 1.msat) + paymentHandler.expectMessageType[RejectPayment] + handler.expectNoMessage(50 millis) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala index 866cca0210..ba0b97c3e3 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala @@ -52,11 +52,11 @@ class InvoicePurgerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("ap // create paid invoices val receivedAt = TimestampMilli.now() + 1.milli val paidInvoices = Seq.fill(count)(Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(100 msat), randomBytes32(), alicePriv, Left("paid invoice"), CltvExpiryDelta(18))) - val paidPayments = paidInvoices.map(invoice => IncomingStandardPayment(invoice, randomBytes32(), PaymentType.Standard, invoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(100 msat, receivedAt))) + val paidPayments = paidInvoices.map(invoice => IncomingStandardPayment(invoice, randomBytes32(), PaymentType.Standard, invoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(100 msat, 100 msat, receivedAt))) paidPayments.foreach(payment => { db.addIncomingPayment(payment.invoice, payment.paymentPreimage) // receive payment - db.receiveIncomingPayment(payment.invoice.paymentHash, 100 msat, receivedAt) + db.receiveIncomingPayment(payment.invoice.paymentHash, 100 msat, 100 msat, receivedAt) }) val now = TimestampMilli.now() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala index 9603b81af3..ab2a0ac59d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/BlindedPathsResolverSpec.scala @@ -114,7 +114,7 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l ExtraEdge(nextNodeId, randomKey().publicKey, RealShortChannelId(BlockHeight(700_000), 1, 0), 750_000 msat, 150, CltvExpiryDelta(48), 1 msat, None), ) val hops = edges.map(e => ChannelHop(e.shortChannelId, e.sourceNodeId, e.targetNodeId, HopRelayParams.FromHint(e))) - val route = BlindedRouteCreation.createBlindedRouteFromHops(hops, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route + val route = BlindedRouteCreation.createBlindedRouteFromHops(hops, hops.last.nextNodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(100_000_000 msat, hops, CltvExpiryDelta(12)) Seq(true, false).foreach { useScidDir => val toResolve = if (useScidDir) { @@ -183,7 +183,7 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l val scid = RealShortChannelId(BlockHeight(750_000), 3, 7) val edge = ExtraEdge(nodeParams.nodeId, randomKey().publicKey, scid, 600_000 msat, 100, CltvExpiryDelta(144), 1 msat, None) val hop = ChannelHop(edge.shortChannelId, edge.sourceNodeId, edge.targetNodeId, HopRelayParams.FromHint(edge)) - val route = BlindedRouteCreation.createBlindedRouteFromHops(Seq(hop), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route + val route = BlindedRouteCreation.createBlindedRouteFromHops(Seq(hop), edge.targetNodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(50_000_000 msat, Seq(hop), CltvExpiryDelta(12)) val toResolve = Seq( PaymentBlindedRoute(route.copy(firstNodeId = EncodedNodeId.ShortChannelIdDir(isNode1 = true, scid)), paymentInfo), @@ -211,15 +211,15 @@ class BlindedPathsResolverSpec extends ScalaTestWithActorTestKit(ConfigFactory.l val edgeLowExpiryDelta = ExtraEdge(nodeParams.nodeId, nextNodeId, scid, 600_000 msat, 100, CltvExpiryDelta(36), 1 msat, None) val toResolve = Seq( // We don't allow paying blinded routes to ourselves. - BlindedRouteCreation.createBlindedRouteWithoutHops(nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, + BlindedRouteCreation.createBlindedRouteFromHops(Nil, nodeParams.nodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes with low fees. - BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, + BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), edgeLowFees.targetNodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes with low cltv_expiry_delta. - BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, + BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowExpiryDelta.targetNodeId, HopRelayParams.FromHint(edgeLowExpiryDelta))), edgeLowExpiryDelta.targetNodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes with low fees, even when the next node seems to be a wallet node. BlindedRouteCreation.createBlindedRouteToWallet(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees)), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route, // We reject blinded routes that cannot be decrypted. - BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(firstPathKey = randomKey().publicKey) + BlindedRouteCreation.createBlindedRouteFromHops(Seq(ChannelHop(scid, nodeParams.nodeId, edgeLowFees.targetNodeId, HopRelayParams.FromHint(edgeLowFees))), edgeLowFees.targetNodeId, hex"deadbeef", 1 msat, CltvExpiry(800_000)).route.copy(firstPathKey = randomKey().publicKey) ).map(r => PaymentBlindedRoute(r, PaymentInfo(1_000_000 msat, 2500, CltvExpiryDelta(300), 1 msat, 500_000_000 msat, Features.empty))) resolver ! Resolve(probe.ref, toResolve) // The routes with low fees or expiry require resolving the next node. diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala index 9d21c490a4..f216c580be 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BaseRouterSpec.scala @@ -309,7 +309,7 @@ object BaseRouterSpec { val offer = Offer(None, Some("Bolt12 r0cks"), recipientKey.publicKey, features, Block.RegtestGenesisBlock.hash) val invoiceRequest = InvoiceRequest(offer, amount, 1, features, randomKey(), Block.RegtestGenesisBlock.hash) val blindedRoutes = paths.map(hops => { - val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(hops, pathId, 1 msat, routeExpiry).route + val blindedRoute = BlindedRouteCreation.createBlindedRouteFromHops(hops, hops.last.nextNodeId, pathId, 1 msat, routeExpiry).route val paymentInfo = BlindedRouteCreation.aggregatePaymentInfo(amount, hops, Channel.MIN_CLTV_EXPIRY_DELTA) PaymentBlindedRoute(blindedRoute, paymentInfo) }) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala index 55833d5feb..e0cef13f5a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BlindedRouteCreationSpec.scala @@ -30,7 +30,7 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { test("create blinded route without hops") { val a = randomKey() val pathId = randomBytes32() - val route = createBlindedRouteWithoutHops(a.publicKey, pathId, 1 msat, CltvExpiry(500)) + val route = createBlindedRouteFromHops(Nil, a.publicKey, pathId, 1 msat, CltvExpiry(500)) assert(route.route.firstNodeId == EncodedNodeId(a.publicKey)) assert(route.route.encryptedPayloads.length == 1) assert(route.route.firstPathKey == route.lastPathKey) @@ -47,7 +47,7 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { ChannelHop(scid1, a.publicKey, b.publicKey, HopRelayParams.FromAnnouncement(makeUpdateShort(scid1, a.publicKey, b.publicKey, 10 msat, 300, cltvDelta = CltvExpiryDelta(200)))), ChannelHop(scid2, b.publicKey, c.publicKey, HopRelayParams.FromAnnouncement(makeUpdateShort(scid2, b.publicKey, c.publicKey, 20 msat, 150, cltvDelta = CltvExpiryDelta(600)))), ) - val route = createBlindedRouteFromHops(hops, pathId, 1 msat, CltvExpiry(500)) + val route = createBlindedRouteFromHops(hops, c.publicKey, pathId, 1 msat, CltvExpiry(500)) assert(route.route.firstNodeId == EncodedNodeId(a.publicKey)) assert(route.route.encryptedPayloads.length == 3) val Right(decoded1) = RouteBlindingEncryptedDataCodecs.decode(a, route.route.firstPathKey, route.route.encryptedPayloads(0)) @@ -103,7 +103,7 @@ class BlindedRouteCreationSpec extends AnyFunSuite with ParallelTestExecution { ChannelHop(scid4, d.publicKey, e.publicKey, HopRelayParams.FromAnnouncement(makeUpdateShort(scid4, d.publicKey, e.publicKey, 100000 msat, 100000, cltvDelta = CltvExpiryDelta(60000)))), ChannelHop(scid5, e.publicKey, f.publicKey, HopRelayParams.FromAnnouncement(makeUpdateShort(scid5, e.publicKey, f.publicKey, 999999999 msat, 999999999, cltvDelta = CltvExpiryDelta(65000)))), ) - val route = createBlindedRouteFromHops(hops, randomBytes32(), 0 msat, CltvExpiry(0)) + val route = createBlindedRouteFromHops(hops, f.publicKey, randomBytes32(), 0 msat, CltvExpiry(0)) assert(route.route.encryptedPayloads.dropRight(1).forall(_.length == 54)) } diff --git a/eclair-node/src/test/resources/api/received-success b/eclair-node/src/test/resources/api/received-success index c70299259c..66cebe4d02 100644 --- a/eclair-node/src/test/resources/api/received-success +++ b/eclair-node/src/test/resources/api/received-success @@ -1 +1 @@ -{"invoice":{"prefix":"lnbc","timestamp":1496314658,"nodeId":"03779dc8b593b74509fab7c8accebc7a9b91d85d9df456d5b885464a34e5751d52","serialized":"lnbc2500u1pvjluezsp5cssgls5lpvunj7zallxsn3v8g3f9wqfs75hsdmkrtxwgkafers0spp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5xysxxatsyp3k7enxv4jsxqzpuaztrnwngzn3kdzw5hydlzf03qdgm2hdq27cqv3agm2awhz5se903vruatfhq77w3ls4evs3ch9zw97j25emudupq63nyw24cg27h2rspzma2k0","description":"1 cup coffee","paymentHash":"0001020304050607080900010203040506070809000102030405060708090102","expiry":60,"amount":250000000,"features":{"activated":{},"unknown":[]},"routingInfo":[]},"paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","paymentType":"Standard","createdAt":{"iso":"1970-01-01T00:00:00.042Z","unix":0},"status":{"type":"received","amount":42,"receivedAt":{"iso":"2021-10-05T13:12:23.777Z","unix":1633439543}}} \ No newline at end of file +{"invoice":{"prefix":"lnbc","timestamp":1496314658,"nodeId":"03779dc8b593b74509fab7c8accebc7a9b91d85d9df456d5b885464a34e5751d52","serialized":"lnbc2500u1pvjluezsp5cssgls5lpvunj7zallxsn3v8g3f9wqfs75hsdmkrtxwgkafers0spp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5xysxxatsyp3k7enxv4jsxqzpuaztrnwngzn3kdzw5hydlzf03qdgm2hdq27cqv3agm2awhz5se903vruatfhq77w3ls4evs3ch9zw97j25emudupq63nyw24cg27h2rspzma2k0","description":"1 cup coffee","paymentHash":"0001020304050607080900010203040506070809000102030405060708090102","expiry":60,"amount":250000000,"features":{"activated":{},"unknown":[]},"routingInfo":[]},"paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","paymentType":"Standard","createdAt":{"iso":"1970-01-01T00:00:00.042Z","unix":0},"status":{"type":"received","virtualAmount":42,"realAmount":42,"receivedAt":{"iso":"2021-10-05T13:12:23.777Z","unix":1633439543}}} \ No newline at end of file diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index 348d4b7d07..d191e462a1 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -908,7 +908,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM val defaultPayment = IncomingStandardPayment(Bolt11Invoice.fromString(invoice).get, ByteVector32.One, PaymentType.Standard, 42 unixms, IncomingPaymentStatus.Pending) val eclair = mock[Eclair] val received = randomBytes32() - eclair.receivedInfo(received)(any) returns Future.successful(Some(defaultPayment.copy(status = IncomingPaymentStatus.Received(42 msat, TimestampMilli(1633439543777L))))) + eclair.receivedInfo(received)(any) returns Future.successful(Some(defaultPayment.copy(status = IncomingPaymentStatus.Received(42 msat, 42 msat, TimestampMilli(1633439543777L))))) val mockService = new MockService(eclair) Post("/getreceivedinfo", FormData("paymentHash" -> received.toHex).toEntity) ~> @@ -1198,8 +1198,8 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM system.eventStream.publish(ptrel) wsClient.expectMessage(expectedSerializedPtrel) - val precv = PaymentReceived(ByteVector32.Zeroes, Seq(PaymentReceived.PartialPayment(21 msat, ByteVector32.Zeroes, TimestampMilli(1553784963659L)))) - val expectedSerializedPrecv = """{"type":"payment-received","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","parts":[{"amount":21,"fromChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":{"iso":"2019-03-28T14:56:03.659Z","unix":1553784963}}]}""" + val precv = PaymentReceived(ByteVector32.Zeroes, Seq(PaymentReceived.PartialPayment(21 msat, 21 msat, ByteVector32.Zeroes, TimestampMilli(1553784963659L)))) + val expectedSerializedPrecv = """{"type":"payment-received","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","parts":[{"virtualAmount":21,"realAmount":21,"fromChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":{"iso":"2019-03-28T14:56:03.659Z","unix":1553784963}}]}""" assert(serialization.write(precv) == expectedSerializedPrecv) system.eventStream.publish(precv) wsClient.expectMessage(expectedSerializedPrecv) From 223762e12f57f859d69117055d0a1bb494e54523 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Fri, 14 Feb 2025 09:31:01 +0100 Subject: [PATCH 2/6] Fuzz test --- eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala index 8929d2fc5e..8e7b1b5fb7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/PackageSpec.scala @@ -20,6 +20,7 @@ import fr.acinq.bitcoin.BitcoinError.ChainHashMismatch import fr.acinq.bitcoin.scalacompat.Crypto.PrivateKey import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, Crypto, Script, TxHash, TxId, addressToPublicKeyScript} import fr.acinq.bitcoin.{Base58, Base58Check, Bech32} +import org.scalatest.Tag import org.scalatest.funsuite.AnyFunSuite import scodec.bits._ @@ -114,14 +115,14 @@ class PackageSpec extends AnyFunSuite { assert(ShortChannelId(Long.MaxValue) < ShortChannelId(Long.MaxValue + 1)) } - test("node fees") { + test("node fees", Tag("fuzzy")) { val rng = new scala.util.Random() for (_ <- 1 to 100) { val amount = rng.nextLong(1_000_000_000_000L) msat val baseFee = rng.nextLong(10_000) msat val proportionalFee = rng.nextLong(5_000) val amountWithFees = amount + nodeFee(baseFee, proportionalFee, amount) - assert(amountAfterFee(baseFee, proportionalFee, amountWithFees) == amount) + assert(amountAfterFee(baseFee, proportionalFee, amountWithFees) == amount, s"amount=$amount baseFee=$baseFee proportionalFee=$proportionalFee") } } From 8b4920b46b8600b4525af39510f8274564f90a11 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Fri, 14 Feb 2025 16:12:03 +0100 Subject: [PATCH 3/6] Remove virtual amount --- .../fr/acinq/eclair/db/DbEventHandler.scala | 2 +- .../fr/acinq/eclair/db/DualDatabases.scala | 12 ++-- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 11 ++-- .../fr/acinq/eclair/db/pg/PgAuditDb.scala | 31 +++------ .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 59 +++++++---------- .../eclair/db/sqlite/SqliteAuditDb.scala | 31 +++------ .../eclair/db/sqlite/SqlitePaymentsDb.scala | 66 ++++++++----------- .../acinq/eclair/payment/PaymentEvents.scala | 5 +- .../eclair/payment/offer/OfferManager.scala | 2 +- .../payment/receive/MultiPartHandler.scala | 49 ++++++++------ .../payment/receive/MultiPartPaymentFSM.scala | 15 +++-- .../eclair/payment/relay/NodeRelay.scala | 8 +-- .../fr/acinq/eclair/db/AuditDbSpec.scala | 4 +- .../fr/acinq/eclair/db/PaymentsDbSpec.scala | 36 +++++----- .../integration/PaymentIntegrationSpec.scala | 22 +++---- .../eclair/payment/MultiPartHandlerSpec.scala | 58 ++++++++-------- .../payment/MultiPartPaymentFSMSpec.scala | 2 +- .../payment/PostRestartHtlcCleanerSpec.scala | 2 +- .../payment/offer/OfferManagerSpec.scala | 6 +- .../payment/receive/InvoicePurgerSpec.scala | 4 +- .../src/test/resources/api/received-success | 2 +- .../fr/acinq/eclair/api/ApiServiceSpec.scala | 6 +- 22 files changed, 191 insertions(+), 242 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala index 7873bdef3f..9356799054 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala @@ -68,7 +68,7 @@ class DbEventHandler(nodeParams: NodeParams) extends Actor with DiagnosticActorL PaymentMetrics.PaymentFailed.withTag(PaymentTags.Direction, PaymentTags.Directions.Sent).increment() case e: PaymentReceived => - PaymentMetrics.PaymentAmount.withTag(PaymentTags.Direction, PaymentTags.Directions.Received).record(e.realAmount.truncateToSatoshi.toLong) + PaymentMetrics.PaymentAmount.withTag(PaymentTags.Direction, PaymentTags.Directions.Received).record(e.amount.truncateToSatoshi.toLong) PaymentMetrics.PaymentParts.withTag(PaymentTags.Direction, PaymentTags.Directions.Received).record(e.parts.length) auditDb.add(e) e.parts.foreach(p => channelsDb.updateChannelMeta(p.fromChannelId, ChannelEvent.EventType.PaymentReceived)) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala index d50acd2f8e..7fb7e56eba 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DualDatabases.scala @@ -319,14 +319,14 @@ case class DualPaymentsDb(primary: PaymentsDb, secondary: PaymentsDb) extends Pa primary.addIncomingPayment(pr, preimage, paymentType) } - override def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = { - runAsync(secondary.receiveIncomingPayment(paymentHash, virtualAmount, realAmount, receivedAt)) - primary.receiveIncomingPayment(paymentHash, virtualAmount, realAmount, receivedAt) + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = { + runAsync(secondary.receiveIncomingPayment(paymentHash, amount, receivedAt)) + primary.receiveIncomingPayment(paymentHash, amount, receivedAt) } - override def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = { - runAsync(secondary.receiveIncomingOfferPayment(pr, preimage, virtualAmount, realAmount, receivedAt, paymentType)) - primary.receiveIncomingOfferPayment(pr, preimage, virtualAmount, realAmount, receivedAt, paymentType) + override def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = { + runAsync(secondary.receiveIncomingOfferPayment(pr, preimage, amount, receivedAt, paymentType)) + primary.receiveIncomingOfferPayment(pr, preimage, amount, receivedAt, paymentType) } override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index a10af427ab..f5fcdfc3fa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -36,13 +36,13 @@ trait IncomingPaymentsDb { * Mark an incoming payment as received (paid). The received amount may exceed the invoice amount. * If there was no matching invoice in the DB, this will return false. */ - def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Boolean + def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now()): Boolean /** * Add a new incoming offer payment as received. * If the invoice is already paid, adds `amount` to the amount paid. */ - def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now(), paymentType: String = PaymentType.Blinded): Unit + def receiveIncomingOfferPayment(pr: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli = TimestampMilli.now(), paymentType: String = PaymentType.Blinded): Unit /** Get information about the incoming payment (paid or not) for the given payment hash, if any. */ def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] @@ -150,11 +150,10 @@ object IncomingPaymentStatus { /** * Payment has been successfully received. * - * @param virtualAmount amount of the payment received, in milli-satoshis (may exceed the invoice amount). - * @param realAmount amount of the payment received, in milli-satoshis (may be less or more than the invoice amount). - * @param receivedAt absolute time in milli-seconds since UNIX epoch when the payment was received. + * @param amount amount of the payment received, in milli-satoshis (may exceed the invoice amount). + * @param receivedAt absolute time in milli-seconds since UNIX epoch when the payment was received. */ - case class Received(virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, receivedAt: TimestampMilli) extends IncomingPaymentStatus + case class Received(amount: MilliSatoshi, receivedAt: TimestampMilli) extends IncomingPaymentStatus } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index 49f9c79035..75cd8c6b83 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -36,7 +36,7 @@ import javax.sql.DataSource object PgAuditDb { val DB_NAME = "audit" - val CURRENT_VERSION = 13 + val CURRENT_VERSION = 12 } class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { @@ -114,20 +114,12 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON audit.transactions_published(channel_id)") } - def migration1213(statement: Statement): Unit = { - statement.executeUpdate("ALTER TABLE audit.received RENAME TO received_old") - statement.executeUpdate("CREATE TABLE audit.received (virtual_amount_msat BIGINT NOT NULL, real_amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("INSERT INTO audit.received SELECT amount_msat, amount_msat, payment_hash, from_channel_id, timestamp FROM audit.received_old") - statement.executeUpdate("DROP TABLE audit.received_old") - statement.executeUpdate("CREATE INDEX received_timestamp_idx ON audit.received(timestamp)") - } - getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE SCHEMA audit") statement.executeUpdate("CREATE TABLE audit.sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") - statement.executeUpdate("CREATE TABLE audit.received (virtual_amount_msat BIGINT NOT NULL, real_amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE audit.received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE TABLE audit.relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE TABLE audit.relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE TABLE audit.channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") @@ -157,7 +149,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON audit.transactions_published(channel_id)") statement.executeUpdate("CREATE INDEX transactions_published_timestamp_idx ON audit.transactions_published(timestamp)") statement.executeUpdate("CREATE INDEX transactions_confirmed_timestamp_idx ON audit.transactions_confirmed(timestamp)") - case Some(v@(4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12)) => + case Some(v@(4 | 5 | 6 | 7 | 8 | 9 | 10 | 11)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 5) { migration45(statement) @@ -183,9 +175,6 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { if (v < 12) { migration1112(statement) } - if (v < 13) { - migration1213(statement) - } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -231,13 +220,12 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def add(e: PaymentReceived): Unit = withMetrics("audit/add-payment-received", DbBackends.Postgres) { inTransaction { pg => - using(pg.prepareStatement("INSERT INTO audit.received VALUES (?, ?, ?, ?, ?)")) { statement => + using(pg.prepareStatement("INSERT INTO audit.received VALUES (?, ?, ?, ?)")) { statement => e.parts.foreach(p => { - statement.setLong(1, p.virtualAmount.toLong) - statement.setLong(2, p.realAmount.toLong) - statement.setString(3, e.paymentHash.toHex) - statement.setString(4, p.fromChannelId.toHex) - statement.setTimestamp(5, p.timestamp.toSqlTimestamp) + statement.setLong(1, p.amount.toLong) + statement.setString(2, e.paymentHash.toHex) + statement.setString(3, p.fromChannelId.toHex) + statement.setTimestamp(4, p.timestamp.toSqlTimestamp) statement.addBatch() }) statement.executeBatch() @@ -416,8 +404,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { .foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) => val paymentHash = rs.getByteVector32FromHex("payment_hash") val part = PaymentReceived.PartialPayment( - MilliSatoshi(rs.getLong("virtual_amount_msat")), - MilliSatoshi(rs.getLong("real_amount_msat")), + MilliSatoshi(rs.getLong("amount_msat")), rs.getByteVector32FromHex("from_channel_id"), TimestampMilli.fromSqlTimestamp(rs.getTimestamp("timestamp"))) val received = receivedByHash.get(paymentHash) match { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index c16076b1d8..d78883a296 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -36,7 +36,7 @@ import scala.util.{Failure, Success, Try} object PgPaymentsDb { val DB_NAME = "payments" - val CURRENT_VERSION = 9 + val CURRENT_VERSION = 8 } class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb with Logging { @@ -77,19 +77,11 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON payments.sent(offer_id)") } - def migration89(statement: Statement): Unit = { - statement.executeUpdate("ALTER TABLE payments.received RENAME TO received_old") - statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, path_ids BYTEA, payment_request TEXT NOT NULL, virtual_received_msat BIGINT, real_received_msat BIGINT, created_at TIMESTAMP WITH TIME ZONE NOT NULL, expire_at TIMESTAMP WITH TIME ZONE NOT NULL, received_at TIMESTAMP WITH TIME ZONE)") - statement.executeUpdate("INSERT INTO payments.received SELECT payment_hash, payment_type, payment_preimage, path_ids, payment_request, received_msat, received_msat, created_at, expire_at, received_at FROM payments.received_old") - statement.executeUpdate("DROP TABLE payments.received_old") - statement.executeUpdate("CREATE INDEX received_created_idx ON payments.received(created_at)") - } - getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE SCHEMA payments") - statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, path_ids BYTEA, payment_request TEXT NOT NULL, virtual_received_msat BIGINT, real_received_msat BIGINT, created_at TIMESTAMP WITH TIME ZONE NOT NULL, expire_at TIMESTAMP WITH TIME ZONE NOT NULL, received_at TIMESTAMP WITH TIME ZONE)") + statement.executeUpdate("CREATE TABLE payments.received (payment_hash TEXT NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage TEXT NOT NULL, path_ids BYTEA, payment_request TEXT NOT NULL, received_msat BIGINT, created_at TIMESTAMP WITH TIME ZONE NOT NULL, expire_at TIMESTAMP WITH TIME ZONE NOT NULL, received_at TIMESTAMP WITH TIME ZONE)") statement.executeUpdate("CREATE TABLE payments.sent (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash TEXT NOT NULL, payment_preimage TEXT, payment_type TEXT NOT NULL, amount_msat BIGINT NOT NULL, fees_msat BIGINT, recipient_amount_msat BIGINT NOT NULL, recipient_node_id TEXT NOT NULL, payment_request TEXT, offer_id TEXT, payer_key TEXT, payment_route BYTEA, failures BYTEA, created_at TIMESTAMP WITH TIME ZONE NOT NULL, completed_at TIMESTAMP WITH TIME ZONE)") statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON payments.sent(parent_id)") @@ -97,7 +89,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON payments.sent(offer_id)") statement.executeUpdate("CREATE INDEX sent_created_idx ON payments.sent(created_at)") statement.executeUpdate("CREATE INDEX received_created_idx ON payments.received(created_at)") - case Some(v@(4 | 5 | 6 | 7 | 8)) => + case Some(v@(4 | 5 | 6 | 7)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 5) { migration45(statement) @@ -111,9 +103,6 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit if (v < 8) { migration78(statement) } - if (v < 9) { - migration89(statement) - } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -279,32 +268,30 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - override def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Postgres) { + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("UPDATE payments.received SET (virtual_received_msat, real_received_msat, received_at) = (? + COALESCE(virtual_received_msat, 0), ? + COALESCE(real_received_msat, 0), ?) WHERE payment_hash = ?")) { update => - update.setLong(1, virtualAmount.toLong) - update.setLong(2, realAmount.toLong) - update.setTimestamp(3, receivedAt.toSqlTimestamp) - update.setString(4, paymentHash.toHex) + using(pg.prepareStatement("UPDATE payments.received SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => + update.setLong(1, amount.toLong) + update.setTimestamp(2, receivedAt.toSqlTimestamp) + update.setString(3, paymentHash.toHex) val updated = update.executeUpdate() updated > 0 } } } - override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Postgres) { + override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("INSERT INTO payments.received (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, virtual_received_msat, real_received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" + - "ON CONFLICT (payment_hash) DO UPDATE SET (virtual_received_msat, real_received_msat, received_at) = (payments.received.virtual_received_msat + EXCLUDED.virtual_received_msat, payments.received.real_received_msat + EXCLUDED.real_received_msat, EXCLUDED.received_at)")) { statement => + using(pg.prepareStatement("INSERT INTO payments.received (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)" + + "ON CONFLICT (payment_hash) DO UPDATE SET (received_msat, received_at) = (payments.received.received_msat + EXCLUDED.received_msat, EXCLUDED.received_at)")) { statement => statement.setString(1, invoice.paymentHash.toHex) statement.setString(2, preimage.toHex) statement.setString(3, paymentType) statement.setString(4, invoice.toString) statement.setTimestamp(5, invoice.createdAt.toSqlTimestamp) statement.setTimestamp(6, (invoice.createdAt + invoice.relativeExpiry.toSeconds).toSqlTimestamp) - statement.setLong(7, virtualAmount.toLong) - statement.setLong(8, realAmount.toLong) - statement.setTimestamp(9, receivedAt.toSqlTimestamp) + statement.setLong(7, amount.toLong) + statement.setTimestamp(8, receivedAt.toSqlTimestamp) statement.executeUpdate() } } @@ -317,10 +304,10 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit val createdAt = TimestampMilli.fromSqlTimestamp(rs.getTimestamp("created_at")) Invoice.fromString(invoice) match { case Success(invoice: Bolt11Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) Some(IncomingStandardPayment(invoice, preimage, paymentType, createdAt, status)) case Success(invoice: MinimalBolt12Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getTimestampNullable("received_at").map(TimestampMilli.fromSqlTimestamp)) Some(IncomingBlindedPayment(invoice, preimage, paymentType, createdAt, status)) case _ => logger.error(s"could not parse DB invoice=$invoice, this should not happen") @@ -328,11 +315,11 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit } } - private def buildIncomingPaymentStatus(virtualAmount_opt: Option[MilliSatoshi], realAmount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { - (virtualAmount_opt, realAmount_opt) match { - case (Some(virtualAmount), Some(realAmount)) => IncomingPaymentStatus.Received(virtualAmount, realAmount, receivedAt_opt.getOrElse(0 unixms)) - case _ if invoice.isExpired() => IncomingPaymentStatus.Expired - case _ => IncomingPaymentStatus.Pending + private def buildIncomingPaymentStatus(amount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { + amount_opt match { + case Some(amount) => IncomingPaymentStatus.Received(amount, receivedAt_opt.getOrElse(0 unixms)) + case None if invoice.isExpired() => IncomingPaymentStatus.Expired + case None => IncomingPaymentStatus.Pending } } @@ -379,7 +366,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-received", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE virtual_received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setTimestamp(1, from.toSqlTimestamp) statement.setTimestamp(2, to.toSqlTimestamp) statement.executeQuery().flatMap(parseIncomingPayment).toSeq @@ -389,7 +376,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listPendingIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-pending", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => + using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => statement.setTimestamp(1, from.toSqlTimestamp) statement.setTimestamp(2, to.toSqlTimestamp) statement.setTimestamp(3, Timestamp.from(Instant.now())) @@ -400,7 +387,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit override def listExpiredIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-expired", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(pg.prepareStatement(limited("SELECT * FROM payments.received WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setTimestamp(1, from.toSqlTimestamp) statement.setTimestamp(2, to.toSqlTimestamp) statement.setTimestamp(3, Timestamp.from(Instant.now())) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index ede02bdb04..c8b8f070df 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -34,7 +34,7 @@ import java.util.UUID object SqliteAuditDb { val DB_NAME = "audit" - val CURRENT_VERSION = 10 + val CURRENT_VERSION = 9 } class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { @@ -114,18 +114,10 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON transactions_published(channel_id)") } - def migration910(statement: Statement): Unit = { - statement.executeUpdate("ALTER TABLE received RENAME TO received_old") - statement.executeUpdate("CREATE TABLE received (virtual_amount_msat INTEGER NOT NULL, real_amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("INSERT INTO received SELECT amount_msat, amount_msat, payment_hash, from_channel_id, timestamp FROM received_old") - statement.executeUpdate("DROP TABLE received_old") - statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)") - } - getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE TABLE sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE received (virtual_amount_msat INTEGER NOT NULL, real_amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, next_node_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") statement.executeUpdate("CREATE TABLE channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)") @@ -153,7 +145,7 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX transactions_published_channel_id_idx ON transactions_published(channel_id)") statement.executeUpdate("CREATE INDEX transactions_published_timestamp_idx ON transactions_published(timestamp)") statement.executeUpdate("CREATE INDEX transactions_confirmed_timestamp_idx ON transactions_confirmed(timestamp)") - case Some(v@(1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9)) => + case Some(v@(1 | 2 | 3 | 4 | 5 | 6 | 7 | 8)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 2) { migration12(statement) @@ -179,9 +171,6 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { if (v < 9) { migration89(statement) } - if (v < 10) { - migration910(statement) - } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -221,13 +210,12 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { } override def add(e: PaymentReceived): Unit = withMetrics("audit/add-payment-received", DbBackends.Sqlite) { - using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?, ?)")) { statement => + using(sqlite.prepareStatement("INSERT INTO received VALUES (?, ?, ?, ?)")) { statement => e.parts.foreach(p => { - statement.setLong(1, p.virtualAmount.toLong) - statement.setLong(2, p.realAmount.toLong) - statement.setBytes(3, e.paymentHash.toArray) - statement.setBytes(4, p.fromChannelId.toArray) - statement.setLong(5, p.timestamp.toLong) + statement.setLong(1, p.amount.toLong) + statement.setBytes(2, e.paymentHash.toArray) + statement.setBytes(3, p.fromChannelId.toArray) + statement.setLong(4, p.timestamp.toLong) statement.addBatch() }) statement.executeBatch() @@ -386,8 +374,7 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging { .foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) => val paymentHash = rs.getByteVector32("payment_hash") val part = PaymentReceived.PartialPayment( - MilliSatoshi(rs.getLong("virtual_amount_msat")), - MilliSatoshi(rs.getLong("real_amount_msat")), + MilliSatoshi(rs.getLong("amount_msat")), rs.getByteVector32("from_channel_id"), TimestampMilli(rs.getLong("timestamp"))) val received = receivedByHash.get(paymentHash) match { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 96ea34467d..d08008388a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -108,17 +108,9 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON sent_payments(offer_id)") } - def migration67(statement: Statement): Unit = { - statement.executeUpdate("ALTER TABLE received_payments RENAME TO received_payments_old") - statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage BLOB NOT NULL, path_ids BLOB, payment_request TEXT NOT NULL, virtual_received_msat INTEGER, real_received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") - statement.executeUpdate("INSERT INTO received_payments SELECT payment_hash, payment_type, payment_preimage, path_ids, payment_request, received_msat, received_msat, created_at, expire_at, received_at FROM received_payments_old") - statement.executeUpdate("DROP TABLE received_payments_old") - statement.executeUpdate("CREATE INDEX received_created_idx ON received_payments(created_at)") - } - getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage BLOB NOT NULL, path_ids BLOB, payment_request TEXT NOT NULL, virtual_received_msat INTEGER, real_received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") + statement.executeUpdate("CREATE TABLE received_payments (payment_hash BLOB NOT NULL PRIMARY KEY, payment_type TEXT NOT NULL, payment_preimage BLOB NOT NULL, path_ids BLOB, payment_request TEXT NOT NULL, received_msat INTEGER, created_at INTEGER NOT NULL, expire_at INTEGER NOT NULL, received_at INTEGER)") statement.executeUpdate("CREATE TABLE sent_payments (id TEXT NOT NULL PRIMARY KEY, parent_id TEXT NOT NULL, external_id TEXT, payment_hash BLOB NOT NULL, payment_preimage BLOB, payment_type TEXT NOT NULL, amount_msat INTEGER NOT NULL, fees_msat INTEGER, recipient_amount_msat INTEGER NOT NULL, recipient_node_id BLOB NOT NULL, payment_request TEXT, offer_id BLOB, payer_key BLOB, payment_route BLOB, failures BLOB, created_at INTEGER NOT NULL, completed_at INTEGER)") statement.executeUpdate("CREATE INDEX sent_parent_id_idx ON sent_payments(parent_id)") @@ -126,7 +118,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { statement.executeUpdate("CREATE INDEX sent_payment_offer_idx ON sent_payments(offer_id)") statement.executeUpdate("CREATE INDEX sent_created_idx ON sent_payments(created_at)") statement.executeUpdate("CREATE INDEX received_created_idx ON received_payments(created_at)") - case Some(v@(1 | 2 | 3 | 4 | 5 | 6)) => + case Some(v@(1 | 2 | 3 | 4 | 5)) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") if (v < 2) { migration12(statement) @@ -143,9 +135,6 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { if (v < 6) { migration56(statement) } - if (v < 7) { - migration67(statement) - } case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -290,35 +279,32 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } } - override def receiveIncomingPayment(paymentHash: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Sqlite) { - using(sqlite.prepareStatement("UPDATE received_payments SET (virtual_received_msat, real_received_msat, received_at) = (? + COALESCE(virtual_received_msat, 0), ? + COALESCE(real_received_msat, 0), ?) WHERE payment_hash = ?")) { update => - update.setLong(1, virtualAmount.toLong) - update.setLong(2, realAmount.toLong) - update.setLong(3, receivedAt.toLong) - update.setBytes(4, paymentHash.toArray) + override def receiveIncomingPayment(paymentHash: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli): Boolean = withMetrics("payments/receive-incoming", DbBackends.Sqlite) { + using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => + update.setLong(1, amount.toLong) + update.setLong(2, receivedAt.toLong) + update.setBytes(3, paymentHash.toArray) val updated = update.executeUpdate() updated > 0 } } - override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, virtualAmount: fr.acinq.eclair.MilliSatoshi, realAmount: fr.acinq.eclair.MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Sqlite) { - if (using(sqlite.prepareStatement("INSERT OR IGNORE INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, virtual_received_msat, real_received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + override def receiveIncomingOfferPayment(invoice: MinimalBolt12Invoice, preimage: ByteVector32, amount: MilliSatoshi, receivedAt: TimestampMilli, paymentType: String): Unit = withMetrics("payments/receive-incoming-offer", DbBackends.Sqlite) { + if (using(sqlite.prepareStatement("INSERT OR IGNORE INTO received_payments (payment_hash, payment_preimage, payment_type, payment_request, created_at, expire_at, received_msat, received_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setBytes(1, invoice.paymentHash.toArray) statement.setBytes(2, preimage.toArray) statement.setString(3, paymentType) statement.setString(4, invoice.toString) statement.setLong(5, invoice.createdAt.toTimestampMilli.toLong) statement.setLong(6, (invoice.createdAt + invoice.relativeExpiry).toLong.seconds.toMillis) - statement.setLong(7, virtualAmount.toLong) - statement.setLong(8, realAmount.toLong) - statement.setLong(9, receivedAt.toLong) + statement.setLong(7, amount.toLong) + statement.setLong(8, receivedAt.toLong) statement.executeUpdate() } == 0) { - using(sqlite.prepareStatement("UPDATE received_payments SET (virtual_received_msat, real_received_msat, received_at) = (virtual_received_msat + ?, real_received_msat + ?, ?) WHERE payment_hash = ?")) { statement => - statement.setLong(1, virtualAmount.toLong) - statement.setLong(2, realAmount.toLong) - statement.setLong(3, receivedAt.toLong) - statement.setBytes(4, invoice.paymentHash.toArray) + using(sqlite.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (received_msat + ?, ?) WHERE payment_hash = ?")) { statement => + statement.setLong(1, amount.toLong) + statement.setLong(2, receivedAt.toLong) + statement.setBytes(3, invoice.paymentHash.toArray) statement.executeUpdate() } } @@ -331,10 +317,10 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { val createdAt = TimestampMilli(rs.getLong("created_at")) Invoice.fromString(invoice) match { case Success(invoice: Bolt11Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) Some(IncomingStandardPayment(invoice, preimage, paymentType, createdAt, status)) case Success(invoice: MinimalBolt12Invoice) => - val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("virtual_received_msat"), rs.getMilliSatoshiNullable("real_received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) + val status = buildIncomingPaymentStatus(rs.getMilliSatoshiNullable("received_msat"), invoice, rs.getLongNullable("received_at").map(TimestampMilli(_))) Some(IncomingBlindedPayment(invoice, preimage, paymentType, createdAt, status)) case _ => logger.error(s"could not parse DB invoice=$invoice, this should not happen") @@ -342,11 +328,11 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } } - private def buildIncomingPaymentStatus(virtualAmount_opt: Option[MilliSatoshi], realAmount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { - (virtualAmount_opt, realAmount_opt) match { - case (Some(virtualAmount), Some(realAmount)) => IncomingPaymentStatus.Received(virtualAmount, realAmount, receivedAt_opt.getOrElse(0 unixms)) - case _ if invoice.isExpired() => IncomingPaymentStatus.Expired - case _ => IncomingPaymentStatus.Pending + private def buildIncomingPaymentStatus(amount_opt: Option[MilliSatoshi], invoice: Invoice, receivedAt_opt: Option[TimestampMilli]): IncomingPaymentStatus = { + amount_opt match { + case Some(amount) => IncomingPaymentStatus.Received(amount, receivedAt_opt.getOrElse(0 unixms)) + case None if invoice.isExpired() => IncomingPaymentStatus.Expired + case None => IncomingPaymentStatus.Pending } } @@ -382,7 +368,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } override def listReceivedIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-received", DbBackends.Sqlite) { - using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE virtual_received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setLong(1, from.toLong) statement.setLong(2, to.toLong) statement.executeQuery().flatMap(parseIncomingPayment).toSeq @@ -390,7 +376,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } override def listPendingIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-pending", DbBackends.Sqlite) { - using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => + using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at > ? ORDER BY created_at", paginated_opt))) { statement => statement.setLong(1, from.toLong) statement.setLong(2, to.toLong) statement.setLong(3, TimestampMilli.now().toLong) @@ -399,7 +385,7 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { } override def listExpiredIncomingPayments(from: TimestampMilli, to: TimestampMilli, paginated_opt: Option[Paginated]): Seq[IncomingPayment] = withMetrics("payments/list-incoming-expired", DbBackends.Sqlite) { - using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE virtual_received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => + using(sqlite.prepareStatement(limited("SELECT * FROM received_payments WHERE received_msat IS NULL AND created_at > ? AND created_at < ? AND expire_at < ? ORDER BY created_at", paginated_opt))) { statement => statement.setLong(1, from.toLong) statement.setLong(2, to.toLong) statement.setLong(3, TimestampMilli.now().toLong) @@ -411,5 +397,5 @@ class SqlitePaymentsDb(val sqlite: Connection) extends PaymentsDb with Logging { object SqlitePaymentsDb { val DB_NAME = "payments" - val CURRENT_VERSION = 7 + val CURRENT_VERSION = 6 } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala index e467bf030a..c070d9fb7e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/PaymentEvents.scala @@ -120,14 +120,13 @@ object PaymentRelayed { case class PaymentReceived(paymentHash: ByteVector32, parts: Seq[PaymentReceived.PartialPayment]) extends PaymentEvent { require(parts.nonEmpty, "must have at least one payment part") - val virtualAmount: MilliSatoshi = parts.map(_.virtualAmount).sum - val realAmount: MilliSatoshi = parts.map(_.realAmount).sum + val amount: MilliSatoshi = parts.map(_.amount).sum val timestamp: TimestampMilli = parts.map(_.timestamp).max // we use max here because we fulfill the payment only once we received all the parts } object PaymentReceived { - case class PartialPayment(virtualAmount: MilliSatoshi, realAmount: MilliSatoshi, fromChannelId: ByteVector32, timestamp: TimestampMilli = TimestampMilli.now()) + case class PartialPayment(amount: MilliSatoshi, fromChannelId: ByteVector32, timestamp: TimestampMilli = TimestampMilli.now()) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala index 06c9242aea..b31b4849a2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala @@ -267,7 +267,7 @@ object OfferManager { case AcceptPayment(additionalTlvs, customTlvs) => val minimalInvoice = MinimalBolt12Invoice(offer, nodeParams.chainHash, metadata.amount, metadata.quantity, Crypto.sha256(metadata.preimage), metadata.payerKey, metadata.createdAt, additionalTlvs, customTlvs) val incomingPayment = IncomingBlindedPayment(minimalInvoice, metadata.preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment, metadata.hiddenFees) + replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment) Behaviors.stopped case RejectPayment(reason) => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(reason) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 6a3484bc8e..6ab1b38d98 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -62,13 +62,21 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP private var pendingPayments: Map[ByteVector32, (IncomingPayment, ActorRef)] = Map.empty private def addHtlcPart(ctx: ActorContext, add: UpdateAddHtlc, payload: FinalPayload, payment: IncomingPayment): Unit = { - pendingPayments.get(add.paymentHash) match { + val handler = pendingPayments.get(add.paymentHash) match { case Some((_, handler)) => - handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, payload.amount, add) + handler case None => val handler = ctx.actorOf(MultiPartPaymentFSM.props(nodeParams, add.paymentHash, payload.totalAmount, ctx.self)) - handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, payload.amount, add) pendingPayments = pendingPayments + (add.paymentHash -> (payment, handler)) + handler + } + handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, add) + payload match { + case payload: FinalPayload.Blinded if payload.amount - add.amountMsat > 0.msat => + val hiddenFee = payload.amount - add.amountMsat + handler ! MultiPartPaymentFSM.HiddenFeePart(add.paymentHash, hiddenFee, payload.totalAmount) + case _: FinalPayload.Blinded => () + case _: FinalPayload.Standard => () } } @@ -133,7 +141,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP } } - case ProcessBlindedPacket(add, payload, payment, hiddenRelayFees) if doHandle(add.paymentHash) => + case ProcessBlindedPacket(add, payload, payment) if doHandle(add.paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(add.paymentHash))) { validateBlindedPayment(nodeParams, add, payload, payment) match { case Some(cmdFail) => @@ -153,7 +161,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP case MultiPartPaymentFSM.MultiPartPaymentFailed(paymentHash, failure, parts) if doHandle(paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(paymentHash))) { Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, failure.getClass.getSimpleName).increment() - log.warning("payment with paidAmount={} failed ({})", parts.map(_.virtualAmount).sum, failure) + log.warning("payment with paidAmount={} failed ({})", parts.map(_.amount).sum, failure) pendingPayments.get(paymentHash).foreach { case (_, handler: ActorRef) => handler ! PoisonPill } parts.collect { case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(failure), commit = true)) @@ -163,7 +171,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP case s@MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts) if doHandle(paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(paymentHash))) { - log.info("received complete payment for amount={}", parts.map(_.virtualAmount).sum) + log.info("received complete payment for amount={}", parts.map(_.amount).sum) pendingPayments.get(paymentHash).foreach { case (payment: IncomingPayment, handler: ActorRef) => handler ! PoisonPill @@ -177,38 +185,41 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP failure match { case Some(failure) => p match { case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(failure), commit = true)) + case _: MultiPartPaymentFSM.HiddenFeePart => () } case None => p match { // NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more // expensive code path by fetching the preimage from DB. case p: MultiPartPaymentFSM.HtlcPart => db.getIncomingPayment(paymentHash).foreach(record => { - val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.virtualAmount, p.realAmount, p.htlc.channelId) :: Nil) - if (db.receiveIncomingPayment(paymentHash, p.virtualAmount, p.realAmount, received.timestamp)) { + val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) :: Nil) + if (db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp)) { PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true)) ctx.system.eventStream.publish(received) } else { - val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.virtualAmount, nodeParams.currentBlockHeight)), commit = true) + val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true) PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail) } }) + case _: MultiPartPaymentFSM.HiddenFeePart => () } } } case DoFulfill(payment, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, parts)) if doHandle(paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(paymentHash))) { - log.debug("fulfilling payment for virtualAmount={}, realAmount={}", parts.map(_.virtualAmount).sum, parts.map(_.realAmount).sum) - val received = PaymentReceived(paymentHash, parts.map { - case p: MultiPartPaymentFSM.HtlcPart => PaymentReceived.PartialPayment(p.virtualAmount, p.realAmount, p.htlc.channelId) + log.debug("fulfilling payment for amount={}", parts.map(_.amount).sum) + val received = PaymentReceived(paymentHash, parts.flatMap { + case p: MultiPartPaymentFSM.HtlcPart => Some(PaymentReceived.PartialPayment(p.amount, p.htlc.channelId)) + case _: MultiPartPaymentFSM.HiddenFeePart => None }) val recordedInDb = payment match { // Incoming offer payments are not stored in the database until they have been paid. case IncomingBlindedPayment(invoice, preimage, paymentType, _, _) => - db.receiveIncomingOfferPayment(invoice, preimage, received.virtualAmount, received.realAmount, received.timestamp, paymentType) + db.receiveIncomingOfferPayment(invoice, preimage, received.amount, received.timestamp, paymentType) true // Incoming standard payments are already stored and need to be marked as received. case _: IncomingStandardPayment => - db.receiveIncomingPayment(paymentHash, received.virtualAmount, received.realAmount, received.timestamp) + db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp) } if (recordedInDb) { parts.collect { @@ -220,7 +231,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP parts.collect { case p: MultiPartPaymentFSM.HtlcPart => Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, "InvoiceNotFound").increment() - val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.virtualAmount, nodeParams.currentBlockHeight)), commit = true) + val cmdFail = CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(received.amount, nodeParams.currentBlockHeight)), commit = true) PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail) } } @@ -237,7 +248,7 @@ object MultiPartHandler { // @formatter:off case class ProcessPacket(add: UpdateAddHtlc, payload: FinalPayload.Standard, payment_opt: Option[IncomingStandardPayment]) - case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment, hiddenRelayFees: RelayFees) + case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment) case class RejectPacket(add: UpdateAddHtlc, failure: FailureMessage) case class DoFulfill(payment: IncomingPayment, success: MultiPartPaymentFSM.MultiPartPaymentSucceeded) @@ -364,7 +375,7 @@ object MultiPartHandler { // @formatter:off sealed trait Command case class GetIncomingPayment(replyTo: ActorRef) extends Command - case class ProcessPayment(payment: IncomingBlindedPayment, hiddenRelayFees: RelayFees) extends Command + case class ProcessPayment(payment: IncomingBlindedPayment) extends Command case class RejectPayment(reason: String) extends Command // @formatter:on @@ -394,8 +405,8 @@ object MultiPartHandler { private def waitForPayment(context: typed.scaladsl.ActorContext[Command], nodeParams: NodeParams, replyTo: ActorRef, add: UpdateAddHtlc, payload: FinalPayload.Blinded): Behavior[Command] = { Behaviors.receiveMessagePartial { - case ProcessPayment(payment, hiddenRelayFees) => - replyTo ! ProcessBlindedPacket(add, payload, payment, hiddenRelayFees) + case ProcessPayment(payment) => + replyTo ! ProcessBlindedPacket(add, payload, payment) Behaviors.stopped case RejectPayment(reason) => context.log.info("rejecting blinded htlc #{} from channel {}: {}", add.id, add.channelId, reason) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala index a8be934ec4..dd1110f6f4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala @@ -58,7 +58,7 @@ class MultiPartPaymentFSM(nodeParams: NodeParams, paymentHash: ByteVector32, tot if (totalAmount != part.totalAmount) { log.warning("multi-part payment total amount mismatch: previously {}, now {}", totalAmount, part.totalAmount) goto(PAYMENT_FAILED) using PaymentFailed(IncorrectOrUnknownPaymentDetails(part.totalAmount, nodeParams.currentBlockHeight), updatedParts) - } else if (d.paidAmount + part.virtualAmount >= totalAmount) { + } else if (d.paidAmount + part.amount >= totalAmount) { goto(PAYMENT_SUCCEEDED) using PaymentSucceeded(updatedParts) } else { stay() using d.copy(parts = updatedParts) @@ -71,7 +71,7 @@ class MultiPartPaymentFSM(nodeParams: NodeParams, paymentHash: ByteVector32, tot // intermediate nodes will be able to fulfill that htlc anyway. This is a harmless spec violation. case Event(part: PaymentPart, _) => require(part.paymentHash == paymentHash, s"invalid payment hash (expected $paymentHash, received ${part.paymentHash}") - log.info("received extraneous payment part with virtualAmount={}, realAmount={}", part.virtualAmount, part.realAmount) + log.info("received extraneous payment part with amount={}", part.amount) replyTo ! ExtraPaymentReceived(paymentHash, part, None) stay() } @@ -130,15 +130,16 @@ object MultiPartPaymentFSM { /** An incoming payment that we're currently holding until we decide to fulfill or fail it (depending on whether we receive the complete payment). */ sealed trait PaymentPart { def paymentHash: ByteVector32 - def virtualAmount: MilliSatoshi - def realAmount: MilliSatoshi + def amount: MilliSatoshi def totalAmount: MilliSatoshi } /** An incoming HTLC. */ - case class HtlcPart(totalAmount: MilliSatoshi, virtualAmount: MilliSatoshi, htlc: UpdateAddHtlc) extends PaymentPart { + case class HtlcPart(totalAmount: MilliSatoshi, htlc: UpdateAddHtlc) extends PaymentPart { override def paymentHash: ByteVector32 = htlc.paymentHash - override def realAmount: MilliSatoshi = htlc.amountMsat + override def amount: MilliSatoshi = htlc.amountMsat } + /** The fee of a blinded route paid by the receiver (us). */ + case class HiddenFeePart(paymentHash: ByteVector32, amount: MilliSatoshi, totalAmount: MilliSatoshi) extends PaymentPart /** We successfully received all parts of the payment. */ case class MultiPartPaymentSucceeded(paymentHash: ByteVector32, parts: Queue[PaymentPart]) /** We aborted the payment because of an inconsistency in the payment set or because we didn't receive the total amount in reasonable time. */ @@ -157,7 +158,7 @@ object MultiPartPaymentFSM { // @formatter:off sealed trait Data { def parts: Queue[PaymentPart] - lazy val paidAmount = parts.map(_.virtualAmount).sum + lazy val paidAmount = parts.map(_.amount).sum } case class WaitingForHtlc(parts: Queue[PaymentPart]) extends Data case class PaymentSucceeded(parts: Queue[PaymentPart]) extends Data diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index 4fbf1b5dba..98f1708d3f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -237,15 +237,15 @@ class NodeRelay private(nodeParams: NodeParams, case Relay(packet: IncomingPaymentPacket.NodeRelayPacket, originNode) => require(packet.outerPayload.paymentSecret == paymentSecret, "payment secret mismatch") context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", packet.add.id, packet.add.channelId) - handler ! MultiPartPaymentFSM.HtlcPart(packet.outerPayload.totalAmount, packet.add.amountMsat, packet.add) + handler ! MultiPartPaymentFSM.HtlcPart(packet.outerPayload.totalAmount, packet.add) receiving(htlcs :+ Upstream.Hot.Channel(packet.add.removeUnknownTlvs(), TimestampMilli.now(), originNode), nextPayload, nextPacket_opt, handler) case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) => - context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.realAmount).sum, failure) + context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure) Metrics.recordPaymentRelayFailed(failure.getClass.getSimpleName, Tags.RelayType.Trampoline) - parts.collect { case p: MultiPartPaymentFSM.HtlcPart => rejectHtlc(p.htlc.id, p.htlc.channelId, p.realAmount, Some(failure)) } + parts.collect { case p: MultiPartPaymentFSM.HtlcPart => rejectHtlc(p.htlc.id, p.htlc.channelId, p.amount, Some(failure)) } stopping() case WrappedMultiPartPaymentSucceeded(MultiPartPaymentFSM.MultiPartPaymentSucceeded(_, parts)) => - context.log.info("completed incoming multi-part payment with parts={} paidAmount={}", parts.size, parts.map(_.realAmount).sum) + context.log.info("completed incoming multi-part payment with parts={} paidAmount={}", parts.size, parts.map(_.amount).sum) val upstream = Upstream.Hot.Trampoline(htlcs.toList) validateRelay(nodeParams, upstream, nextPayload) match { case Some(failure) => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 9fa933daae..d79665a0a0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -66,8 +66,8 @@ class AuditDbSpec extends AnyFunSuite { val now = TimestampMilli.now() val e1 = PaymentSent(ZERO_UUID, randomBytes32(), randomBytes32(), 40000 msat, randomKey().publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32(), None) :: Nil) - val pp2a = PaymentReceived.PartialPayment(42000 msat, 42000 msat,randomBytes32()) - val pp2b = PaymentReceived.PartialPayment(42100 msat, 42100 msat,randomBytes32()) + val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32()) + val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32()) val e2 = PaymentReceived(randomBytes32(), pp2a :: pp2b :: Nil) val e3 = ChannelPaymentRelayed(42000 msat, 1000 msat, randomBytes32(), randomBytes32(), randomBytes32(), now - 3.seconds, now) val e4a = TransactionPublished(randomBytes32(), randomKey().publicKey, Transaction(0, Seq.empty, Seq.empty, 0), 42 sat, "mutual") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala index 49d1c5c2ec..829b521612 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PaymentsDbSpec.scala @@ -82,11 +82,11 @@ class PaymentsDbSpec extends AnyFunSuite { // add a few rows val ps1 = OutgoingPayment(UUID.randomUUID(), UUID.randomUUID(), None, paymentHash1, PaymentType.Standard, 12345 msat, 12345 msat, alice, 1000 unixms, None, None, OutgoingPaymentStatus.Pending) val i1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(500 msat), paymentHash1, davePriv, Left("Some invoice"), CltvExpiryDelta(18), expirySeconds = None, timestamp = 1 unixsec) - val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(550 msat, 550 msat, 1100 unixms)) + val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(550 msat, 1100 unixms)) db.addOutgoingPayment(ps1) db.addIncomingPayment(i1, preimage1) - db.receiveIncomingPayment(i1.paymentHash, 550 msat, 550 msat, 1100 unixms) + db.receiveIncomingPayment(i1.paymentHash, 550 msat, 1100 unixms) assert(db.listIncomingPayments(1 unixms, 1500 unixms, None) == Seq(pr1)) assert(db.listOutgoingPayments(1 unixms, 1500 unixms) == Seq(ps1)) @@ -105,7 +105,7 @@ class PaymentsDbSpec extends AnyFunSuite { val ps2 = OutgoingPayment(id2, id2, None, randomBytes32(), PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, 1010 unixms, None, None, OutgoingPaymentStatus.Failed(Nil, 1050 unixms)) val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, 1040 unixms, None, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, 1060 unixms)) val i1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, Left("Some invoice"), CltvExpiryDelta(18), expirySeconds = None, timestamp = 1 unixsec) - val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, 12345678 msat, 1090 unixms)) + val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, 1090 unixms)) val i2 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, Left("Another invoice"), CltvExpiryDelta(18), expirySeconds = Some(30), timestamp = 1 unixsec) val pr2 = IncomingStandardPayment(i2, preimage2, PaymentType.Standard, i2.createdAt.toTimestampMilli, IncomingPaymentStatus.Expired) @@ -166,7 +166,7 @@ class PaymentsDbSpec extends AnyFunSuite { statement.setBytes(1, i1.paymentHash.toArray) statement.setBytes(2, pr1.paymentPreimage.toArray) statement.setString(3, i1.toString) - statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount.toLong) + statement.setLong(4, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong) statement.setLong(5, pr1.createdAt.toLong) statement.setLong(6, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt.toLong) statement.executeUpdate() @@ -298,7 +298,7 @@ class PaymentsDbSpec extends AnyFunSuite { val pendingInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(2500 msat), paymentHash1, bobPriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = now, expirySeconds = Some(30)) val pending = IncomingStandardPayment(pendingInvoice, preimage1, PaymentType.Standard, pendingInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Pending) val paidInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(10_000 msat), paymentHash2, bobPriv, Left("invoice #2"), CltvExpiryDelta(12), timestamp = 250 unixsec, expirySeconds = Some(60)) - val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 11_000 msat, 300.unixsec.toTimestampMilli)) + val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 300.unixsec.toTimestampMilli)) migrationCheck( dbs = dbs, @@ -428,7 +428,7 @@ class PaymentsDbSpec extends AnyFunSuite { val ps2 = OutgoingPayment(id2, id2, None, randomBytes32(), PaymentType.Standard, 1105 msat, 1105 msat, PrivateKey(ByteVector32.One).publicKey, TimestampMilli(Instant.parse("2020-05-14T13:47:21.00Z").toEpochMilli), None, None, OutgoingPaymentStatus.Failed(Nil, TimestampMilli(Instant.parse("2021-05-15T04:12:40.00Z").toEpochMilli))) val ps3 = OutgoingPayment(id3, id3, None, paymentHash1, PaymentType.Standard, 1729 msat, 1729 msat, PrivateKey(ByteVector32.One).publicKey, TimestampMilli(Instant.parse("2021-01-28T09:12:05.00Z").toEpochMilli), None, None, OutgoingPaymentStatus.Succeeded(preimage1, 0 msat, Nil, TimestampMilli.now())) val i1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash1, davePriv, Left("Some invoice"), CltvExpiryDelta(18), expirySeconds = None, timestamp = TimestampSecond.now()) - val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, 12345678 msat, TimestampMilli.now())) + val pr1 = IncomingStandardPayment(i1, preimage1, PaymentType.Standard, i1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(12345678 msat, TimestampMilli.now())) val i2 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(12345678 msat), paymentHash2, carolPriv, Left("Another invoice"), CltvExpiryDelta(18), expirySeconds = Some(24 * 3600), timestamp = TimestampSecond(Instant.parse("2020-12-30T10:00:55.00Z").getEpochSecond)) val pr2 = IncomingStandardPayment(i2, preimage2, PaymentType.Standard, i2.createdAt.toTimestampMilli, IncomingPaymentStatus.Expired) @@ -487,7 +487,7 @@ class PaymentsDbSpec extends AnyFunSuite { } using(connection.prepareStatement("UPDATE received_payments SET (received_msat, received_at) = (? + COALESCE(received_msat, 0), ?) WHERE payment_hash = ?")) { update => - update.setLong(1, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount.toLong) + update.setLong(1, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].amount.toLong) update.setLong(2, pr1.status.asInstanceOf[IncomingPaymentStatus.Received].receivedAt.toLong) update.setString(3, pr1.invoice.paymentHash.toHex) val updated = update.executeUpdate() @@ -526,7 +526,7 @@ class PaymentsDbSpec extends AnyFunSuite { val pendingInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(2500 msat), paymentHash1, bobPriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = now, expirySeconds = Some(30)) val pending = IncomingStandardPayment(pendingInvoice, preimage1, PaymentType.Standard, pendingInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Pending) val paidInvoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(10_000 msat), paymentHash2, bobPriv, Left("invoice #2"), CltvExpiryDelta(12), timestamp = 250 unixsec, expirySeconds = Some(60)) - val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 11_000 msat, 300.unixsec.toTimestampMilli)) + val paid = IncomingStandardPayment(paidInvoice, preimage2, PaymentType.Standard, paidInvoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(11_000 msat, 300.unixsec.toTimestampMilli)) migrationCheck( dbs = dbs, @@ -650,7 +650,7 @@ class PaymentsDbSpec extends AnyFunSuite { // can't receive a payment without an invoice associated with it val unknownPaymentHash = randomBytes32() - assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat, 12345600 msat)) + assert(!db.receiveIncomingPayment(unknownPaymentHash, 12345678 msat)) assert(db.getIncomingPayment(unknownPaymentHash).isEmpty) val expiredInvoice1 = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(561 msat), randomBytes32(), alicePriv, Left("invoice #1"), CltvExpiryDelta(18), timestamp = 1 unixsec) @@ -672,9 +672,9 @@ class PaymentsDbSpec extends AnyFunSuite { val receivedAt2 = TimestampMilli.now() + 2.milli val receivedAt3 = TimestampMilli.now() + 3.milli val receivedAt4 = TimestampMilli.now() + 4.milli - val payment1 = IncomingStandardPayment(paidInvoice1, randomBytes32(), PaymentType.Standard, paidInvoice1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(561 msat, 560 msat, receivedAt2)) - val payment2 = IncomingStandardPayment(paidInvoice2, randomBytes32(), PaymentType.Standard, paidInvoice2.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1111 msat, 1100 msat, receivedAt2)) - val payment3 = IncomingBlindedPayment(paidInvoice3, randomBytes32(), PaymentType.Blinded, paidInvoice3.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1730 msat, 1720 msat, receivedAt3)) + val payment1 = IncomingStandardPayment(paidInvoice1, randomBytes32(), PaymentType.Standard, paidInvoice1.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(561 msat, receivedAt2)) + val payment2 = IncomingStandardPayment(paidInvoice2, randomBytes32(), PaymentType.Standard, paidInvoice2.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1111 msat, receivedAt2)) + val payment3 = IncomingBlindedPayment(paidInvoice3, randomBytes32(), PaymentType.Blinded, paidInvoice3.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(1730 msat, receivedAt3)) db.addIncomingPayment(pendingInvoice1, pendingPayment1.paymentPreimage) db.addIncomingPayment(pendingInvoice2, pendingPayment2.paymentPreimage, PaymentType.SwapIn) @@ -682,7 +682,7 @@ class PaymentsDbSpec extends AnyFunSuite { db.addIncomingPayment(expiredInvoice2, expiredPayment2.paymentPreimage) db.addIncomingPayment(paidInvoice1, payment1.paymentPreimage) db.addIncomingPayment(paidInvoice2, payment2.paymentPreimage) - db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 1730 msat, 1720 msat, receivedAt3) + db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 1730 msat, receivedAt3) assert(db.getIncomingPayment(pendingInvoice1.paymentHash).contains(pendingPayment1)) assert(db.getIncomingPayment(expiredInvoice2.paymentHash).contains(expiredPayment2)) @@ -695,12 +695,12 @@ class PaymentsDbSpec extends AnyFunSuite { assert(db.listReceivedIncomingPayments(0 unixms, now, None) == Seq(payment3)) assert(db.listPendingIncomingPayments(0 unixms, now, None) == Seq(pendingPayment1, pendingPayment2, payment1.copy(status = IncomingPaymentStatus.Pending), payment2.copy(status = IncomingPaymentStatus.Pending))) - db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, 460 msat, receivedAt1) - db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, 100 msat, receivedAt2) // adding another payment to this invoice should sum - db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, 1100 msat, receivedAt2) - db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 3400 msat, 3400 msat, receivedAt4) + db.receiveIncomingPayment(paidInvoice1.paymentHash, 461 msat, receivedAt1) + db.receiveIncomingPayment(paidInvoice1.paymentHash, 100 msat, receivedAt2) // adding another payment to this invoice should sum + db.receiveIncomingPayment(paidInvoice2.paymentHash, 1111 msat, receivedAt2) + db.receiveIncomingOfferPayment(paidInvoice3, payment3.paymentPreimage, 3400 msat, receivedAt4) - val payment4 = payment3.copy(status = IncomingPaymentStatus.Received(5130 msat, 5120 msat, receivedAt4)) + val payment4 = payment3.copy(status = IncomingPaymentStatus.Received(5130 msat, receivedAt4)) assert(db.getIncomingPayment(paidInvoice1.paymentHash).contains(payment1)) assert(db.getIncomingPayment(paidInvoice3.paymentHash).contains(payment4)) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala index cf8177dca8..51fe1a78c7 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala @@ -373,7 +373,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(sent.head.copy(parts = sent.head.parts.sortBy(_.timestamp)) == paymentSent.copy(parts = paymentSent.parts.map(_.copy(route = None)).sortBy(_.timestamp)), sent) awaitCond(nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) } @@ -430,7 +430,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentParts.forall(p => p.status.asInstanceOf[OutgoingPaymentStatus.Succeeded].feesPaid == 0.msat), paymentParts) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) } @@ -482,7 +482,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid == amount * 0.002) // 0.2% awaitCond(nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("F").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) awaitCond({ @@ -515,7 +515,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.recipientAmount == amount, paymentSent) awaitCond(nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("B").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) eventListener.expectMsg(PaymentMetadataReceived(invoice.paymentHash, invoice.paymentMetadata.get)) @@ -553,7 +553,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.recipientAmount == amount, paymentSent) awaitCond(nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) + val Some(IncomingStandardPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(receivedAmount == amount) eventListener.expectMsg(PaymentMetadataReceived(invoice.paymentHash, invoice.paymentMetadata.get)) @@ -650,7 +650,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid > 0.msat, paymentSent) awaitCond(nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -683,7 +683,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid == 0.msat, paymentSent) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount == amount) } @@ -717,7 +717,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid >= 0.msat, paymentSent) awaitCond(nodes("A").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("A").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -754,7 +754,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid >= 0.msat, paymentSent) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -786,7 +786,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(paymentSent.feesPaid >= 0.msat, paymentSent) awaitCond(nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("D").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } @@ -836,7 +836,7 @@ class PaymentIntegrationSpec extends IntegrationSpec { assert(invoice.blindedPaths.forall(_.route.firstNodeId.isInstanceOf[EncodedNodeId.ShortChannelIdDir])) awaitCond(nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash).exists(_.status.isInstanceOf[IncomingPaymentStatus.Received])) - val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) + val Some(IncomingBlindedPayment(_, _, _, _, IncomingPaymentStatus.Received(receivedAmount, _))) = nodes("C").nodeParams.db.payments.getIncomingPayment(paymentSent.paymentHash) assert(receivedAmount >= amount) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index a56b9760c0..53d84b359d 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -121,10 +121,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -137,10 +137,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(70_000 msat, add.amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(add.amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(70_000 msat, add.amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(add.amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -155,10 +155,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -181,14 +181,14 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, RelayFees.zero) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == finalPacket.add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(finalPacket.add.paymentHash, PartialPayment(amountMsat, amountMsat, finalPacket.add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(finalPacket.add.paymentHash, PartialPayment(amountMsat, finalPacket.add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) sender.expectNoMessage(50 millis) } @@ -489,7 +489,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, RelayFees.zero) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status.isInstanceOf[IncomingPaymentStatus.Received]) } @@ -532,7 +532,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, RelayFees.zero) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(5000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).isEmpty) @@ -570,7 +570,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike }) // Extraneous HTLCs should be failed. - f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(pr1.paymentHash, HtlcPart(1000 msat, 1000 msat, UpdateAddHtlc(ByteVector32.One, 42, 200 msat, pr1.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), Some(PaymentTimeout()))) + f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(pr1.paymentHash, HtlcPart(1000 msat, UpdateAddHtlc(ByteVector32.One, 42, 200 msat, pr1.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), Some(PaymentTimeout()))) f.register.expectMsg(Register.Forward(null, ByteVector32.One, CMD_FAIL_HTLC(42, FailureReason.LocalFailure(PaymentTimeout()), commit = true))) // The payment should still be pending in DB. @@ -601,25 +601,21 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike ) val paymentReceived = f.eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(800 msat, 800 msat, ByteVector32.One, 0 unixms), PartialPayment(200 msat, 200 msat, ByteVector32.Zeroes, 0 unixms))) + assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(800 msat, ByteVector32.One, 0 unixms), PartialPayment(200 msat, ByteVector32.Zeroes, 0 unixms))) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1000.msat) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1000.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1000.msat) awaitCond({ f.sender.send(handler, GetPendingPayments) f.sender.expectMsgType[PendingPayments].paymentHashes.isEmpty }) // Extraneous HTLCs should be fulfilled. - f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(invoice.paymentHash, HtlcPart(1000 msat, 200 msat, UpdateAddHtlc(ByteVector32.One, 44, 200 msat, invoice.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), None)) + f.sender.send(handler, MultiPartPaymentFSM.ExtraPaymentReceived(invoice.paymentHash, HtlcPart(1000 msat, UpdateAddHtlc(ByteVector32.One, 44, 200 msat, invoice.paymentHash, add1.cltvExpiry, add1.onionRoutingPacket, None, 1.0, None)), None)) f.register.expectMsg(Register.Forward(null, ByteVector32.One, CMD_FULFILL_HTLC(44, preimage, commit = true))) - val paymentReceived2 = f.eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived2.virtualAmount == 200.msat) - assert(paymentReceived2.realAmount == 200.msat) + assert(f.eventListener.expectMsgType[PaymentReceived].amount == 200.msat) val received2 = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) - assert(received2.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1200.msat) - assert(received2.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1200.msat) + assert(received2.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1200.msat) f.sender.send(handler, GetPendingPayments) f.sender.expectMsgType[PendingPayments].paymentHashes.isEmpty @@ -644,11 +640,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike ) val paymentReceived = f.eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(1100 msat, 1100 msat, add1.channelId, 0 unixms), PartialPayment(500 msat, 500 msat, add2.channelId, 0 unixms))) + assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(1100 msat, add1.channelId, 0 unixms), PartialPayment(500 msat, add2.channelId, 0 unixms))) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1600.msat) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1600.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1600.msat) } test("PaymentHandler should handle multi-part payment timeout then success") { f => @@ -682,11 +677,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val paymentReceived = f.eventListener.expectMsgType[PaymentReceived] assert(paymentReceived.paymentHash == invoice.paymentHash) - assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(300 msat, 300 msat, ByteVector32.One, 0 unixms), PartialPayment(700 msat, 700 msat, ByteVector32.Zeroes, 0 unixms))) + assert(paymentReceived.parts.map(_.copy(timestamp = 0 unixms)).toSet == Set(PartialPayment(300 msat, ByteVector32.One, 0 unixms), PartialPayment(700 msat, ByteVector32.Zeroes, 0 unixms))) val received = nodeParams.db.payments.getIncomingPayment(invoice.paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].virtualAmount == 1000.msat) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].realAmount == 1000.msat) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].amount == 1000.msat) awaitCond({ f.sender.send(handler, GetPendingPayments) f.sender.expectMsgType[PendingPayments].paymentHashes.isEmpty @@ -709,10 +703,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) } test("PaymentHandler should handle single-part KeySend payment without payment secret") { f => @@ -730,10 +724,10 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] val paymentReceived = eventListener.expectMsgType[PaymentReceived] - assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) + assert(paymentReceived.copy(parts = paymentReceived.parts.map(_.copy(timestamp = 0 unixms))) == PaymentReceived(add.paymentHash, PartialPayment(amountMsat, add.channelId, timestamp = 0 unixms) :: Nil)) val received = nodeParams.db.payments.getIncomingPayment(paymentHash) assert(received.isDefined && received.get.status.isInstanceOf[IncomingPaymentStatus.Received]) - assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, amountMsat, 0 unixms)) + assert(received.get.status.asInstanceOf[IncomingPaymentStatus.Received].copy(receivedAt = 0 unixms) == IncomingPaymentStatus.Received(amountMsat, 0 unixms)) } test("PaymentHandler should reject KeySend payment when feature is disabled") { f => @@ -792,7 +786,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val add = UpdateAddHtlc(ByteVector32.One, 0, 1000 msat, paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None, 1.0, None) val invoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, None, paymentHash, randomKey(), Left("dummy"), CltvExpiryDelta(12)) val incomingPayment = IncomingStandardPayment(invoice, paymentPreimage, PaymentType.Standard, invoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Pending) - val fulfill = DoFulfill(incomingPayment, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, 1000 msat, add)))) + val fulfill = DoFulfill(incomingPayment, MultiPartPaymentFSM.MultiPartPaymentSucceeded(paymentHash, Queue(HtlcPart(1000 msat, add)))) sender.send(handlerWithoutMpp, fulfill) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.id == add.id) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala index e00e142402..dcdd2ef225 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartPaymentFSMSpec.scala @@ -233,7 +233,7 @@ object MultiPartPaymentFSMSpec { def createMultiPartHtlc(totalAmount: MilliSatoshi, htlcAmount: MilliSatoshi, htlcId: Long): HtlcPart = { val htlc = UpdateAddHtlc(htlcIdToChannelId(htlcId), htlcId, htlcAmount, paymentHash, CltvExpiry(42), TestConstants.emptyOnionPacket, None, 1.0, None) - HtlcPart(totalAmount, htlcAmount, htlc) + HtlcPart(totalAmount, htlc) } } \ No newline at end of file diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index 301ad9f966..8133fafc34 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -197,7 +197,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit val paymentHash = Crypto.sha256(preimage) val invoice = Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(500 msat), paymentHash, TestConstants.Bob.nodeKeyManager.nodeKey.privateKey, Left("Some invoice"), CltvExpiryDelta(18)) nodeParams.db.payments.addIncomingPayment(invoice, preimage) - nodeParams.db.payments.receiveIncomingPayment(paymentHash, 5000 msat, 5000 msat) + nodeParams.db.payments.receiveIncomingPayment(paymentHash, 5000 msat) val htlc_ab_1 = Seq( buildFinalHtlc(0, channelId_ab_1, randomBytes32()), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala index f6bc3b256d..81e1cd90c4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala @@ -125,11 +125,10 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(handlePayment.offerId == offer.offerId) assert(handlePayment.pluginData_opt.contains(hex"deadbeef")) handlePayment.replyTo ! PaymentActor.AcceptPayment() - val ProcessPayment(incomingPayment, hiddenRelayFees) = paymentHandler.expectMessageType[ProcessPayment] + val ProcessPayment(incomingPayment) = paymentHandler.expectMessageType[ProcessPayment] assert(Crypto.sha256(incomingPayment.paymentPreimage) == invoice.paymentHash) assert(incomingPayment.invoice.nodeId == nodeParams.nodeId) assert(incomingPayment.invoice.paymentHash == invoice.paymentHash) - assert(hiddenRelayFees == RelayFees.zero) } test("pay offer without path_id") { f => @@ -316,11 +315,10 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val handlePayment = handler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) handlePayment.replyTo ! PaymentActor.AcceptPayment() - val ProcessPayment(incomingPayment, hiddenRelayFees) = paymentHandler.expectMessageType[ProcessPayment] + val ProcessPayment(incomingPayment) = paymentHandler.expectMessageType[ProcessPayment] assert(Crypto.sha256(incomingPayment.paymentPreimage) == invoice.paymentHash) assert(incomingPayment.invoice.nodeId == nodeParams.nodeId) assert(incomingPayment.invoice.paymentHash == invoice.paymentHash) - assert(hiddenRelayFees == RelayFees(1000 msat, 200)) } test("invalid payment (incorrect amount with hidden fee)") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala index ba0b97c3e3..866cca0210 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/receive/InvoicePurgerSpec.scala @@ -52,11 +52,11 @@ class InvoicePurgerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("ap // create paid invoices val receivedAt = TimestampMilli.now() + 1.milli val paidInvoices = Seq.fill(count)(Bolt11Invoice(Block.Testnet3GenesisBlock.hash, Some(100 msat), randomBytes32(), alicePriv, Left("paid invoice"), CltvExpiryDelta(18))) - val paidPayments = paidInvoices.map(invoice => IncomingStandardPayment(invoice, randomBytes32(), PaymentType.Standard, invoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(100 msat, 100 msat, receivedAt))) + val paidPayments = paidInvoices.map(invoice => IncomingStandardPayment(invoice, randomBytes32(), PaymentType.Standard, invoice.createdAt.toTimestampMilli, IncomingPaymentStatus.Received(100 msat, receivedAt))) paidPayments.foreach(payment => { db.addIncomingPayment(payment.invoice, payment.paymentPreimage) // receive payment - db.receiveIncomingPayment(payment.invoice.paymentHash, 100 msat, 100 msat, receivedAt) + db.receiveIncomingPayment(payment.invoice.paymentHash, 100 msat, receivedAt) }) val now = TimestampMilli.now() diff --git a/eclair-node/src/test/resources/api/received-success b/eclair-node/src/test/resources/api/received-success index 66cebe4d02..c70299259c 100644 --- a/eclair-node/src/test/resources/api/received-success +++ b/eclair-node/src/test/resources/api/received-success @@ -1 +1 @@ -{"invoice":{"prefix":"lnbc","timestamp":1496314658,"nodeId":"03779dc8b593b74509fab7c8accebc7a9b91d85d9df456d5b885464a34e5751d52","serialized":"lnbc2500u1pvjluezsp5cssgls5lpvunj7zallxsn3v8g3f9wqfs75hsdmkrtxwgkafers0spp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5xysxxatsyp3k7enxv4jsxqzpuaztrnwngzn3kdzw5hydlzf03qdgm2hdq27cqv3agm2awhz5se903vruatfhq77w3ls4evs3ch9zw97j25emudupq63nyw24cg27h2rspzma2k0","description":"1 cup coffee","paymentHash":"0001020304050607080900010203040506070809000102030405060708090102","expiry":60,"amount":250000000,"features":{"activated":{},"unknown":[]},"routingInfo":[]},"paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","paymentType":"Standard","createdAt":{"iso":"1970-01-01T00:00:00.042Z","unix":0},"status":{"type":"received","virtualAmount":42,"realAmount":42,"receivedAt":{"iso":"2021-10-05T13:12:23.777Z","unix":1633439543}}} \ No newline at end of file +{"invoice":{"prefix":"lnbc","timestamp":1496314658,"nodeId":"03779dc8b593b74509fab7c8accebc7a9b91d85d9df456d5b885464a34e5751d52","serialized":"lnbc2500u1pvjluezsp5cssgls5lpvunj7zallxsn3v8g3f9wqfs75hsdmkrtxwgkafers0spp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdq5xysxxatsyp3k7enxv4jsxqzpuaztrnwngzn3kdzw5hydlzf03qdgm2hdq27cqv3agm2awhz5se903vruatfhq77w3ls4evs3ch9zw97j25emudupq63nyw24cg27h2rspzma2k0","description":"1 cup coffee","paymentHash":"0001020304050607080900010203040506070809000102030405060708090102","expiry":60,"amount":250000000,"features":{"activated":{},"unknown":[]},"routingInfo":[]},"paymentPreimage":"0100000000000000000000000000000000000000000000000000000000000000","paymentType":"Standard","createdAt":{"iso":"1970-01-01T00:00:00.042Z","unix":0},"status":{"type":"received","amount":42,"receivedAt":{"iso":"2021-10-05T13:12:23.777Z","unix":1633439543}}} \ No newline at end of file diff --git a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala index d191e462a1..348d4b7d07 100644 --- a/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala +++ b/eclair-node/src/test/scala/fr/acinq/eclair/api/ApiServiceSpec.scala @@ -908,7 +908,7 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM val defaultPayment = IncomingStandardPayment(Bolt11Invoice.fromString(invoice).get, ByteVector32.One, PaymentType.Standard, 42 unixms, IncomingPaymentStatus.Pending) val eclair = mock[Eclair] val received = randomBytes32() - eclair.receivedInfo(received)(any) returns Future.successful(Some(defaultPayment.copy(status = IncomingPaymentStatus.Received(42 msat, 42 msat, TimestampMilli(1633439543777L))))) + eclair.receivedInfo(received)(any) returns Future.successful(Some(defaultPayment.copy(status = IncomingPaymentStatus.Received(42 msat, TimestampMilli(1633439543777L))))) val mockService = new MockService(eclair) Post("/getreceivedinfo", FormData("paymentHash" -> received.toHex).toEntity) ~> @@ -1198,8 +1198,8 @@ class ApiServiceSpec extends AnyFunSuite with ScalatestRouteTest with IdiomaticM system.eventStream.publish(ptrel) wsClient.expectMessage(expectedSerializedPtrel) - val precv = PaymentReceived(ByteVector32.Zeroes, Seq(PaymentReceived.PartialPayment(21 msat, 21 msat, ByteVector32.Zeroes, TimestampMilli(1553784963659L)))) - val expectedSerializedPrecv = """{"type":"payment-received","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","parts":[{"virtualAmount":21,"realAmount":21,"fromChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":{"iso":"2019-03-28T14:56:03.659Z","unix":1553784963}}]}""" + val precv = PaymentReceived(ByteVector32.Zeroes, Seq(PaymentReceived.PartialPayment(21 msat, ByteVector32.Zeroes, TimestampMilli(1553784963659L)))) + val expectedSerializedPrecv = """{"type":"payment-received","paymentHash":"0000000000000000000000000000000000000000000000000000000000000000","parts":[{"amount":21,"fromChannelId":"0000000000000000000000000000000000000000000000000000000000000000","timestamp":{"iso":"2019-03-28T14:56:03.659Z","unix":1553784963}}]}""" assert(serialization.write(precv) == expectedSerializedPrecv) system.eventStream.publish(precv) wsClient.expectMessage(expectedSerializedPrecv) From 874a19d4485a829838780ea0b9642d913f2e4fe0 Mon Sep 17 00:00:00 2001 From: t-bast Date: Mon, 17 Feb 2025 15:48:10 +0100 Subject: [PATCH 4/6] Clarify recipient blinded path fee deduction We clarify that the previously named "hiddenFee" is actually the blinded path fee that is being paid by the recipient (by deducing it from the amount it plans on receiving). We thus rename variables and add docs. We also revert the validation flow: instead of generating a payment part for that fee based on the difference between the paid amount and the onion value, we generate that payment part based on the blinded path fee we recorded in the `path_id` and validate that it matches what the payer is sending. We revert some of the changes to the `MultiPartHandler` which aren't necessary and do some clean-up of some unused code paths. --- .../eclair/payment/offer/OfferManager.scala | 78 +++++++++++-------- .../payment/offer/OfferPaymentMetadata.scala | 21 +++-- .../payment/receive/MultiPartHandler.scala | 78 ++++++++----------- .../payment/receive/MultiPartPaymentFSM.scala | 2 +- .../payment/send/PaymentLifecycle.scala | 3 - .../scala/fr/acinq/eclair/router/Router.scala | 16 ++-- 6 files changed, 98 insertions(+), 100 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala index b31b4849a2..6ff2069498 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala @@ -24,17 +24,17 @@ import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.crypto.Sphinx.RouteBlinding import fr.acinq.eclair.db.{IncomingBlindedPayment, IncomingPaymentStatus, PaymentType} import fr.acinq.eclair.message.{OnionMessages, Postman} -import fr.acinq.eclair.payment.MinimalBolt12Invoice import fr.acinq.eclair.payment.offer.OfferPaymentMetadata.MinimalInvoiceData import fr.acinq.eclair.payment.receive.MultiPartHandler import fr.acinq.eclair.payment.receive.MultiPartHandler.{CreateInvoiceActor, ReceivingRoute} import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair.payment.{Bolt12Invoice, MinimalBolt12Invoice} import fr.acinq.eclair.router.BlindedRouteCreation.aggregatePaymentInfo import fr.acinq.eclair.router.Router import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, InvoiceTlv, Offer} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiryDelta, Logs, MilliSatoshi, NodeParams, TimestampMilli, TimestampSecond, nodeFee, randomBytes32} +import fr.acinq.eclair.{CltvExpiryDelta, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, TimestampSecond, nodeFee, randomBytes32} import scodec.bits.ByteVector import scala.concurrent.duration.FiniteDuration @@ -65,7 +65,7 @@ object OfferManager { case class RequestInvoice(messagePayload: MessageOnion.InvoiceRequestPayload, blindedKey: PrivateKey, postman: ActorRef[Postman.SendMessage]) extends Command - case class ReceivePayment(replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], paymentHash: ByteVector32, payload: FinalPayload.Blinded, realAmount: MilliSatoshi) extends Command + case class ReceivePayment(replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], paymentHash: ByteVector32, payload: FinalPayload.Blinded) extends Command /** * Offer handlers must be implemented in separate plugins and respond to these two `HandlerCommand`. @@ -117,16 +117,14 @@ object OfferManager { case _ => context.log.debug("offer {} is not registered or invoice request is invalid", messagePayload.invoiceRequest.offer.offerId) } Behaviors.same - case ReceivePayment(replyTo, paymentHash, payload, realAmount) => + case ReceivePayment(replyTo, paymentHash, payload) => MinimalInvoiceData.decode(payload.pathId) match { case Some(signed) => registeredOffers.get(signed.offerId) match { case Some(RegisteredOffer(offer, _, _, handler)) => MinimalInvoiceData.verify(nodeParams.nodeId, signed) match { - case Some(metadata) if realAmount + nodeFee(metadata.hiddenFees, realAmount) < payload.amount => - replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(s"incorrect amount received for offer ${signed.offerId.toHex}: realAmount=$realAmount, hiddenFees=${metadata.hiddenFees}, virtualAmount=${payload.amount}") case Some(metadata) if Crypto.sha256(metadata.preimage) == paymentHash => - val child = context.spawnAnonymous(PaymentActor(nodeParams, replyTo, offer, metadata, paymentTimeout)) + val child = context.spawnAnonymous(PaymentActor(nodeParams, replyTo, offer, metadata, payload, paymentTimeout)) handler ! HandlePayment(child, signed.offerId, metadata.pluginData_opt) case Some(_) => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(s"preimage does not match payment hash for offer ${signed.offerId.toHex}") case None => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(s"invalid signature for metadata for offer ${signed.offerId.toHex}") @@ -151,26 +149,43 @@ object OfferManager { * * @param amount Amount for the invoice (must be the same as the invoice request if it contained an amount). * @param routes Routes to use for the payment. - * @param hideFees If true, fees for the blinded route will be hidden to the payer and paid by the recipient. * @param pluginData_opt Some data for the handler by the handler. It will be sent to the handler when a payment is attempted. * @param additionalTlvs additional TLVs to add to the invoice. * @param customTlvs custom TLVs to add to the invoice. */ case class ApproveRequest(amount: MilliSatoshi, routes: Seq[Route], - hideFees: Boolean, pluginData_opt: Option[ByteVector] = None, additionalTlvs: Set[InvoiceTlv] = Set.empty, customTlvs: Set[GenericTlv] = Set.empty) extends Command - case class Route(hops: Seq[Router.ChannelHop], maxFinalExpiryDelta: CltvExpiryDelta, shortChannelIdDir_opt: Option[ShortChannelIdDir] = None) + /** + * @param recipientPaysFees If true, fees for the blinded route will be hidden to the payer and paid by the recipient. + */ + case class Route(hops: Seq[Router.ChannelHop], recipientPaysFees: Boolean, maxFinalExpiryDelta: CltvExpiryDelta, shortChannelIdDir_opt: Option[ShortChannelIdDir] = None) { + def finalize(nodePriv: PrivateKey, preimage: ByteVector32, amount: MilliSatoshi, invoiceRequest: InvoiceRequest, minFinalExpiryDelta: CltvExpiryDelta, pluginData_opt: Option[ByteVector]): ReceivingRoute = { + val (paymentInfo, metadata) = if (recipientPaysFees) { + val realPaymentInfo = aggregatePaymentInfo(amount, hops, minFinalExpiryDelta) + val recipientFees = RelayFees(realPaymentInfo.feeBase, realPaymentInfo.feeProportionalMillionths) + val metadata = MinimalInvoiceData(preimage, invoiceRequest.payerId, TimestampSecond.now(), invoiceRequest.quantity, amount, recipientFees, pluginData_opt) + val paymentInfo = realPaymentInfo.copy(feeBase = 0 msat, feeProportionalMillionths = 0) + (paymentInfo, metadata) + } else { + val paymentInfo = aggregatePaymentInfo(amount, hops, minFinalExpiryDelta) + val metadata = MinimalInvoiceData(preimage, invoiceRequest.payerId, TimestampSecond.now(), invoiceRequest.quantity, amount, RelayFees.zero, pluginData_opt) + (paymentInfo, metadata) + } + val pathId = MinimalInvoiceData.encode(nodePriv, invoiceRequest.offer.offerId, metadata) + ReceivingRoute(hops, pathId, maxFinalExpiryDelta, paymentInfo, shortChannelIdDir_opt) + } + } /** * Sent by the offer handler to reject the request. For instance because stock has been exhausted. */ case class RejectRequest(message: String) extends Command - private case class WrappedInvoiceResponse(response: CreateInvoiceActor.Bolt12InvoiceResponse) extends Command + private case class WrappedInvoiceResponse(invoice: Bolt12Invoice) extends Command private case class WrappedOnionMessageResponse(response: Postman.OnionMessageResponse) extends Command @@ -203,17 +218,10 @@ object OfferManager { context.log.debug("offer handler rejected invoice request: {}", error) postman ! Postman.SendMessage(OfferTypes.BlindedPath(pathToSender), OnionMessages.RoutingStrategy.FindRoute, TlvStream(OnionMessagePayloadTlv.InvoiceError(TlvStream(OfferTypes.Error(error)))), expectsReply = false, context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse)) waitForSent() - case ApproveRequest(amount, routes, hideFees, pluginData_opt, additionalTlvs, customTlvs) => + case ApproveRequest(amount, routes, pluginData_opt, additionalTlvs, customTlvs) => val preimage = randomBytes32() - val receivingRoutes = routes.map(route => { - val paymentInfo = aggregatePaymentInfo(amount, route.hops, nodeParams.channelConf.minFinalExpiryDelta) - val hiddenFees = if (hideFees) RelayFees(paymentInfo.feeBase, paymentInfo.feeProportionalMillionths) else RelayFees.zero - val metadata = MinimalInvoiceData(preimage, invoiceRequest.payerId, TimestampSecond.now(), invoiceRequest.quantity, amount, hiddenFees, pluginData_opt) - val pathId = MinimalInvoiceData.encode(nodeParams.privateKey, invoiceRequest.offer.offerId, metadata) - val paymentInfo1 = if (hideFees) paymentInfo.copy(feeBase = MilliSatoshi(0), feeProportionalMillionths = 0) else paymentInfo - ReceivingRoute(route.hops, pathId, route.maxFinalExpiryDelta, paymentInfo1, route.shortChannelIdDir_opt) - }) - val receivePayment = MultiPartHandler.ReceiveOfferPayment(context.messageAdapter[CreateInvoiceActor.Bolt12InvoiceResponse](WrappedInvoiceResponse), nodeKey, invoiceRequest, receivingRoutes, preimage, additionalTlvs, customTlvs) + val receivingRoutes = routes.map(_.finalize(nodeParams.privateKey, preimage, amount, invoiceRequest, nodeParams.channelConf.minFinalExpiryDelta, pluginData_opt)) + val receivePayment = MultiPartHandler.ReceiveOfferPayment(context.messageAdapter[Bolt12Invoice](WrappedInvoiceResponse), nodeKey, invoiceRequest, receivingRoutes, preimage, additionalTlvs, customTlvs) val child = context.spawnAnonymous(CreateInvoiceActor(nodeParams)) child ! CreateInvoiceActor.CreateBolt12Invoice(receivePayment) waitForInvoice() @@ -222,16 +230,10 @@ object OfferManager { private def waitForInvoice(): Behavior[Command] = { Behaviors.receiveMessagePartial { - case WrappedInvoiceResponse(invoiceResponse) => - invoiceResponse match { - case CreateInvoiceActor.InvoiceCreated(invoice) => - context.log.debug("invoice created for offerId={} invoice={}", invoice.invoiceRequest.offer.offerId, invoice.toString) - postman ! Postman.SendMessage(OfferTypes.BlindedPath(pathToSender), OnionMessages.RoutingStrategy.FindRoute, TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)), expectsReply = false, context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse)) - waitForSent() - case f: CreateInvoiceActor.InvoiceCreationFailed => - context.log.debug("invoice creation failed: {}", f.message) - Behaviors.stopped - } + case WrappedInvoiceResponse(invoice) => + context.log.debug("invoice created for offerId={} invoice={}", invoice.invoiceRequest.offer.offerId, invoice.toString) + postman ! Postman.SendMessage(OfferTypes.BlindedPath(pathToSender), OnionMessages.RoutingStrategy.FindRoute, TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)), expectsReply = false, context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse)) + waitForSent() } } @@ -248,7 +250,8 @@ object OfferManager { sealed trait Command /** - * Sent by the offer handler. Causes the creation of a dummy invoice that matches as best as possible the actual invoice for this payment (since the actual invoice is not stored) and will be used in the payment handler. + * Sent by the offer handler. Causes the creation of a dummy invoice that matches as best as possible the actual + * invoice for this payment (since the actual invoice is not stored) and will be used in the payment handler. * * @param additionalTlvs additional TLVs to add to the dummy invoice. Should be the same as what was used for the actual invoice. * @param customTlvs custom TLVs to add to the dummy invoice. Should be the same as what was used for the actual invoice. @@ -260,14 +263,21 @@ object OfferManager { */ case class RejectPayment(reason: String) extends Command - def apply(nodeParams: NodeParams, replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], offer: Offer, metadata: MinimalInvoiceData, timeout: FiniteDuration): Behavior[Command] = { + def apply(nodeParams: NodeParams, + replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], + offer: Offer, + metadata: MinimalInvoiceData, + payload: FinalPayload.Blinded, + timeout: FiniteDuration): Behavior[Command] = { Behaviors.setup { context => context.scheduleOnce(timeout, context.self, RejectPayment("plugin timeout")) Behaviors.receiveMessage { case AcceptPayment(additionalTlvs, customTlvs) => val minimalInvoice = MinimalBolt12Invoice(offer, nodeParams.chainHash, metadata.amount, metadata.quantity, Crypto.sha256(metadata.preimage), metadata.payerKey, metadata.createdAt, additionalTlvs, customTlvs) val incomingPayment = IncomingBlindedPayment(minimalInvoice, metadata.preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment) + // We may be deducing some of the blinded path fees from the received amount. + val recipientPathFees = nodeFee(metadata.recipientPathFees, payload.amount) + replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment, recipientPathFees) Behaviors.stopped case RejectPayment(reason) => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(reason) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala index fa60b83899..e8ba0d5caa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferPaymentMetadata.scala @@ -34,23 +34,28 @@ import scodec.bits.ByteVector * We instead include payment metadata in the blinded route's path_id field which lets us generate a minimal invoice * once we receive the payment, that is similar to the one that was actually sent to the payer. It will not be exactly * the same (notably the blinding route will be missing) but it will contain what we need to fulfill the payment. + * + * Since the recipient is selecting the blinded route to themselves, it may be unfair to the payer if the blinded route + * requires a high routing fee. The recipient can instead opt into having some of those routing fees deducted from the + * amount they receive by setting [[MinimalInvoiceData.recipientPathFees]] to a non-zero value. */ object OfferPaymentMetadata { /** - * @param preimage preimage for that payment. - * @param payerKey payer key (from their invoice request). - * @param createdAt creation time of the invoice. - * @param quantity quantity of items requested. - * @param amount amount that must be paid. - * @param pluginData_opt optional data from the offer plugin. + * @param preimage preimage for that payment. + * @param payerKey payer key (from their invoice request). + * @param createdAt creation time of the invoice. + * @param quantity quantity of items requested. + * @param amount amount that must be paid. + * @param recipientPathFees the payment recipient may choose to pay part of the blinded path relay fees themselves. + * @param pluginData_opt optional data from the offer plugin. */ case class MinimalInvoiceData(preimage: ByteVector32, payerKey: PublicKey, createdAt: TimestampSecond, quantity: Long, amount: MilliSatoshi, - hiddenFees: RelayFees, + recipientPathFees: RelayFees, pluginData_opt: Option[ByteVector]) /** @@ -71,7 +76,7 @@ object OfferPaymentMetadata { ("createdAt" | timestampSecond) :: ("quantity" | uint64overflow) :: ("amount" | millisatoshi) :: - ("hiddenFees" | (millisatoshi :: int64).as[RelayFees]) :: + ("recipientPathFees" | (millisatoshi :: int64).as[RelayFees]) :: ("pluginData" | optional(bitsRemaining, bytes))).as[MinimalInvoiceData] private val signedDataCodec: Codec[SignedMinimalInvoiceData] = diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 6ab1b38d98..2a2bca0f10 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -18,13 +18,11 @@ package fr.acinq.eclair.payment.receive import akka.actor.Actor.Receive import akka.actor.typed.Behavior -import akka.actor.typed.scaladsl.AskPattern.Askable import akka.actor.typed.scaladsl.Behaviors -import akka.actor.typed.scaladsl.adapter.{ClassicActorContextOps, ClassicActorRefOps} +import akka.actor.typed.scaladsl.adapter.ClassicActorContextOps import akka.actor.{ActorContext, ActorRef, PoisonPill, typed} import akka.event.{DiagnosticLoggingAdapter, LoggingAdapter} -import akka.util.Timeout -import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} +import fr.acinq.bitcoin.scalacompat.Crypto.PrivateKey import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto} import fr.acinq.eclair.EncodedNodeId.ShortChannelIdDir import fr.acinq.eclair.Logs.LogCategory @@ -35,20 +33,14 @@ import fr.acinq.eclair.payment.Bolt11Invoice.ExtraHop import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.offer.OfferManager -import fr.acinq.eclair.payment.relay.Relayer.RelayFees -import fr.acinq.eclair.router.BlindedRouteCreation.{aggregatePaymentInfo, createBlindedRouteFromHops} +import fr.acinq.eclair.router.BlindedRouteCreation.createBlindedRouteFromHops import fr.acinq.eclair.router.Router -import fr.acinq.eclair.router.Router.{ChannelHop, HopRelayParams, PaymentRouteResponse} import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, InvoiceTlv} import fr.acinq.eclair.wire.protocol.PaymentOnion.FinalPayload import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{Bolt11Feature, CltvExpiryDelta, FeatureSupport, Features, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, ShortChannelId, TimestampMilli, nodeFee, randomBytes32} +import fr.acinq.eclair.{Bolt11Feature, CltvExpiryDelta, FeatureSupport, Features, Logs, MilliSatoshi, MilliSatoshiLong, NodeParams, TimestampMilli, randomBytes32} import scodec.bits.{ByteVector, HexStringSyntax} -import scala.concurrent.duration.DurationInt -import scala.concurrent.{ExecutionContextExecutor, Future} -import scala.util.{Failure, Success, Try} - /** * Simple payment handler that generates invoices and fulfills incoming htlcs. * @@ -62,21 +54,13 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP private var pendingPayments: Map[ByteVector32, (IncomingPayment, ActorRef)] = Map.empty private def addHtlcPart(ctx: ActorContext, add: UpdateAddHtlc, payload: FinalPayload, payment: IncomingPayment): Unit = { - val handler = pendingPayments.get(add.paymentHash) match { + pendingPayments.get(add.paymentHash) match { case Some((_, handler)) => - handler + handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, add) case None => val handler = ctx.actorOf(MultiPartPaymentFSM.props(nodeParams, add.paymentHash, payload.totalAmount, ctx.self)) + handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, add) pendingPayments = pendingPayments + (add.paymentHash -> (payment, handler)) - handler - } - handler ! MultiPartPaymentFSM.HtlcPart(payload.totalAmount, add) - payload match { - case payload: FinalPayload.Blinded if payload.amount - add.amountMsat > 0.msat => - val hiddenFee = payload.amount - add.amountMsat - handler ! MultiPartPaymentFSM.HiddenFeePart(add.paymentHash, hiddenFee, payload.totalAmount) - case _: FinalPayload.Blinded => () - case _: FinalPayload.Standard => () } } @@ -141,15 +125,20 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP } } - case ProcessBlindedPacket(add, payload, payment) if doHandle(add.paymentHash) => + case ProcessBlindedPacket(add, payload, payment, recipientPathFees) if doHandle(add.paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(add.paymentHash))) { - validateBlindedPayment(nodeParams, add, payload, payment) match { + validateBlindedPayment(nodeParams, add, payload, payment, recipientPathFees) match { case Some(cmdFail) => Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, Tags.FailureType(cmdFail)).increment() PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmdFail) case None => - log.debug("received payment for virtualAmount={} realAmount={} totalAmount={}", payload.amount, add.amountMsat, payload.totalAmount) + log.debug("received payment for amount={} recipientPathFees={} totalAmount={}", add.amountMsat, recipientPathFees, payload.totalAmount) addHtlcPart(ctx, add, payload, payment) + if (recipientPathFees > 0.msat) { + // We've opted into deducing the blinded paths fees from the amount we receive for this payment. + // We add an artificial payment part for those fees, otherwise we will never reach the total amount. + pendingPayments.get(add.paymentHash).foreach(_._2 ! MultiPartPaymentFSM.RecipientBlindedPathFeePart(add.paymentHash, recipientPathFees, payload.totalAmount)) + } } } @@ -185,7 +174,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP failure match { case Some(failure) => p match { case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, FailureReason.LocalFailure(failure), commit = true)) - case _: MultiPartPaymentFSM.HiddenFeePart => () + case _: MultiPartPaymentFSM.RecipientBlindedPathFeePart => () } case None => p match { // NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more @@ -200,7 +189,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, cmdFail) } }) - case _: MultiPartPaymentFSM.HiddenFeePart => () + case _: MultiPartPaymentFSM.RecipientBlindedPathFeePart => () } } } @@ -210,7 +199,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP log.debug("fulfilling payment for amount={}", parts.map(_.amount).sum) val received = PaymentReceived(paymentHash, parts.flatMap { case p: MultiPartPaymentFSM.HtlcPart => Some(PaymentReceived.PartialPayment(p.amount, p.htlc.channelId)) - case _: MultiPartPaymentFSM.HiddenFeePart => None + case _: MultiPartPaymentFSM.RecipientBlindedPathFeePart => None }) val recordedInDb = payment match { // Incoming offer payments are not stored in the database until they have been paid. @@ -247,9 +236,9 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP object MultiPartHandler { // @formatter:off - case class ProcessPacket(add: UpdateAddHtlc, payload: FinalPayload.Standard, payment_opt: Option[IncomingStandardPayment]) - case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment) - case class RejectPacket(add: UpdateAddHtlc, failure: FailureMessage) + private case class ProcessPacket(add: UpdateAddHtlc, payload: FinalPayload.Standard, payment_opt: Option[IncomingStandardPayment]) + private case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment, recipientPathFees: MilliSatoshi) + private case class RejectPacket(add: UpdateAddHtlc, failure: FailureMessage) case class DoFulfill(payment: IncomingPayment, success: MultiPartPaymentFSM.MultiPartPaymentSucceeded) case object GetPendingPayments @@ -295,7 +284,7 @@ object MultiPartHandler { * @param routes routes that must be blinded and provided in the invoice. * @param paymentPreimage payment preimage. */ - case class ReceiveOfferPayment(replyTo: typed.ActorRef[CreateInvoiceActor.Bolt12InvoiceResponse], + case class ReceiveOfferPayment(replyTo: typed.ActorRef[Bolt12Invoice], nodeKey: PrivateKey, invoiceRequest: InvoiceRequest, routes: Seq[ReceivingRoute], @@ -313,12 +302,6 @@ object MultiPartHandler { sealed trait Command case class CreateBolt11Invoice(receivePayment: ReceiveStandardPayment) extends Command case class CreateBolt12Invoice(receivePayment: ReceiveOfferPayment) extends Command - private case class WrappedInvoiceResult(invoice: Try[Bolt12Invoice]) extends Command - - sealed trait Bolt12InvoiceResponse - case class InvoiceCreated(invoice: Bolt12Invoice) extends Bolt12InvoiceResponse - sealed trait InvoiceCreationFailed extends Bolt12InvoiceResponse { def message: String } - case class BlindedRouteCreationFailed(message: String) extends InvoiceCreationFailed // @formatter:on def apply(nodeParams: NodeParams): Behavior[Command] = { @@ -363,7 +346,7 @@ object MultiPartHandler { val invoiceFeatures = nodeParams.features.bolt12Features() val invoice = Bolt12Invoice(r.invoiceRequest, r.paymentPreimage, r.nodeKey, nodeParams.invoiceExpiry, invoiceFeatures, paths, r.additionalTlvs, r.customTlvs) context.log.debug("generated invoice={} for offer={}", invoice.toString, r.invoiceRequest.offer.toString) - r.replyTo ! InvoiceCreated(invoice) + r.replyTo ! invoice Behaviors.stopped } } @@ -375,7 +358,7 @@ object MultiPartHandler { // @formatter:off sealed trait Command case class GetIncomingPayment(replyTo: ActorRef) extends Command - case class ProcessPayment(payment: IncomingBlindedPayment) extends Command + case class ProcessPayment(payment: IncomingBlindedPayment, recipientPathFees: MilliSatoshi) extends Command case class RejectPayment(reason: String) extends Command // @formatter:on @@ -395,7 +378,7 @@ object MultiPartHandler { } Behaviors.stopped case payload: FinalPayload.Blinded => - offerManager ! OfferManager.ReceivePayment(context.self, packet.add.paymentHash, payload, packet.add.amountMsat) + offerManager ! OfferManager.ReceivePayment(context.self, packet.add.paymentHash, payload) waitForPayment(context, nodeParams, replyTo, packet.add, payload) } } @@ -405,8 +388,8 @@ object MultiPartHandler { private def waitForPayment(context: typed.scaladsl.ActorContext[Command], nodeParams: NodeParams, replyTo: ActorRef, add: UpdateAddHtlc, payload: FinalPayload.Blinded): Behavior[Command] = { Behaviors.receiveMessagePartial { - case ProcessPayment(payment) => - replyTo ! ProcessBlindedPacket(add, payload, payment) + case ProcessPayment(payment, recipientPathFees) => + replyTo ! ProcessBlindedPacket(add, payload, payment, recipientPathFees) Behaviors.stopped case RejectPayment(reason) => context.log.info("rejecting blinded htlc #{} from channel {}: {}", add.id, add.channelId, reason) @@ -490,11 +473,14 @@ object MultiPartHandler { if (commonOk && secretOk) None else Some(cmdFail) } - private def validateBlindedPayment(nodeParams: NodeParams, add: UpdateAddHtlc, payload: FinalPayload.Blinded, record: IncomingBlindedPayment)(implicit log: LoggingAdapter): Option[CMD_FAIL_HTLC] = { + private def validateBlindedPayment(nodeParams: NodeParams, add: UpdateAddHtlc, payload: FinalPayload.Blinded, record: IncomingBlindedPayment, recipientPathFees: MilliSatoshi)(implicit log: LoggingAdapter): Option[CMD_FAIL_HTLC] = { // We send the same error regardless of the failure to avoid probing attacks. val cmdFail = CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(payload.totalAmount, nodeParams.currentBlockHeight)), commit = true) val commonOk = validateCommon(nodeParams, add, payload, record) - if (commonOk) None else Some(cmdFail) + // The payer isn't aware of the blinded path fees if we decided to hide them. The HTLC amount will thus be smaller + // than the onion amount, but should match when re-adding the blinded path fees. + val pathFeesOk = add.amountMsat + recipientPathFees >= payload.amount + if (commonOk && pathFeesOk) None else Some(cmdFail) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala index dd1110f6f4..6bccb65f41 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartPaymentFSM.scala @@ -139,7 +139,7 @@ object MultiPartPaymentFSM { override def amount: MilliSatoshi = htlc.amountMsat } /** The fee of a blinded route paid by the receiver (us). */ - case class HiddenFeePart(paymentHash: ByteVector32, amount: MilliSatoshi, totalAmount: MilliSatoshi) extends PaymentPart + case class RecipientBlindedPathFeePart(paymentHash: ByteVector32, amount: MilliSatoshi, totalAmount: MilliSatoshi) extends PaymentPart /** We successfully received all parts of the payment. */ case class MultiPartPaymentSucceeded(paymentHash: ByteVector32, parts: Queue[PaymentPart]) /** We aborted the payment because of an inconsistency in the payment set or because we didn't receive the total amount in reasonable time. */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala index e570551bc7..ab37eeaff0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/PaymentLifecycle.scala @@ -338,9 +338,6 @@ class PaymentLifecycle(nodeParams: NodeParams, cfg: SendPaymentConfig, router: A // this is most likely a liquidity issue, we remove this edge for our next payment attempt data.recipient.extraEdges.filterNot(edge => edge.sourceNodeId == nodeId && edge.targetNodeId == hop.nextNodeId) } - case _: HopRelayParams.Dummy => - log.error("received an update for a dummy hop, this should never happen") - data.recipient.extraEdges } case None => log.error(s"couldn't find node=$nodeId in the route, this should never happen") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index d95a1d6d22..dcc9e55114 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -462,7 +462,7 @@ object Router { for { scid <- aliases.remoteAlias_opt update <- remoteUpdate_opt - } yield (Bolt11Invoice.ExtraHop(remoteNodeId, scid, update.feeBaseMsat, update.feeProportionalMillionths, update.cltvExpiryDelta)) + } yield Bolt11Invoice.ExtraHop(remoteNodeId, scid, update.feeBaseMsat, update.feeProportionalMillionths, update.cltvExpiryDelta) } } // @formatter:on @@ -512,11 +512,6 @@ object Router { override val htlcMaximum_opt = extraHop.htlcMaximum_opt } - case class Dummy(relayFees: Relayer.RelayFees, cltvExpiryDelta: CltvExpiryDelta) extends HopRelayParams { - override val htlcMinimum: MilliSatoshi = 1 msat - override val htlcMaximum_opt: Option[MilliSatoshi] = None - } - def areSame(a: HopRelayParams, b: HopRelayParams, ignoreHtlcSize: Boolean = false): Boolean = a.cltvExpiryDelta == b.cltvExpiryDelta && a.relayFees == b.relayFees && @@ -539,8 +534,11 @@ object Router { } object ChannelHop { - def dummy(nodeId: PublicKey, feeBase: MilliSatoshi, feeProportionalMillionths: Long, cltvExpiryDelta: CltvExpiryDelta): ChannelHop = - ChannelHop(ShortChannelId.toSelf, nodeId, nodeId, HopRelayParams.Dummy(Relayer.RelayFees(feeBase, feeProportionalMillionths), cltvExpiryDelta)) + /** Create a dummy channel hop, used for example when padding blinded routes to a fixed length. */ + def dummy(nodeId: PublicKey, feeBase: MilliSatoshi, feeProportionalMillionths: Long, cltvExpiryDelta: CltvExpiryDelta): ChannelHop = { + val dummyEdge = ExtraEdge(nodeId, nodeId, ShortChannelId.toSelf, feeBase, feeProportionalMillionths, cltvExpiryDelta, 1 msat, None) + ChannelHop(ShortChannelId.toSelf, nodeId, nodeId, HopRelayParams.FromHint(dummyEdge)) + } } sealed trait FinalHop extends Hop @@ -678,11 +676,13 @@ object Router { } } + // @formatter:off sealed trait PaymentRouteResponse case class RouteResponse(routes: Seq[Route]) extends PaymentRouteResponse { require(routes.nonEmpty, "routes cannot be empty") } case class PaymentRouteNotFound(error: Throwable) extends PaymentRouteResponse + // @formatter:on // @formatter:off /** A pre-defined route chosen outside of eclair (e.g. manually by a user to do some re-balancing). */ From 59e934be0bc05880cb9b28f0af0d5801dea975f0 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Mon, 24 Feb 2025 15:18:59 +0100 Subject: [PATCH 5/6] Fix computation of recipient path fees --- .../eclair/payment/offer/OfferManager.scala | 12 +-- .../payment/receive/MultiPartHandler.scala | 19 ++-- .../integration/PaymentIntegrationSpec.scala | 30 +++--- .../basic/payment/OfferPaymentSpec.scala | 92 +++++++++---------- .../eclair/payment/MultiPartHandlerSpec.scala | 18 ++-- .../payment/offer/OfferManagerSpec.scala | 46 ++++------ 6 files changed, 104 insertions(+), 113 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala index 6ff2069498..f15c1aa035 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala @@ -65,7 +65,7 @@ object OfferManager { case class RequestInvoice(messagePayload: MessageOnion.InvoiceRequestPayload, blindedKey: PrivateKey, postman: ActorRef[Postman.SendMessage]) extends Command - case class ReceivePayment(replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], paymentHash: ByteVector32, payload: FinalPayload.Blinded) extends Command + case class ReceivePayment(replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], paymentHash: ByteVector32, payload: FinalPayload.Blinded, amountReceived: MilliSatoshi) extends Command /** * Offer handlers must be implemented in separate plugins and respond to these two `HandlerCommand`. @@ -117,14 +117,14 @@ object OfferManager { case _ => context.log.debug("offer {} is not registered or invoice request is invalid", messagePayload.invoiceRequest.offer.offerId) } Behaviors.same - case ReceivePayment(replyTo, paymentHash, payload) => + case ReceivePayment(replyTo, paymentHash, payload, amountReceived) => MinimalInvoiceData.decode(payload.pathId) match { case Some(signed) => registeredOffers.get(signed.offerId) match { case Some(RegisteredOffer(offer, _, _, handler)) => MinimalInvoiceData.verify(nodeParams.nodeId, signed) match { case Some(metadata) if Crypto.sha256(metadata.preimage) == paymentHash => - val child = context.spawnAnonymous(PaymentActor(nodeParams, replyTo, offer, metadata, payload, paymentTimeout)) + val child = context.spawnAnonymous(PaymentActor(nodeParams, replyTo, offer, metadata, amountReceived, paymentTimeout)) handler ! HandlePayment(child, signed.offerId, metadata.pluginData_opt) case Some(_) => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(s"preimage does not match payment hash for offer ${signed.offerId.toHex}") case None => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(s"invalid signature for metadata for offer ${signed.offerId.toHex}") @@ -267,7 +267,7 @@ object OfferManager { replyTo: ActorRef[MultiPartHandler.GetIncomingPaymentActor.Command], offer: Offer, metadata: MinimalInvoiceData, - payload: FinalPayload.Blinded, + amount: MilliSatoshi, timeout: FiniteDuration): Behavior[Command] = { Behaviors.setup { context => context.scheduleOnce(timeout, context.self, RejectPayment("plugin timeout")) @@ -276,8 +276,8 @@ object OfferManager { val minimalInvoice = MinimalBolt12Invoice(offer, nodeParams.chainHash, metadata.amount, metadata.quantity, Crypto.sha256(metadata.preimage), metadata.payerKey, metadata.createdAt, additionalTlvs, customTlvs) val incomingPayment = IncomingBlindedPayment(minimalInvoice, metadata.preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) // We may be deducing some of the blinded path fees from the received amount. - val recipientPathFees = nodeFee(metadata.recipientPathFees, payload.amount) - replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment, recipientPathFees) + val maxRecipientPathFees = nodeFee(metadata.recipientPathFees, amount) + replyTo ! MultiPartHandler.GetIncomingPaymentActor.ProcessPayment(incomingPayment, maxRecipientPathFees) Behaviors.stopped case RejectPayment(reason) => replyTo ! MultiPartHandler.GetIncomingPaymentActor.RejectPayment(reason) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 2a2bca0f10..3d7722037b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -125,13 +125,14 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP } } - case ProcessBlindedPacket(add, payload, payment, recipientPathFees) if doHandle(add.paymentHash) => + case ProcessBlindedPacket(add, payload, payment, maxRecipientPathFees) if doHandle(add.paymentHash) => Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(add.paymentHash))) { - validateBlindedPayment(nodeParams, add, payload, payment, recipientPathFees) match { + validateBlindedPayment(nodeParams, add, payload, payment, maxRecipientPathFees) match { case Some(cmdFail) => Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, Tags.FailureType(cmdFail)).increment() PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmdFail) case None => + val recipientPathFees = payload.amount - add.amountMsat log.debug("received payment for amount={} recipientPathFees={} totalAmount={}", add.amountMsat, recipientPathFees, payload.totalAmount) addHtlcPart(ctx, add, payload, payment) if (recipientPathFees > 0.msat) { @@ -237,7 +238,7 @@ object MultiPartHandler { // @formatter:off private case class ProcessPacket(add: UpdateAddHtlc, payload: FinalPayload.Standard, payment_opt: Option[IncomingStandardPayment]) - private case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment, recipientPathFees: MilliSatoshi) + private case class ProcessBlindedPacket(add: UpdateAddHtlc, payload: FinalPayload.Blinded, payment: IncomingBlindedPayment, maxRecipientPathFees: MilliSatoshi) private case class RejectPacket(add: UpdateAddHtlc, failure: FailureMessage) case class DoFulfill(payment: IncomingPayment, success: MultiPartPaymentFSM.MultiPartPaymentSucceeded) @@ -358,7 +359,7 @@ object MultiPartHandler { // @formatter:off sealed trait Command case class GetIncomingPayment(replyTo: ActorRef) extends Command - case class ProcessPayment(payment: IncomingBlindedPayment, recipientPathFees: MilliSatoshi) extends Command + case class ProcessPayment(payment: IncomingBlindedPayment, maxRecipientPathFees: MilliSatoshi) extends Command case class RejectPayment(reason: String) extends Command // @formatter:on @@ -378,7 +379,7 @@ object MultiPartHandler { } Behaviors.stopped case payload: FinalPayload.Blinded => - offerManager ! OfferManager.ReceivePayment(context.self, packet.add.paymentHash, payload) + offerManager ! OfferManager.ReceivePayment(context.self, packet.add.paymentHash, payload, packet.add.amountMsat) waitForPayment(context, nodeParams, replyTo, packet.add, payload) } } @@ -388,8 +389,8 @@ object MultiPartHandler { private def waitForPayment(context: typed.scaladsl.ActorContext[Command], nodeParams: NodeParams, replyTo: ActorRef, add: UpdateAddHtlc, payload: FinalPayload.Blinded): Behavior[Command] = { Behaviors.receiveMessagePartial { - case ProcessPayment(payment, recipientPathFees) => - replyTo ! ProcessBlindedPacket(add, payload, payment, recipientPathFees) + case ProcessPayment(payment, maxRecipientPathFees) => + replyTo ! ProcessBlindedPacket(add, payload, payment, maxRecipientPathFees) Behaviors.stopped case RejectPayment(reason) => context.log.info("rejecting blinded htlc #{} from channel {}: {}", add.id, add.channelId, reason) @@ -473,13 +474,13 @@ object MultiPartHandler { if (commonOk && secretOk) None else Some(cmdFail) } - private def validateBlindedPayment(nodeParams: NodeParams, add: UpdateAddHtlc, payload: FinalPayload.Blinded, record: IncomingBlindedPayment, recipientPathFees: MilliSatoshi)(implicit log: LoggingAdapter): Option[CMD_FAIL_HTLC] = { + private def validateBlindedPayment(nodeParams: NodeParams, add: UpdateAddHtlc, payload: FinalPayload.Blinded, record: IncomingBlindedPayment, maxRecipientPathFees: MilliSatoshi)(implicit log: LoggingAdapter): Option[CMD_FAIL_HTLC] = { // We send the same error regardless of the failure to avoid probing attacks. val cmdFail = CMD_FAIL_HTLC(add.id, FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(payload.totalAmount, nodeParams.currentBlockHeight)), commit = true) val commonOk = validateCommon(nodeParams, add, payload, record) // The payer isn't aware of the blinded path fees if we decided to hide them. The HTLC amount will thus be smaller // than the onion amount, but should match when re-adding the blinded path fees. - val pathFeesOk = add.amountMsat + recipientPathFees >= payload.amount + val pathFeesOk = payload.amount - add.amountMsat <= maxRecipientPathFees if (commonOk && pathFeesOk) None else Some(cmdFail) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala index 51fe1a78c7..436a2d5e75 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/PaymentIntegrationSpec.scala @@ -634,11 +634,11 @@ class PaymentIntegrationSpec extends IntegrationSpec { val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] val receivingRoutes = Seq( - OfferManager.InvoiceRequestActor.Route(route1.hops, CltvExpiryDelta(1000)), - OfferManager.InvoiceRequestActor.Route(route2.hops, CltvExpiryDelta(1000)), - OfferManager.InvoiceRequestActor.Route(route3.hops, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(route1.hops, recipientPaysFees = false, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(route2.hops, recipientPaysFees = false, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(route3.hops, recipientPaysFees = false, CltvExpiryDelta(1000)), ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"abcd")) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"abcd")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -668,10 +668,10 @@ class PaymentIntegrationSpec extends IntegrationSpec { val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] // C uses a 0-hop blinded route and signs the invoice with its public nodeId. val receivingRoutes = Seq( - OfferManager.InvoiceRequestActor.Route(Nil, CltvExpiryDelta(1000)), - OfferManager.InvoiceRequestActor.Route(Nil, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(Nil, recipientPaysFees = false, CltvExpiryDelta(1000)), + OfferManager.InvoiceRequestActor.Route(Nil, recipientPaysFees = false, CltvExpiryDelta(1000)), ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"0123")) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"0123")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -703,9 +703,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] val receivingRoutes = Seq( - OfferManager.InvoiceRequestActor.Route(Seq(ChannelHop.dummy(nodes("A").nodeParams.nodeId, 100 msat, 100, CltvExpiryDelta(48)), ChannelHop.dummy(nodes("A").nodeParams.nodeId, 150 msat, 50, CltvExpiryDelta(36))), CltvExpiryDelta(1000)) + OfferManager.InvoiceRequestActor.Route(Seq(ChannelHop.dummy(nodes("A").nodeParams.nodeId, 100 msat, 100, CltvExpiryDelta(48)), ChannelHop.dummy(nodes("A").nodeParams.nodeId, 150 msat, 50, CltvExpiryDelta(36))), recipientPaysFees = false, CltvExpiryDelta(1000)) ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -740,9 +740,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] val receivingRoutes = Seq( - OfferManager.InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(nodes("C").nodeParams.nodeId, 55 msat, 55, CltvExpiryDelta(55)), CltvExpiryDelta(555)) + OfferManager.InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(nodes("C").nodeParams.nodeId, 55 msat, 55, CltvExpiryDelta(55)), recipientPaysFees = false, CltvExpiryDelta(555)) ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"eff0")) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"eff0")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -773,8 +773,8 @@ class PaymentIntegrationSpec extends IntegrationSpec { val route = sender.expectMsgType[Router.RouteResponse].routes.head val handleInvoiceRequest = offerHandler.expectMessageType[HandleInvoiceRequest] - val receivingRoutes = Seq(OfferManager.InvoiceRequestActor.Route(route.hops, CltvExpiryDelta(500))) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false, pluginData_opt = Some(hex"0123")) + val receivingRoutes = Seq(OfferManager.InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, CltvExpiryDelta(500))) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, pluginData_opt = Some(hex"0123")) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) @@ -821,9 +821,9 @@ class PaymentIntegrationSpec extends IntegrationSpec { ShortChannelIdDir(channelBC.nodeId1 == nodes("B").nodeParams.nodeId, channelBC.shortChannelId) } val receivingRoutes = Seq( - OfferManager.InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(nodes("C").nodeParams.nodeId, 55 msat, 55, CltvExpiryDelta(55)), CltvExpiryDelta(555), Some(scidDirCB)) + OfferManager.InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(nodes("C").nodeParams.nodeId, 55 msat, 55, CltvExpiryDelta(55)), recipientPaysFees = false, CltvExpiryDelta(555), Some(scidDirCB)) ) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes, hideFees = false) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, receivingRoutes) val handlePayment = offerHandler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala index efb37a76f9..0071757513 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala @@ -153,10 +153,10 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { } } - def offerHandler(amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], hideFees: Boolean): Behavior[OfferManager.HandlerCommand] = { + def offerHandler(amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route]): Behavior[OfferManager.HandlerCommand] = { Behaviors.receiveMessage { case OfferManager.HandleInvoiceRequest(replyTo, _) => - replyTo ! InvoiceRequestActor.ApproveRequest(amount, routes, hideFees) + replyTo ! InvoiceRequestActor.ApproveRequest(amount, routes) Behaviors.same case OfferManager.HandlePayment(replyTo, _, _) => replyTo ! OfferManager.PaymentActor.AcceptPayment() @@ -164,12 +164,12 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { } } - def sendOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], maxAttempts: Int = 1, hideFees: Boolean = false): (Offer, PaymentEvent) = { + def sendOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], maxAttempts: Int = 1): (Offer, PaymentEvent) = { import f._ val sender = TestProbe("sender") val offer = Offer(None, Some("test"), recipient.nodeId, Features.empty, recipient.nodeParams.chainHash) - val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes, hideFees)) + val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) recipient.offerManager ! OfferManager.RegisterOffer(offer, Some(recipient.nodeParams.privateKey), None, handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.router, payer.register, payer.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts, payer.routeParams, blocking = true) @@ -177,7 +177,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { (offer, sender.expectMsgType[PaymentEvent]) } - def sendPrivateOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], maxAttempts: Int = 1, hideFees: Boolean = false): (Offer, PaymentEvent) = { + def sendPrivateOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], maxAttempts: Int = 1): (Offer, PaymentEvent) = { import f._ val sender = TestProbe("sender") @@ -188,7 +188,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { buildRoute(randomKey(), intermediateNodes, Recipient(recipient.nodeId, Some(pathId))).route }) val offer = Offer(None, Some("test"), recipientKey.publicKey, Features.empty, recipient.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(offerPaths))) - val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes, hideFees)) + val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) recipient.offerManager ! OfferManager.RegisterOffer(offer, Some(recipientKey), Some(pathId), handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.router, payer.register, payer.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts, payer.routeParams, blocking = true) @@ -196,13 +196,13 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { (offer, sender.expectMsgType[PaymentEvent]) } - def sendOfferPaymentWithInvalidAmount(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, payerAmount: MilliSatoshi, recipientAmount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route], hideFees: Boolean = false): PaymentFailed = { + def sendOfferPaymentWithInvalidAmount(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, payerAmount: MilliSatoshi, recipientAmount: MilliSatoshi, routes: Seq[InvoiceRequestActor.Route]): PaymentFailed = { import f._ val sender = TestProbe("sender") val paymentInterceptor = TestProbe("payment-interceptor") val offer = Offer(None, Some("test"), recipient.nodeId, Features.empty, recipient.nodeParams.chainHash) - val handler = recipient.system.spawnAnonymous(offerHandler(recipientAmount, routes, hideFees)) + val handler = recipient.system.spawnAnonymous(offerHandler(recipientAmount, routes)) recipient.offerManager ! OfferManager.RegisterOffer(offer, Some(recipient.nodeParams.privateKey), None, handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.router, payer.register, paymentInterceptor.ref)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, payer.routeParams, blocking = true) @@ -231,7 +231,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -246,8 +246,8 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) - val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, hideFees = true) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = true, maxFinalExpiryDelta)) + val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) assert(payment.parts.head.amount == amount) @@ -264,8 +264,8 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val route = sender.expectMsgType[Router.RouteResponse].routes.head val routes = Seq( - InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), - InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta), ) val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) @@ -282,10 +282,10 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val route = sender.expectMsgType[Router.RouteResponse].routes.head val routes = Seq( - InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), - InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops, recipientPaysFees = true, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops, recipientPaysFees = true, maxFinalExpiryDelta), ) - val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3, hideFees = true) + val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 2) assert(payment.parts.forall(_.feesPaid == 0.msat)) @@ -299,7 +299,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val route = sender.expectMsgType[Router.RouteResponse].routes.head // Carol advertises a single blinded path from Bob to herself. - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) // We make a first set of payments to ensure channels have less than 50 000 sat on Bob's side. Seq(50_000_000 msat, 50_000_000 msat).foreach(amount => { @@ -331,7 +331,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(10_000_000 msat, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val amount1 = 150_000_000 msat val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount1, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount1, result) @@ -347,8 +347,8 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val amount = 125_000_000 msat val routes = Seq( - InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 150 msat, 0, CltvExpiryDelta(50)), maxFinalExpiryDelta), - InvoiceRequestActor.Route(route.hops ++ Seq(ChannelHop.dummy(carol.nodeId, 50 msat, 0, CltvExpiryDelta(20)), ChannelHop.dummy(carol.nodeId, 100 msat, 0, CltvExpiryDelta(30))), maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 150 msat, 0, CltvExpiryDelta(50)), recipientPaysFees = false, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops ++ Seq(ChannelHop.dummy(carol.nodeId, 50 msat, 0, CltvExpiryDelta(20)), ChannelHop.dummy(carol.nodeId, 100 msat, 0, CltvExpiryDelta(30))), recipientPaysFees = false, maxFinalExpiryDelta), ) val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) @@ -364,10 +364,10 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val amount = 125_000_000 msat val routes = Seq( - InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 150 msat, 0, CltvExpiryDelta(50)), maxFinalExpiryDelta), - InvoiceRequestActor.Route(route.hops ++ Seq(ChannelHop.dummy(carol.nodeId, 50 msat, 0, CltvExpiryDelta(20)), ChannelHop.dummy(carol.nodeId, 100 msat, 0, CltvExpiryDelta(30))), maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 150 msat, 0, CltvExpiryDelta(50)), recipientPaysFees = true, maxFinalExpiryDelta), + InvoiceRequestActor.Route(route.hops ++ Seq(ChannelHop.dummy(carol.nodeId, 50 msat, 0, CltvExpiryDelta(20)), ChannelHop.dummy(carol.nodeId, 100 msat, 0, CltvExpiryDelta(30))), recipientPaysFees = true, maxFinalExpiryDelta), ) - val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, hideFees = true) + val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 2) assert(payment.parts.forall(_.feesPaid == 0.msat)) @@ -382,7 +382,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount, routes) verifyPaymentSuccess(offer, amount, result) } @@ -396,8 +396,8 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) - val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount, routes, hideFees = true) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = true, maxFinalExpiryDelta)) + val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.forall(_.feesPaid == 0.msat)) } @@ -406,7 +406,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { import f._ val amount = 75_000_000 msat - val routes = Seq(InvoiceRequestActor.Route(Nil, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(Nil, recipientPaysFees = false, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -416,7 +416,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { import f._ val amount = 250_000_000 msat - val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 10 msat, 25, CltvExpiryDelta(24)), ChannelHop.dummy(bob.nodeId, 5 msat, 10, CltvExpiryDelta(36))), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 10 msat, 25, CltvExpiryDelta(24)), ChannelHop.dummy(bob.nodeId, 5 msat, 10, CltvExpiryDelta(36))), recipientPaysFees = false, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -426,8 +426,8 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { import f._ val amount = 250_000_000 msat - val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 10 msat, 25, CltvExpiryDelta(24)), ChannelHop.dummy(bob.nodeId, 5 msat, 10, CltvExpiryDelta(36))), maxFinalExpiryDelta)) - val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes, hideFees = true) + val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 10 msat, 25, CltvExpiryDelta(24)), ChannelHop.dummy(bob.nodeId, 5 msat, 10, CltvExpiryDelta(36))), recipientPaysFees = true, maxFinalExpiryDelta)) + val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) assert(payment.parts.forall(_.feesPaid == 0.msat)) @@ -442,7 +442,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, bob, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -460,7 +460,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val route = sender.expectMsgType[Router.RouteResponse].routes.head // Carol creates a blinded path using that channel. - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) // We make a payment to ensure that the channel contains less than 150 000 sat on Bob's side. assert(sendPayment(bob, carol, 50_000_000 msat).isRight) @@ -485,7 +485,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 25 msat, 250, CltvExpiryDelta(75)), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops :+ ChannelHop.dummy(carol.nodeId, 25 msat, 250, CltvExpiryDelta(75)), recipientPaysFees = false, maxFinalExpiryDelta)) val (offer, result) = sendOfferPayment(f, bob, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) @@ -511,7 +511,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { // Carol receives a first payment through those channels. { - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val amount1 = 100_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount1, routes) val payment = verifyPaymentSuccess(offer, amount1, result) @@ -527,7 +527,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { // Carol receives a second payment that requires using MPP. { - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val amount2 = 200_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount2, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount2, result) @@ -556,7 +556,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val route = sender.expectMsgType[Router.RouteResponse].routes.head // Carol receives a payment that requires using MPP. - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val amount = 300_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) @@ -584,7 +584,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val route = sender.expectMsgType[Router.RouteResponse].routes.head // Carol receives a payment that requires using MPP. - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val amount = 200_000_000 msat val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) @@ -614,7 +614,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { sender.expectMsgType[PaymentSent] }) // Bob now doesn't have enough funds to relay the payment. - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val (_, result) = sendOfferPayment(f, alice, carol, 75_000_000 msat, routes) verifyBlindedFailure(result, bob.nodeId) } @@ -626,7 +626,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(25_000_000 msat, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, CltvExpiryDelta(-500))) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, CltvExpiryDelta(-500))) val (_, result) = sendOfferPayment(f, alice, carol, 25_000_000 msat, routes) verifyBlindedFailure(result, bob.nodeId) } @@ -641,7 +641,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(recipientAmount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) // The amount is below what Carol expects. val payment = sendOfferPaymentWithInvalidAmount(f, alice, carol, payerAmount, recipientAmount, routes) verifyBlindedFailure(payment, bob.nodeId) @@ -652,7 +652,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val payerAmount = 25_000_000 msat val recipientAmount = 50_000_000 msat - val routes = Seq(InvoiceRequestActor.Route(Nil, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(Nil, recipientPaysFees = false, maxFinalExpiryDelta)) // The amount is below what Bob expects: since he is both the introduction node and the final recipient, he sends // back a normal error. val payment = sendOfferPaymentWithInvalidAmount(f, alice, bob, payerAmount, recipientAmount, routes) @@ -668,7 +668,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val payerAmount = 25_000_000 msat val recipientAmount = 50_000_000 msat - val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 1 msat, 100, CltvExpiryDelta(48))), maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(Seq(ChannelHop.dummy(bob.nodeId, 1 msat, 100, CltvExpiryDelta(48))), recipientPaysFees = false, maxFinalExpiryDelta)) // The amount is below what Bob expects: since he is both the introduction node and the final recipient, he sends // back a normal error. val payment = sendOfferPaymentWithInvalidAmount(f, alice, bob, payerAmount, recipientAmount, routes) @@ -689,7 +689,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(recipientAmount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val routes = Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)) + val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) // The amount is below what Carol expects. val payment = sendOfferPaymentWithInvalidAmount(f, bob, carol, payerAmount, recipientAmount, routes) assert(payment.failures.head.isInstanceOf[PaymentFailure]) @@ -714,8 +714,8 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { carol.router ! Router.FinalizeRoute(sender.ref.toTyped, Router.PredefinedNodeRoute(amount, Seq(bob.nodeId, carol.nodeId))) val route = sender.expectMsgType[Router.RouteResponse].routes.head - val receivingRoute = InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta) - val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(receivingRoute), hideFees = false)) + val receivingRoute = InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta) + val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(receivingRoute))) carol.offerManager ! OfferManager.RegisterOffer(compactOffer, Some(recipientKey), Some(pathId), handler) val offerPayment = alice.system.spawnAnonymous(OfferPayment(alice.nodeParams, alice.postman, alice.router, alice.register, alice.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, alice.routeParams, blocking = true) @@ -736,7 +736,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val offerPaths = Seq(OnionMessages.buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId)), Recipient(carol.nodeId, Some(pathId))).route) val offer = Offer.withPaths(None, Some("implicit node id"), offerPaths, Features.empty, carol.nodeParams.chainHash) - val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(InvoiceRequestActor.Route(route.hops, maxFinalExpiryDelta)), hideFees = false)) + val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)))) carol.offerManager ! OfferManager.RegisterOffer(offer, None, Some(pathId), handler) val offerPayment = alice.system.spawnAnonymous(OfferPayment(alice.nodeParams, alice.postman, alice.router, alice.register, alice.paymentInitiator)) val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, alice.routeParams, blocking = true) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala index 53d84b359d..00b8f6eea3 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/MultiPartHandlerSpec.scala @@ -171,7 +171,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val router = TestProbe() sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, privKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) router.expectNoMessage(50 millis) - val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice + val invoice = sender.expectMsgType[Bolt12Invoice] // Offer invoices shouldn't be stored in the DB until we receive a payment for it. assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).isEmpty) @@ -181,7 +181,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, 0 msat) assert(register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]].message.id == finalPacket.add.id) val paymentReceived = eventListener.expectMsgType[PaymentReceived] @@ -282,7 +282,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike ReceivingRoute(Nil, randomBytes32(), CltvExpiryDelta(250), PaymentInfo(0 msat, 0, nodeParams.channelConf.minFinalExpiryDelta, 0 msat, amount, Features.empty)), ) sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, privKey, invoiceReq, receivingRoutes, randomBytes32())) - val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice + val invoice = sender.expectMsgType[Bolt12Invoice] assert(invoice.amount == amount) assert(invoice.nodeId == privKey.publicKey) assert(invoice.blindedPaths.nonEmpty) @@ -462,7 +462,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, Some("a blinded coffee please"), nodeKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(randomBytes32()), randomBytes32())) - val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice + val invoice = sender.expectMsgType[Bolt12Invoice] val add = UpdateAddHtlc(ByteVector32.One, 0, 5000 msat, invoice.paymentHash, defaultExpiry, TestConstants.emptyOnionPacket, None, 1.0, None) sender.send(handlerWithMpp, IncomingPaymentPacket.FinalPacket(add, FinalPayload.Standard.createPayload(add.amountMsat, add.amountMsat, add.cltvExpiry, randomBytes32(), None))) @@ -480,7 +480,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val preimage = randomBytes32() val pathId = randomBytes32() sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) - val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice + val invoice = sender.expectMsgType[Bolt12Invoice] assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).isEmpty) val packet = createBlindedPacket(5000 msat, invoice.paymentHash, defaultExpiry, CltvExpiry(nodeParams.currentBlockHeight), pathId) @@ -489,7 +489,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, 0 msat) register.expectMsgType[Register.Forward[CMD_FULFILL_HTLC]] assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).get.status.isInstanceOf[IncomingPaymentStatus.Received]) } @@ -503,7 +503,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val offer = Offer(None, Some("a blinded coffee please"), nodeKey.publicKey, Features.empty, Block.RegtestGenesisBlock.hash) val invoiceReq = InvoiceRequest(offer, 5000 msat, 1, featuresWithRouteBlinding.bolt12Features(), randomKey(), Block.RegtestGenesisBlock.hash) sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) - val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice + val invoice = sender.expectMsgType[Bolt12Invoice] val packet = createBlindedPacket(5000 msat, invoice.paymentHash, defaultExpiry, CltvExpiry(nodeParams.currentBlockHeight), pathId) sender.send(handlerWithRouteBlinding, packet) @@ -523,7 +523,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val preimage = randomBytes32() val pathId = randomBytes32() sender.send(handlerWithRouteBlinding, ReceiveOfferPayment(sender.ref, nodeKey, invoiceReq, createEmptyReceivingRoute(pathId), preimage)) - val invoice = sender.expectMsgType[CreateInvoiceActor.InvoiceCreated].invoice + val invoice = sender.expectMsgType[Bolt12Invoice] // We test the case where the HTLC's cltv_expiry is lower than expected and doesn't meet the min_final_expiry_delta. val packet = createBlindedPacket(5000 msat, invoice.paymentHash, defaultExpiry - CltvExpiryDelta(1), defaultExpiry, pathId) @@ -532,7 +532,7 @@ class MultiPartHandlerSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike assert(receivePayment.paymentHash == invoice.paymentHash) assert(receivePayment.payload.pathId == pathId.bytes) val payment = IncomingBlindedPayment(MinimalBolt12Invoice(invoice.records), preimage, PaymentType.Blinded, TimestampMilli.now(), IncomingPaymentStatus.Pending) - receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment) + receivePayment.replyTo ! GetIncomingPaymentActor.ProcessPayment(payment, 0 msat) val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_HTLC]].message assert(cmd.reason == FailureReason.LocalFailure(IncorrectOrUnknownPaymentDetails(5000 msat, nodeParams.currentBlockHeight))) assert(nodeParams.db.payments.getIncomingPayment(invoice.paymentHash).isEmpty) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala index 81e1cd90c4..666e76e413 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala @@ -34,7 +34,7 @@ import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, amountAfterFee, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, amountAfterFee, nodeFee, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} @@ -77,7 +77,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val handleInvoiceRequest = handler.expectMessageType[HandleInvoiceRequest] assert(handleInvoiceRequest.invoiceRequest.isValid) assert(handleInvoiceRequest.invoiceRequest.payerId == payerKey.publicKey) - handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, Seq(InvoiceRequestActor.Route(hops, CltvExpiryDelta(1000))), hideFees, pluginData_opt) + handleInvoiceRequest.replyTo ! InvoiceRequestActor.ApproveRequest(amount, Seq(InvoiceRequestActor.Route(hops, hideFees, CltvExpiryDelta(1000))), pluginData_opt) val invoiceMessage = postman.expectMessageType[Postman.SendMessage] val Right(invoice) = Bolt12Invoice.validate(invoiceMessage.message.get[OnionMessagePayloadTlv.Invoice].get.tlvs) assert(invoice.validateFor(handleInvoiceRequest.invoiceRequest, pathNodeId).isRight) @@ -125,7 +125,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(handlePayment.offerId == offer.offerId) assert(handlePayment.pluginData_opt.contains(hex"deadbeef")) handlePayment.replyTo ! PaymentActor.AcceptPayment() - val ProcessPayment(incomingPayment) = paymentHandler.expectMessageType[ProcessPayment] + val ProcessPayment(incomingPayment, _) = paymentHandler.expectMessageType[ProcessPayment] assert(Crypto.sha256(incomingPayment.paymentPreimage) == invoice.paymentHash) assert(incomingPayment.invoice.nodeId == nodeParams.nodeId) assert(incomingPayment.invoice.paymentHash == invoice.paymentHash) @@ -279,7 +279,7 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(paymentHandler.expectMessageType[RejectPayment].reason == "internal error") } - test("invalid payment (incorrect amount)") { f => + test("pay offer without hidden fee") { f => import f._ val handler = TestProbe[HandlerCommand]() @@ -290,11 +290,17 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val payerKey = randomKey() requestInvoice(payerKey, offer, nodeParams.privateKey, amount, offerManager, postman.ref) val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler) - // Try sending 1 msat less than needed val paymentPayload = createPaymentPayload(f, invoice) - offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amount - 1.msat) - paymentHandler.expectMessageType[RejectPayment] - handler.expectNoMessage(50 millis) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amount) + + val handlePayment = handler.expectMessageType[HandlePayment] + assert(handlePayment.offerId == offer.offerId) + handlePayment.replyTo ! PaymentActor.AcceptPayment() + val ProcessPayment(incomingPayment, maxRecipientPathFees) = paymentHandler.expectMessageType[ProcessPayment] + assert(Crypto.sha256(incomingPayment.paymentPreimage) == invoice.paymentHash) + assert(incomingPayment.invoice.nodeId == nodeParams.nodeId) + assert(incomingPayment.invoice.paymentHash == invoice.paymentHash) + assert(maxRecipientPathFees == 0.msat) } test("pay offer with hidden fees") { f => @@ -310,32 +316,16 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler, hops = List(ChannelHop.dummy(nodeParams.nodeId, 1000 msat, 200, CltvExpiryDelta(144))), hideFees = true) // Sending less than the full amount as fees are paid by the recipient val paymentPayload = createPaymentPayload(f, invoice) - offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amountAfterFee(1000 msat, 200, amount)) + val amountReceived = amountAfterFee(1000 msat, 200, amount) + offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amountReceived) val handlePayment = handler.expectMessageType[HandlePayment] assert(handlePayment.offerId == offer.offerId) handlePayment.replyTo ! PaymentActor.AcceptPayment() - val ProcessPayment(incomingPayment) = paymentHandler.expectMessageType[ProcessPayment] + val ProcessPayment(incomingPayment, maxRecipientPathFees) = paymentHandler.expectMessageType[ProcessPayment] assert(Crypto.sha256(incomingPayment.paymentPreimage) == invoice.paymentHash) assert(incomingPayment.invoice.nodeId == nodeParams.nodeId) assert(incomingPayment.invoice.paymentHash == invoice.paymentHash) - } - - test("invalid payment (incorrect amount with hidden fee)") { f => - import f._ - - val handler = TestProbe[HandlerCommand]() - val amount = 10_000_000 msat - val offer = Offer(Some(amount), Some("offer"), nodeParams.nodeId, Features.empty, nodeParams.chainHash) - offerManager ! RegisterOffer(offer, Some(nodeParams.privateKey), None, handler.ref) - // Request invoice. - val payerKey = randomKey() - requestInvoice(payerKey, offer, nodeParams.privateKey, amount, offerManager, postman.ref) - val invoice = receiveInvoice(f, amount, payerKey, nodeParams.nodeId, handler, hops = List(ChannelHop.dummy(nodeParams.nodeId, 1000 msat, 200, CltvExpiryDelta(144))), hideFees = true) - // Try sending 1 msat less than needed - val paymentPayload = createPaymentPayload(f, invoice) - offerManager ! ReceivePayment(paymentHandler.ref, invoice.paymentHash, paymentPayload, amountAfterFee(1000 msat, 200, amount) - 1.msat) - paymentHandler.expectMessageType[RejectPayment] - handler.expectNoMessage(50 millis) + assert(maxRecipientPathFees == paymentPayload.amount - amountReceived) } } From c2d1546b801f18f4e0e57890a09a41b89c463869 Mon Sep 17 00:00:00 2001 From: Thomas HUET Date: Tue, 25 Feb 2025 16:55:13 +0100 Subject: [PATCH 6/6] tests --- .../basic/payment/OfferPaymentSpec.scala | 8 ++++-- .../payment/offer/OfferManagerSpec.scala | 26 +++++++++++-------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala index 0071757513..6d4e211e2f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala @@ -235,6 +235,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) + assert(payment.parts.head.feesPaid > 0.msat) } test("send blinded payment a->b->c, hidden fees") { f => @@ -250,7 +251,6 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) - assert(payment.parts.head.amount == amount) assert(payment.parts.head.feesPaid == 0.msat) } @@ -270,6 +270,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes, maxAttempts = 3) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 2) + assert(payment.parts.forall(_.feesPaid > 0.msat)) } test("send blinded multi-part payment a->b->c, hidden fees") { f => @@ -353,6 +354,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 2) + assert(payment.parts.forall(_.feesPaid > 0.msat)) } test("send blinded payment a->b->c with dummy hops, hidden fees") { f => @@ -384,7 +386,8 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val routes = Seq(InvoiceRequestActor.Route(route.hops, recipientPaysFees = false, maxFinalExpiryDelta)) val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount, routes) - verifyPaymentSuccess(offer, amount, result) + val payment = verifyPaymentSuccess(offer, amount, result) + assert(payment.parts.forall(_.feesPaid > 0.msat)) } test("send blinded payment a->b->c through private channels, hidden fees", Tag(PrivateChannels)) { f => @@ -420,6 +423,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val (offer, result) = sendOfferPayment(f, alice, bob, amount, routes) val payment = verifyPaymentSuccess(offer, amount, result) assert(payment.parts.length == 1) + assert(payment.parts.forall(_.feesPaid > 0.msat)) } test("send blinded payment a->b with dummy hops, hidden fees") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala index 666e76e413..0f0b8596df 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/offer/OfferManagerSpec.scala @@ -28,17 +28,16 @@ import fr.acinq.eclair.payment.Bolt12Invoice import fr.acinq.eclair.payment.offer.OfferManager._ import fr.acinq.eclair.payment.receive.MultiPartHandler import fr.acinq.eclair.payment.receive.MultiPartHandler.GetIncomingPaymentActor.{ProcessPayment, RejectPayment} -import fr.acinq.eclair.payment.receive.MultiPartHandler.ReceivingRoute -import fr.acinq.eclair.payment.relay.Relayer.RelayFees import fr.acinq.eclair.router.Router.ChannelHop import fr.acinq.eclair.wire.protocol.OfferTypes.{InvoiceRequest, Offer} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.RouteBlindingDecryptedData import fr.acinq.eclair.wire.protocol._ -import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, amountAfterFee, nodeFee, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, NodeParams, TestConstants, amountAfterFee, randomBytes32, randomKey} import org.scalatest.funsuite.FixtureAnyFunSuiteLike import org.scalatest.{Outcome, Tag} import scodec.bits.{ByteVector, HexStringSyntax} +import scala.annotation.tailrec import scala.concurrent.duration.DurationInt class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { @@ -86,19 +85,24 @@ class OfferManagerSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app invoice } + /** Decrypt the provided encrypted payloads, assuming we're using only dummy hops for the target node. */ + @tailrec + private def decryptBlindedPayload(nodeKey: PrivateKey, pathKey: PublicKey, encryptedPayloads: Seq[ByteVector]): TlvStream[RouteBlindingEncryptedDataTlv] = { + if (encryptedPayloads.size == 1) { + val Right(RouteBlindingDecryptedData(encryptedDataTlvs, _)) = RouteBlindingEncryptedDataCodecs.decode(nodeKey, pathKey, encryptedPayloads.head) + encryptedDataTlvs + } else { + val Right(RouteBlindingDecryptedData(_, nextPathKey)) = RouteBlindingEncryptedDataCodecs.decode(nodeKey, pathKey, encryptedPayloads.head) + decryptBlindedPayload(nodeKey, nextPathKey, encryptedPayloads.tail) + } + } + def createPaymentPayload(f: FixtureParam, invoice: Bolt12Invoice): PaymentOnion.FinalPayload.Blinded = { import f._ assert(invoice.blindedPaths.length == 1) val blindedPath = invoice.blindedPaths.head.route - val Right(RouteBlindingDecryptedData(tlvs, nextPathKey)) = RouteBlindingEncryptedDataCodecs.decode(nodeParams.privateKey, blindedPath.firstPathKey, blindedPath.encryptedPayloads.head) - var encryptedDataTlvs = tlvs - var pathKey = nextPathKey - for (encryptedPayload <- blindedPath.encryptedPayloads.drop(1)) { - val Right(RouteBlindingDecryptedData(tlvs, nextPathKey)) = RouteBlindingEncryptedDataCodecs.decode(nodeParams.privateKey, pathKey, encryptedPayload) - encryptedDataTlvs = tlvs - pathKey = nextPathKey - } + val encryptedDataTlvs = decryptBlindedPayload(nodeParams.privateKey, blindedPath.firstPathKey, blindedPath.encryptedPayloads) val paymentTlvs = TlvStream[OnionPaymentPayloadTlv]( OnionPaymentPayloadTlv.AmountToForward(invoice.amount), OnionPaymentPayloadTlv.TotalAmount(invoice.amount),