Skip to content

Commit 7ce36d7

Browse files
committed
multi: thread contexts through privacy map interfaces
Update the PrivacyMapDB interface methods to take contexts (both the methods themselves and the call-back params) and then ensure all implementations are updated and all call-sites pass contexts through correctly.
1 parent e49a1c3 commit 7ce36d7

16 files changed

+160
-90
lines changed

firewall/privacy_mapper.go

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,16 @@ func handleGetInfoResponse(db firewalldb.PrivacyMapDB,
325325
flags session.PrivacyFlags) func(ctx context.Context,
326326
r *lnrpc.GetInfoResponse) (proto.Message, error) {
327327

328-
return func(_ context.Context, r *lnrpc.GetInfoResponse) (
328+
return func(ctx context.Context, r *lnrpc.GetInfoResponse) (
329329
proto.Message, error) {
330330

331331
// We hide the pubkey unless it is disabled.
332332
pseudoPubKey := r.IdentityPubkey
333333
if !flags.Contains(session.ClearPubkeys) {
334-
err := db.Update(
335-
func(tx firewalldb.PrivacyMapTx) error {
334+
err := db.Update(ctx,
335+
func(ctx context.Context,
336+
tx firewalldb.PrivacyMapTx) error {
337+
336338
var err error
337339
pseudoPubKey, err = firewalldb.HideString(
338340
tx, r.IdentityPubkey,
@@ -377,14 +379,16 @@ func handleFwdHistoryResponse(db firewalldb.PrivacyMapDB,
377379
randIntn func(int) (int, error)) func(ctx context.Context,
378380
r *lnrpc.ForwardingHistoryResponse) (proto.Message, error) {
379381

380-
return func(_ context.Context, r *lnrpc.ForwardingHistoryResponse) (
382+
return func(ctx context.Context, r *lnrpc.ForwardingHistoryResponse) (
381383
proto.Message, error) {
382384

383385
fwdEvents := make(
384386
[]*lnrpc.ForwardingEvent, len(r.ForwardingEvents),
385387
)
386388

387-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
389+
err := db.Update(ctx, func(ctx context.Context,
390+
tx firewalldb.PrivacyMapTx) error {
391+
388392
for i, fe := range r.ForwardingEvents {
389393
var err error
390394

@@ -487,7 +491,9 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB,
487491

488492
chanFees := make([]*lnrpc.ChannelFeeReport, len(r.ChannelFees))
489493

490-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
494+
err := db.Update(ctx, func(ctx context.Context,
495+
tx firewalldb.PrivacyMapTx) error {
496+
491497
var err error
492498

493499
for i, c := range r.ChannelFees {
@@ -550,7 +556,9 @@ func handleListChannelsRequest(db firewalldb.PrivacyMapDB,
550556
return r, nil
551557
}
552558

553-
err := db.View(func(tx firewalldb.PrivacyMapTx) error {
559+
err := db.View(ctx, func(ctx context.Context,
560+
tx firewalldb.PrivacyMapTx) error {
561+
554562
peer, err := firewalldb.RevealBytes(tx, r.Peer)
555563
if err != nil {
556564
return err
@@ -572,15 +580,17 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB,
572580
randIntn func(int) (int, error)) func(ctx context.Context,
573581
r *lnrpc.ListChannelsResponse) (proto.Message, error) {
574582

575-
return func(_ context.Context, r *lnrpc.ListChannelsResponse) (
583+
return func(ctx context.Context, r *lnrpc.ListChannelsResponse) (
576584
proto.Message, error) {
577585

578586
hidePubkeys := !flags.Contains(session.ClearPubkeys)
579587
hideChanIds := !flags.Contains(session.ClearChanIDs)
580588

581589
channels := make([]*lnrpc.Channel, len(r.Channels))
582590

583-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
591+
err := db.Update(ctx, func(ctx context.Context,
592+
tx firewalldb.PrivacyMapTx) error {
593+
584594
for i, c := range r.Channels {
585595
var err error
586596

@@ -745,7 +755,7 @@ func handleUpdatePolicyRequest(db firewalldb.PrivacyMapDB,
745755
flags session.PrivacyFlags) func(ctx context.Context,
746756
r *lnrpc.PolicyUpdateRequest) (proto.Message, error) {
747757

748-
return func(_ context.Context, r *lnrpc.PolicyUpdateRequest) (
758+
return func(ctx context.Context, r *lnrpc.PolicyUpdateRequest) (
749759
proto.Message, error) {
750760

751761
chanPoint := r.GetChanPoint()
@@ -764,7 +774,9 @@ func handleUpdatePolicyRequest(db firewalldb.PrivacyMapDB,
764774
newTxid := txid.String()
765775
newIndex := chanPoint.GetOutputIndex()
766776
if !flags.Contains(session.ClearChanIDs) {
767-
err = db.View(func(tx firewalldb.PrivacyMapTx) error {
777+
err = db.View(ctx, func(ctx context.Context,
778+
tx firewalldb.PrivacyMapTx) error {
779+
768780
var err error
769781
newTxid, newIndex, err = firewalldb.RevealChanPoint(
770782
tx, newTxid, newIndex,
@@ -793,7 +805,7 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB,
793805
flags session.PrivacyFlags) func(ctx context.Context,
794806
r *lnrpc.PolicyUpdateResponse) (proto.Message, error) {
795807

796-
return func(_ context.Context, r *lnrpc.PolicyUpdateResponse) (
808+
return func(ctx context.Context, r *lnrpc.PolicyUpdateResponse) (
797809
proto.Message, error) {
798810

799811
if flags.Contains(session.ClearChanIDs) {
@@ -804,7 +816,9 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB,
804816
[]*lnrpc.FailedUpdate, len(r.FailedUpdates),
805817
)
806818

807-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
819+
err := db.Update(ctx, func(ctx context.Context,
820+
tx firewalldb.PrivacyMapTx) error {
821+
808822
for i, u := range r.FailedUpdates {
809823
failedUpdates[i] = &lnrpc.FailedUpdate{
810824
Reason: u.Reason,
@@ -926,15 +940,17 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB,
926940
randIntn func(int) (int, error)) func(ctx context.Context,
927941
r *lnrpc.ClosedChannelsResponse) (proto.Message, error) {
928942

929-
return func(_ context.Context, r *lnrpc.ClosedChannelsResponse) (
943+
return func(ctx context.Context, r *lnrpc.ClosedChannelsResponse) (
930944
proto.Message, error) {
931945

932946
closedChannels := make(
933947
[]*lnrpc.ChannelCloseSummary,
934948
len(r.Channels),
935949
)
936950

937-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
951+
err := db.Update(ctx, func(ctx context.Context,
952+
tx firewalldb.PrivacyMapTx) error {
953+
938954
for i, c := range r.Channels {
939955
var err error
940956

@@ -1117,7 +1133,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
11171133
randIntn func(int) (int, error)) func(ctx context.Context,
11181134
r *lnrpc.PendingChannelsResponse) (proto.Message, error) {
11191135

1120-
return func(_ context.Context, r *lnrpc.PendingChannelsResponse) (
1136+
return func(ctx context.Context, r *lnrpc.PendingChannelsResponse) (
11211137
proto.Message, error) {
11221138

11231139
pendingOpens := make(
@@ -1140,7 +1156,9 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
11401156
len(r.WaitingCloseChannels),
11411157
)
11421158

1143-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
1159+
err := db.Update(ctx, func(ctx context.Context,
1160+
tx firewalldb.PrivacyMapTx) error {
1161+
11441162
for i, c := range r.PendingOpenChannels {
11451163
var err error
11461164

@@ -1343,12 +1361,14 @@ func handleBatchOpenChannelRequest(db firewalldb.PrivacyMapDB,
13431361
flags session.PrivacyFlags) func(ctx context.Context,
13441362
r *lnrpc.BatchOpenChannelRequest) (proto.Message, error) {
13451363

1346-
return func(_ context.Context, r *lnrpc.BatchOpenChannelRequest) (
1364+
return func(ctx context.Context, r *lnrpc.BatchOpenChannelRequest) (
13471365
proto.Message, error) {
13481366

13491367
var reqs = make([]*lnrpc.BatchOpenChannel, len(r.Channels))
13501368

1351-
err := db.View(func(tx firewalldb.PrivacyMapTx) error {
1369+
err := db.View(ctx, func(ctx context.Context,
1370+
tx firewalldb.PrivacyMapTx) error {
1371+
13521372
for i, c := range r.Channels {
13531373
var err error
13541374

@@ -1414,12 +1434,14 @@ func handleBatchOpenChannelResponse(db firewalldb.PrivacyMapDB,
14141434
flags session.PrivacyFlags) func(ctx context.Context,
14151435
r *lnrpc.BatchOpenChannelResponse) (proto.Message, error) {
14161436

1417-
return func(_ context.Context, r *lnrpc.BatchOpenChannelResponse) (
1437+
return func(ctx context.Context, r *lnrpc.BatchOpenChannelResponse) (
14181438
proto.Message, error) {
14191439

14201440
resps := make([]*lnrpc.PendingUpdate, len(r.PendingChannels))
14211441

1422-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
1442+
err := db.Update(ctx, func(ctx context.Context,
1443+
tx firewalldb.PrivacyMapTx) error {
1444+
14231445
for i, p := range r.PendingChannels {
14241446
var (
14251447
txIdBytes = p.Txid
@@ -1471,14 +1493,15 @@ func handleChannelOpenRequest(db firewalldb.PrivacyMapDB,
14711493
flags session.PrivacyFlags) func(ctx context.Context,
14721494
r *lnrpc.OpenChannelRequest) (proto.Message, error) {
14731495

1474-
return func(_ context.Context, r *lnrpc.OpenChannelRequest) (
1496+
return func(ctx context.Context, r *lnrpc.OpenChannelRequest) (
14751497
proto.Message, error) {
14761498

14771499
var nodePubkey []byte
14781500

1479-
err := db.View(func(tx firewalldb.PrivacyMapTx) error {
1480-
var err error
1501+
err := db.View(ctx, func(ctx context.Context,
1502+
tx firewalldb.PrivacyMapTx) error {
14811503

1504+
var err error
14821505
// We use the byte slice representation of the
14831506
// pubkey and fall back to the hex string if present.
14841507
nodePubkey = r.NodePubkey
@@ -1548,15 +1571,17 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB,
15481571
flags session.PrivacyFlags) func(ctx context.Context,
15491572
r *lnrpc.ChannelPoint) (proto.Message, error) {
15501573

1551-
return func(_ context.Context, r *lnrpc.ChannelPoint) (
1574+
return func(ctx context.Context, r *lnrpc.ChannelPoint) (
15521575
proto.Message, error) {
15531576

15541577
var (
15551578
txid string
15561579
index uint32
15571580
)
15581581

1559-
err := db.Update(func(tx firewalldb.PrivacyMapTx) error {
1582+
err := db.Update(ctx, func(ctx context.Context,
1583+
tx firewalldb.PrivacyMapTx) error {
1584+
15601585
var err error
15611586

15621587
txid = r.GetFundingTxidStr()
@@ -1622,12 +1647,14 @@ func handleConnectPeerRequest(db firewalldb.PrivacyMapDB,
16221647
flags session.PrivacyFlags) func(ctx context.Context,
16231648
r *lnrpc.ConnectPeerRequest) (proto.Message, error) {
16241649

1625-
return func(_ context.Context, r *lnrpc.ConnectPeerRequest) (
1650+
return func(ctx context.Context, r *lnrpc.ConnectPeerRequest) (
16261651
proto.Message, error) {
16271652

16281653
var addr *lnrpc.LightningAddress
16291654

1630-
err := db.View(func(tx firewalldb.PrivacyMapTx) error {
1655+
err := db.View(ctx, func(ctx context.Context,
1656+
tx firewalldb.PrivacyMapTx) error {
1657+
16311658
var err error
16321659

16331660
// Note, this only works if the pubkey alias was

firewall/privacy_mapper_test.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,9 @@ func newMockDB(t *testing.T, preloadRealToPseudo map[string]string,
10731073
db := mockDB{privDB: make(map[string]*mockPrivacyMapDB)}
10741074
sessDB := db.NewSessionDB(sessID)
10751075

1076-
_ = sessDB.Update(func(tx firewalldb.PrivacyMapTx) error {
1076+
_ = sessDB.Update(context.Background(), func(ctx context.Context,
1077+
tx firewalldb.PrivacyMapTx) error {
1078+
10771079
for r, p := range preloadRealToPseudo {
10781080
require.NoError(t, tx.NewPair(r, p))
10791081
}
@@ -1107,16 +1109,16 @@ type mockPrivacyMapDB struct {
11071109
p2r map[string]string
11081110
}
11091111

1110-
func (m *mockPrivacyMapDB) Update(
1111-
f func(tx firewalldb.PrivacyMapTx) error) error {
1112+
func (m *mockPrivacyMapDB) Update(ctx context.Context,
1113+
f func(ctx context.Context, tx firewalldb.PrivacyMapTx) error) error {
11121114

1113-
return f(m)
1115+
return f(ctx, m)
11141116
}
11151117

1116-
func (m *mockPrivacyMapDB) View(
1117-
f func(tx firewalldb.PrivacyMapTx) error) error {
1118+
func (m *mockPrivacyMapDB) View(ctx context.Context,
1119+
f func(ctx context.Context, tx firewalldb.PrivacyMapTx) error) error {
11181120

1119-
return f(m)
1121+
return f(ctx, m)
11201122
}
11211123

11221124
func (m *mockPrivacyMapDB) NewPair(real, pseudo string) error {

firewall/rule_enforcer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ func (r *RuleEnforcer) initRule(ctx context.Context, reqID uint64, name string,
395395
privMap := r.newPrivMap(session.GroupID)
396396

397397
ruleValues, err = ruleValues.PseudoToReal(
398-
privMap, session.PrivacyFlags,
398+
ctx, privMap, session.PrivacyFlags,
399399
)
400400
if err != nil {
401401
return nil, fmt.Errorf("could not prepare rule "+

firewalldb/privacy_mapper.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package firewalldb
22

33
import (
4+
"context"
45
"crypto/rand"
56
"encoding/binary"
67
"encoding/hex"
@@ -57,13 +58,13 @@ type PrivacyMapDB interface {
5758
// error, the transaction is rolled back. If the rollback fails, the
5859
// original error returned by f is still returned. If the commit fails,
5960
// the commit error is returned.
60-
Update(f func(tx PrivacyMapTx) error) error
61+
Update(context.Context, func(context.Context, PrivacyMapTx) error) error
6162

6263
// View opens a database read transaction and executes the function f
6364
// with the transaction passed as a parameter. After f exits, the
6465
// transaction is rolled back. If f errors, its error is returned, not a
6566
// rollback error (if any occur).
66-
View(f func(tx PrivacyMapTx) error) error
67+
View(context.Context, func(context.Context, PrivacyMapTx) error) error
6768
}
6869

6970
// PrivacyMapTx represents a db that can be used to create, store and fetch
@@ -112,7 +113,9 @@ func (p *privacyMapDB) beginTx(writable bool) (*privacyMapTx, error) {
112113
// returned.
113114
//
114115
// NOTE: this is part of the PrivacyMapDB interface.
115-
func (p *privacyMapDB) Update(f func(tx PrivacyMapTx) error) error {
116+
func (p *privacyMapDB) Update(ctx context.Context, f func(ctx context.Context,
117+
tx PrivacyMapTx) error) error {
118+
116119
tx, err := p.beginTx(true)
117120
if err != nil {
118121
return err
@@ -125,7 +128,7 @@ func (p *privacyMapDB) Update(f func(tx PrivacyMapTx) error) error {
125128
}
126129
}()
127130

128-
err = f(tx)
131+
err = f(ctx, tx)
129132
if err != nil {
130133
// Want to return the original error, not a rollback error if
131134
// any occur.
@@ -142,7 +145,9 @@ func (p *privacyMapDB) Update(f func(tx PrivacyMapTx) error) error {
142145
// occur).
143146
//
144147
// NOTE: this is part of the PrivacyMapDB interface.
145-
func (p *privacyMapDB) View(f func(tx PrivacyMapTx) error) error {
148+
func (p *privacyMapDB) View(ctx context.Context, f func(ctx context.Context,
149+
tx PrivacyMapTx) error) error {
150+
146151
tx, err := p.beginTx(false)
147152
if err != nil {
148153
return err
@@ -155,7 +160,7 @@ func (p *privacyMapDB) View(f func(tx PrivacyMapTx) error) error {
155160
}
156161
}()
157162

158-
err = f(tx)
163+
err = f(ctx, tx)
159164
rollbackErr := tx.boltTx.Rollback()
160165
if err != nil {
161166
return err

0 commit comments

Comments
 (0)