@@ -12,9 +12,55 @@ import (
1212 "github.com/lightningnetwork/lnd/lnrpc"
1313 "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
1414 "github.com/lightningnetwork/lnd/lntypes"
15+ "github.com/lightningnetwork/lnd/lnwire"
1516 "google.golang.org/grpc"
1617)
1718
19+ // InvoiceHtlcModifyRequest is a request to modify an HTLC that is attempting to
20+ // settle an invoice.
21+ type InvoiceHtlcModifyRequest struct {
22+ // Invoice is the current state of the invoice, _before_ this HTLC is
23+ // applied. Any HTLC in the invoice is a previously accepted/settled
24+ // one.
25+ Invoice * lnrpc.Invoice
26+
27+ // CircuitKey is the circuit key of the HTLC that is attempting to
28+ // settle the invoice.
29+ CircuitKey invpkg.CircuitKey
30+
31+ // ExitHtlcAmt is the amount of the HTLC that is attempting to settle
32+ // the invoice.
33+ ExitHtlcAmt lnwire.MilliSatoshi
34+
35+ // ExitHtlcExpiry is the expiry of the HTLC that is attempting to settle
36+ // the invoice.
37+ ExitHtlcExpiry uint32
38+
39+ // CurrentHeight is the current block height.
40+ CurrentHeight uint32
41+
42+ // WireCustomRecords is the wire custom records of the HTLC that is
43+ // attempting to settle the invoice.
44+ WireCustomRecords lnwire.CustomRecords
45+ }
46+
47+ // InvoiceHtlcModifyResponse is a response to an HTLC modification request.
48+ type InvoiceHtlcModifyResponse struct {
49+ // CircuitKey is the circuit key the response is for.
50+ CircuitKey invpkg.CircuitKey
51+
52+ // AmtPaid is the amount the HTLC contributes toward settling the
53+ // invoice. This amount can be different from the on-chain amount of the
54+ // HTLC in case of custom channels. To not modify the amount and use the
55+ // on-chain amount, set this to 0.
56+ AmtPaid lnwire.MilliSatoshi
57+ }
58+
59+ // InvoiceHtlcModifyHandler is a function that handles an HTLC modification
60+ // request.
61+ type InvoiceHtlcModifyHandler func (context.Context ,
62+ InvoiceHtlcModifyRequest ) (* InvoiceHtlcModifyResponse , error )
63+
1864// InvoicesClient exposes invoice functionality.
1965type InvoicesClient interface {
2066 SubscribeSingleInvoice (ctx context.Context , hash lntypes.Hash ) (
@@ -26,6 +72,14 @@ type InvoicesClient interface {
2672
2773 AddHoldInvoice (ctx context.Context , in * invoicesrpc.AddInvoiceData ) (
2874 string , error )
75+
76+ // HtlcModifier is a bidirectional streaming RPC that allows a client to
77+ // intercept and modify the HTLCs that attempt to settle the given
78+ // invoice. The server will send HTLCs of invoices to the client and the
79+ // client can modify some aspects of the HTLC in order to pass the
80+ // invoice acceptance tests.
81+ HtlcModifier (ctx context.Context ,
82+ handler InvoiceHtlcModifyHandler ) error
2983}
3084
3185// InvoiceUpdate contains a state update for an invoice.
@@ -38,6 +92,8 @@ type invoicesClient struct {
3892 client invoicesrpc.InvoicesClient
3993 invoiceMac serializedMacaroon
4094 timeout time.Duration
95+ quitOnce sync.Once
96+ quit chan struct {}
4197 wg sync.WaitGroup
4298}
4399
@@ -48,10 +104,15 @@ func newInvoicesClient(conn grpc.ClientConnInterface,
48104 client : invoicesrpc .NewInvoicesClient (conn ),
49105 invoiceMac : invoiceMac ,
50106 timeout : timeout ,
107+ quit : make (chan struct {}),
51108 }
52109}
53110
54111func (s * invoicesClient ) WaitForFinished () {
112+ s .quitOnce .Do (func () {
113+ close (s .quit )
114+ })
115+
55116 s .wg .Wait ()
56117}
57118
@@ -184,3 +245,144 @@ func fromRPCInvoiceState(state lnrpc.Invoice_InvoiceState) (
184245
185246 return 0 , errors .New ("unknown state" )
186247}
248+
249+ // HtlcModifier is a bidirectional streaming RPC that allows a client to
250+ // intercept and modify the HTLCs that attempt to settle the given invoice. The
251+ // server will send HTLCs of invoices to the client and the client can modify
252+ // some aspects of the HTLC in order to pass the invoice acceptance tests.
253+ func (s * invoicesClient ) HtlcModifier (ctx context.Context ,
254+ handler InvoiceHtlcModifyHandler ) error {
255+
256+ // Create a child context that will be canceled when this function
257+ // exits. We use this context to be able to cancel goroutines when we
258+ // exit on errors, because the parent context won't be canceled in that
259+ // case.
260+ ctx , cancel := context .WithCancel (ctx )
261+ defer cancel ()
262+
263+ stream , err := s .client .HtlcModifier (
264+ s .invoiceMac .WithMacaroonAuth (ctx ),
265+ )
266+ if err != nil {
267+ return err
268+ }
269+
270+ // Create an error channel that we'll send errors on if any of our
271+ // goroutines fail. We buffer by 1 so that the goroutine doesn't depend
272+ // on the stream being read, and select on context cancellation and
273+ // quit channel so that we do not block in the case where we exit with
274+ // multiple errors.
275+ errChan := make (chan error , 1 )
276+
277+ sendErr := func (err error ) {
278+ select {
279+ case errChan <- err :
280+ case <- ctx .Done ():
281+ case <- s .quit :
282+ }
283+ }
284+
285+ // Start a goroutine that consumes interception requests from lnd and
286+ // sends them into our requests channel for handling. The requests
287+ // channel is not buffered because we expect all requests to be handled
288+ // until this function exits, at which point we expect our context to
289+ // be canceled or quit channel to be closed.
290+ requestChan := make (chan InvoiceHtlcModifyRequest )
291+ s .wg .Add (1 )
292+ go func () {
293+ defer s .wg .Done ()
294+
295+ for {
296+ // Do a quick check whether our client context has been
297+ // canceled so that we can exit sooner if needed.
298+ if ctx .Err () != nil {
299+ return
300+ }
301+
302+ req , err := stream .Recv ()
303+ if err != nil {
304+ sendErr (err )
305+ return
306+ }
307+
308+ wireCustomRecords := req .ExitHtlcWireCustomRecords
309+ interceptReq := InvoiceHtlcModifyRequest {
310+ Invoice : req .Invoice ,
311+ CircuitKey : invpkg.CircuitKey {
312+ ChanID : lnwire .NewShortChanIDFromInt (
313+ req .ExitHtlcCircuitKey .ChanId ,
314+ ),
315+ HtlcID : req .ExitHtlcCircuitKey .HtlcId ,
316+ },
317+ ExitHtlcAmt : lnwire .MilliSatoshi (
318+ req .ExitHtlcAmt ,
319+ ),
320+ ExitHtlcExpiry : req .ExitHtlcExpiry ,
321+ CurrentHeight : req .CurrentHeight ,
322+ WireCustomRecords : wireCustomRecords ,
323+ }
324+
325+ // Try to send our interception request, failing on
326+ // context cancel or router exit.
327+ select {
328+ case requestChan <- interceptReq :
329+
330+ case <- s .quit :
331+ sendErr (ErrRouterShuttingDown )
332+ return
333+
334+ case <- ctx .Done ():
335+ sendErr (ctx .Err ())
336+ return
337+ }
338+ }
339+ }()
340+
341+ for {
342+ select {
343+ case request := <- requestChan :
344+ // Handle requests in a goroutine so that the handler
345+ // provided to this function can be blocking. If we
346+ // get an error, send it into our error channel to
347+ // shut down the interceptor.
348+ s .wg .Add (1 )
349+ go func () {
350+ defer s .wg .Done ()
351+
352+ // Get a response from handler, this may block
353+ // for a while.
354+ resp , err := handler (ctx , request )
355+ if err != nil {
356+ sendErr (err )
357+ return
358+ }
359+
360+ key := resp .CircuitKey
361+ amtPaid := uint64 (resp .AmtPaid )
362+ rpcResp := & invoicesrpc.HtlcModifyResponse {
363+ CircuitKey : & invoicesrpc.CircuitKey {
364+ ChanId : key .ChanID .ToUint64 (),
365+ HtlcId : key .HtlcID ,
366+ },
367+ AmtPaid : & amtPaid ,
368+ }
369+
370+ if err := stream .Send (rpcResp ); err != nil {
371+ sendErr (err )
372+ return
373+ }
374+ }()
375+
376+ // If one of our goroutines fails, exit with the error that
377+ // occurred.
378+ case err := <- errChan :
379+ return err
380+
381+ case <- s .quit :
382+ return ErrRouterShuttingDown
383+
384+ case <- ctx .Done ():
385+ return ctx .Err ()
386+ }
387+ }
388+ }
0 commit comments