@@ -12,9 +12,33 @@ import (
1212 "github.com/lightningnetwork/lnd/lnrpc"
1313 "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
1414 "github.com/lightningnetwork/lnd/lntypes"
15+ "github.com/lightningnetwork/lnd/lnwire"
1516 "google.golang.org/grpc"
1617)
1718
19+ type InvoiceHtlcModifyRequest struct {
20+ Invoice * lnrpc.Invoice
21+
22+ CircuitKey invpkg.CircuitKey
23+
24+ ExitHtlcAmt lnwire.MilliSatoshi
25+
26+ ExitHtlcExpiry uint32
27+
28+ CurrentHeight uint32
29+
30+ WireCustomRecords lnwire.CustomRecords
31+ }
32+
33+ type InvoiceHtlcModifyResponse struct {
34+ CircuitKey invpkg.CircuitKey
35+
36+ AmtPaid lnwire.MilliSatoshi
37+ }
38+
39+ type InvoiceHtlcModifyHandler func (context.Context ,
40+ InvoiceHtlcModifyRequest ) (* InvoiceHtlcModifyResponse , error )
41+
1842// InvoicesClient exposes invoice functionality.
1943type InvoicesClient interface {
2044 SubscribeSingleInvoice (ctx context.Context , hash lntypes.Hash ) (
@@ -26,6 +50,14 @@ type InvoicesClient interface {
2650
2751 AddHoldInvoice (ctx context.Context , in * invoicesrpc.AddInvoiceData ) (
2852 string , error )
53+
54+ // HtlcModifier is a bidirectional streaming RPC that allows a client to
55+ // intercept and modify the HTLCs that attempt to settle the given
56+ // invoice. The server will send HTLCs of invoices to the client and the
57+ // client can modify some aspects of the HTLC in order to pass the
58+ // invoice acceptance tests.
59+ HtlcModifier (ctx context.Context ,
60+ handler InvoiceHtlcModifyHandler ) error
2961}
3062
3163// InvoiceUpdate contains a state update for an invoice.
@@ -38,6 +70,8 @@ type invoicesClient struct {
3870 client invoicesrpc.InvoicesClient
3971 invoiceMac serializedMacaroon
4072 timeout time.Duration
73+ quitOnce sync.Once
74+ quit chan struct {}
4175 wg sync.WaitGroup
4276}
4377
@@ -48,10 +82,15 @@ func newInvoicesClient(conn grpc.ClientConnInterface,
4882 client : invoicesrpc .NewInvoicesClient (conn ),
4983 invoiceMac : invoiceMac ,
5084 timeout : timeout ,
85+ quit : make (chan struct {}),
5186 }
5287}
5388
5489func (s * invoicesClient ) WaitForFinished () {
90+ s .quitOnce .Do (func () {
91+ close (s .quit )
92+ })
93+
5594 s .wg .Wait ()
5695}
5796
@@ -184,3 +223,144 @@ func fromRPCInvoiceState(state lnrpc.Invoice_InvoiceState) (
184223
185224 return 0 , errors .New ("unknown state" )
186225}
226+
227+ // HtlcModifier is a bidirectional streaming RPC that allows a client to
228+ // intercept and modify the HTLCs that attempt to settle the given invoice. The
229+ // server will send HTLCs of invoices to the client and the client can modify
230+ // some aspects of the HTLC in order to pass the invoice acceptance tests.
231+ func (s * invoicesClient ) HtlcModifier (ctx context.Context ,
232+ handler InvoiceHtlcModifyHandler ) error {
233+
234+ // Create a child context that will be canceled when this function
235+ // exits. We use this context to be able to cancel goroutines when we
236+ // exit on errors, because the parent context won't be canceled in that
237+ // case.
238+ ctx , cancel := context .WithCancel (ctx )
239+ defer cancel ()
240+
241+ stream , err := s .client .HtlcModifier (
242+ s .invoiceMac .WithMacaroonAuth (ctx ),
243+ )
244+ if err != nil {
245+ return err
246+ }
247+
248+ // Create an error channel that we'll send errors on if any of our
249+ // goroutines fail. We buffer by 1 so that the goroutine doesn't depend
250+ // on the stream being read, and select on context cancellation and
251+ // quit channel so that we do not block in the case where we exit with
252+ // multiple errors.
253+ errChan := make (chan error , 1 )
254+
255+ sendErr := func (err error ) {
256+ select {
257+ case errChan <- err :
258+ case <- ctx .Done ():
259+ case <- s .quit :
260+ }
261+ }
262+
263+ // Start a goroutine that consumes interception requests from lnd and
264+ // sends them into our requests channel for handling. The requests
265+ // channel is not buffered because we expect all requests to be handled
266+ // until this function exits, at which point we expect our context to
267+ // be canceled or quit channel to be closed.
268+ requestChan := make (chan InvoiceHtlcModifyRequest )
269+ s .wg .Add (1 )
270+ go func () {
271+ defer s .wg .Done ()
272+
273+ for {
274+ // Do a quick check whether our client context has been
275+ // canceled so that we can exit sooner if needed.
276+ if ctx .Err () != nil {
277+ return
278+ }
279+
280+ req , err := stream .Recv ()
281+ if err != nil {
282+ sendErr (err )
283+ return
284+ }
285+
286+ wireCustomRecords := req .ExitHtlcWireCustomRecords
287+ interceptReq := InvoiceHtlcModifyRequest {
288+ Invoice : req .Invoice ,
289+ CircuitKey : invpkg.CircuitKey {
290+ ChanID : lnwire .NewShortChanIDFromInt (
291+ req .ExitHtlcCircuitKey .ChanId ,
292+ ),
293+ HtlcID : req .ExitHtlcCircuitKey .HtlcId ,
294+ },
295+ ExitHtlcAmt : lnwire .MilliSatoshi (
296+ req .ExitHtlcAmt ,
297+ ),
298+ ExitHtlcExpiry : req .ExitHtlcExpiry ,
299+ CurrentHeight : req .CurrentHeight ,
300+ WireCustomRecords : wireCustomRecords ,
301+ }
302+
303+ // Try to send our interception request, failing on
304+ // context cancel or router exit.
305+ select {
306+ case requestChan <- interceptReq :
307+
308+ case <- s .quit :
309+ sendErr (ErrRouterShuttingDown )
310+ return
311+
312+ case <- ctx .Done ():
313+ sendErr (ctx .Err ())
314+ return
315+ }
316+ }
317+ }()
318+
319+ for {
320+ select {
321+ case request := <- requestChan :
322+ // Handle requests in a goroutine so that the handler
323+ // provided to this function can be blocking. If we
324+ // get an error, send it into our error channel to
325+ // shut down the interceptor.
326+ s .wg .Add (1 )
327+ go func () {
328+ defer s .wg .Done ()
329+
330+ // Get a response from handler, this may block
331+ // for a while.
332+ resp , err := handler (ctx , request )
333+ if err != nil {
334+ sendErr (err )
335+ return
336+ }
337+
338+ key := resp .CircuitKey
339+ amtPaid := uint64 (resp .AmtPaid )
340+ rpcResp := & invoicesrpc.HtlcModifyResponse {
341+ CircuitKey : & invoicesrpc.CircuitKey {
342+ ChanId : key .ChanID .ToUint64 (),
343+ HtlcId : key .HtlcID ,
344+ },
345+ AmtPaid : & amtPaid ,
346+ }
347+
348+ if err := stream .Send (rpcResp ); err != nil {
349+ sendErr (err )
350+ return
351+ }
352+ }()
353+
354+ // If one of our goroutines fails, exit with the error that
355+ // occurred.
356+ case err := <- errChan :
357+ return err
358+
359+ case <- s .quit :
360+ return ErrRouterShuttingDown
361+
362+ case <- ctx .Done ():
363+ return ctx .Err ()
364+ }
365+ }
366+ }
0 commit comments