Skip to content

Commit 197ee3b

Browse files
committed
firewalldb: thread context to PrivMap NewPair
Update the NewPair method of the PrivacyMapTx interface to take a context.
1 parent 7ce36d7 commit 197ee3b

File tree

5 files changed

+67
-55
lines changed

5 files changed

+67
-55
lines changed

firewall/privacy_mapper.go

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ func handleGetInfoResponse(db firewalldb.PrivacyMapDB,
336336
tx firewalldb.PrivacyMapTx) error {
337337

338338
var err error
339-
pseudoPubKey, err = firewalldb.HideString(
340-
tx, r.IdentityPubkey,
339+
pseudoPubKey, err = firewalldb.HideString( //nolint:lll
340+
ctx, tx, r.IdentityPubkey,
341341
)
342342

343343
return err
@@ -397,14 +397,14 @@ func handleFwdHistoryResponse(db firewalldb.PrivacyMapDB,
397397
if !flags.Contains(session.ClearChanIDs) {
398398
// Deterministically hide channel ids.
399399
chanIn, err = firewalldb.HideUint64(
400-
tx, chanIn,
400+
ctx, tx, chanIn,
401401
)
402402
if err != nil {
403403
return err
404404
}
405405

406406
chanOut, err = firewalldb.HideUint64(
407-
tx, chanOut,
407+
ctx, tx, chanOut,
408408
)
409409
if err != nil {
410410
return err
@@ -500,7 +500,7 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB,
500500
chanID := c.ChanId
501501
if !flags.Contains(session.ClearChanIDs) {
502502
chanID, err = firewalldb.HideUint64(
503-
tx, chanID,
503+
ctx, tx, chanID,
504504
)
505505
if err != nil {
506506
return err
@@ -510,7 +510,7 @@ func handleFeeReportResponse(db firewalldb.PrivacyMapDB,
510510
chanPoint := c.ChannelPoint
511511
if !flags.Contains(session.ClearChanIDs) {
512512
chanPoint, err = firewalldb.HideChanPointStr(
513-
tx, chanPoint,
513+
ctx, tx, chanPoint,
514514
)
515515
if err != nil {
516516
return err
@@ -599,7 +599,7 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB,
599599
remotePub := c.RemotePubkey
600600
if hidePubkeys {
601601
remotePub, err = firewalldb.HideString(
602-
tx, c.RemotePubkey,
602+
ctx, tx, c.RemotePubkey,
603603
)
604604
if err != nil {
605605
return err
@@ -610,14 +610,14 @@ func handleListChannelsResponse(db firewalldb.PrivacyMapDB,
610610
chanID := c.ChanId
611611
if hideChanIds {
612612
chanPoint, err = firewalldb.HideChanPointStr(
613-
tx, c.ChannelPoint,
613+
ctx, tx, c.ChannelPoint,
614614
)
615615
if err != nil {
616616
return err
617617
}
618618

619619
chanID, err = firewalldb.HideUint64(
620-
tx, c.ChanId,
620+
ctx, tx, c.ChanId,
621621
)
622622
if err != nil {
623623
return err
@@ -830,7 +830,7 @@ func handleUpdatePolicyResponse(db firewalldb.PrivacyMapDB,
830830
}
831831

832832
txid, index, err := firewalldb.HideChanPoint(
833-
tx, u.Outpoint.TxidStr,
833+
ctx, tx, u.Outpoint.TxidStr,
834834
u.Outpoint.OutputIndex,
835835
)
836836
if err != nil {
@@ -957,7 +957,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB,
957957
remotePub := c.RemotePubkey
958958
if !flags.Contains(session.ClearPubkeys) {
959959
remotePub, err = firewalldb.HideString(
960-
tx, remotePub,
960+
ctx, tx, remotePub,
961961
)
962962
if err != nil {
963963
return err
@@ -985,7 +985,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB,
985985
channelPoint := c.ChannelPoint
986986
if !flags.Contains(session.ClearChanIDs) {
987987
channelPoint, err = firewalldb.HideChanPointStr(
988-
tx, c.ChannelPoint,
988+
ctx, tx, c.ChannelPoint,
989989
)
990990
if err != nil {
991991
return err
@@ -995,7 +995,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB,
995995
chanID := c.ChanId
996996
if !flags.Contains(session.ClearChanIDs) {
997997
chanID, err = firewalldb.HideUint64(
998-
tx, c.ChanId,
998+
ctx, tx, c.ChanId,
999999
)
10001000
if err != nil {
10011001
return err
@@ -1005,7 +1005,7 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB,
10051005
closingTxid := c.ClosingTxHash
10061006
if !flags.Contains(session.ClearClosingTxIds) {
10071007
closingTxid, err = firewalldb.HideString(
1008-
tx, c.ClosingTxHash,
1008+
ctx, tx, c.ClosingTxHash,
10091009
)
10101010
if err != nil {
10111011
return err
@@ -1052,7 +1052,8 @@ func handleClosedChannelsResponse(db firewalldb.PrivacyMapDB,
10521052

10531053
// obfuscatePendingChannel is a helper to obfuscate the fields of a pending
10541054
// channel.
1055-
func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel,
1055+
func obfuscatePendingChannel(ctx context.Context,
1056+
c *lnrpc.PendingChannelsResponse_PendingChannel,
10561057
tx firewalldb.PrivacyMapTx, randIntn func(int) (int, error),
10571058
flags session.PrivacyFlags) (
10581059
*lnrpc.PendingChannelsResponse_PendingChannel, error) {
@@ -1062,7 +1063,7 @@ func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel,
10621063
remotePub := c.RemoteNodePub
10631064
if !flags.Contains(session.ClearPubkeys) {
10641065
remotePub, err = firewalldb.HideString(
1065-
tx, remotePub,
1066+
ctx, tx, remotePub,
10661067
)
10671068
if err != nil {
10681069
return nil, err
@@ -1099,7 +1100,7 @@ func obfuscatePendingChannel(c *lnrpc.PendingChannelsResponse_PendingChannel,
10991100
chanPoint := c.ChannelPoint
11001101
if !flags.Contains(session.ClearChanIDs) {
11011102
chanPoint, err = firewalldb.HideChanPointStr(
1102-
tx, c.ChannelPoint,
1103+
ctx, tx, c.ChannelPoint,
11031104
)
11041105
if err != nil {
11051106
return nil, err
@@ -1163,7 +1164,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
11631164
var err error
11641165

11651166
pendingChannel, err := obfuscatePendingChannel(
1166-
c.Channel, tx, randIntn, flags,
1167+
ctx, c.Channel, tx, randIntn, flags,
11671168
)
11681169
if err != nil {
11691170
return err
@@ -1187,16 +1188,16 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
11871188
var err error
11881189

11891190
pendingChannel, err := obfuscatePendingChannel(
1190-
c.Channel, tx, randIntn, flags,
1191+
ctx, c.Channel, tx, randIntn, flags,
11911192
)
11921193
if err != nil {
11931194
return err
11941195
}
11951196

11961197
closingTxid := c.ClosingTxid
11971198
if !flags.Contains(session.ClearClosingTxIds) {
1198-
closingTxid, err = firewalldb.HideString(
1199-
tx, c.ClosingTxid,
1199+
closingTxid, err = firewalldb.HideString( //nolint:lll
1200+
ctx, tx, c.ClosingTxid,
12001201
)
12011202
if err != nil {
12021203
return err
@@ -1216,7 +1217,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
12161217
var err error
12171218

12181219
pendingChannel, err := obfuscatePendingChannel(
1219-
c.Channel, tx, randIntn, flags,
1220+
ctx, c.Channel, tx, randIntn, flags,
12201221
)
12211222
if err != nil {
12221223
return err
@@ -1225,7 +1226,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
12251226
closingTxid := c.ClosingTxid
12261227
if !flags.Contains(session.ClearClosingTxIds) {
12271228
closingTxid, err = firewalldb.HideString(
1228-
tx, c.ClosingTxid,
1229+
ctx, tx, c.ClosingTxid,
12291230
)
12301231
if err != nil {
12311232
return err
@@ -1277,7 +1278,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
12771278
var err error
12781279

12791280
pendingChannel, err := obfuscatePendingChannel(
1280-
c.Channel, tx, randIntn, flags,
1281+
ctx, c.Channel, tx, randIntn, flags,
12811282
)
12821283
if err != nil {
12831284
return err
@@ -1297,7 +1298,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
12971298
closingTxid := c.ClosingTxid
12981299
if !flags.Contains(session.ClearClosingTxIds) {
12991300
closingTxid, err = firewalldb.HideString(
1300-
tx, closingTxid,
1301+
ctx, tx, closingTxid,
13011302
)
13021303
if err != nil {
13031304
return err
@@ -1314,7 +1315,7 @@ func handlePendingChannelsResponse(db firewalldb.PrivacyMapDB,
13141315
) {
13151316

13161317
closingTxHex, err = firewalldb.HideString(
1317-
tx, closingTxHex,
1318+
ctx, tx, closingTxHex,
13181319
)
13191320
if err != nil {
13201321
return err
@@ -1454,8 +1455,9 @@ func handleBatchOpenChannelResponse(db firewalldb.PrivacyMapDB,
14541455
return err
14551456
}
14561457

1457-
txID, outIdx, err := firewalldb.HideChanPoint(
1458-
tx, txId.String(), p.OutputIndex,
1458+
txID, outIdx, err := firewalldb.HideChanPoint( //nolint:lll
1459+
ctx, tx, txId.String(),
1460+
p.OutputIndex,
14591461
)
14601462
if err != nil {
14611463
return err
@@ -1600,7 +1602,7 @@ func handleChannelOpenResponse(db firewalldb.PrivacyMapDB,
16001602

16011603
if !flags.Contains(session.ClearChanIDs) {
16021604
txid, index, err = firewalldb.HideChanPoint(
1603-
tx, txid, index,
1605+
ctx, tx, txid, index,
16041606
)
16051607
if err != nil {
16061608
return err

firewall/privacy_mapper_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ func newMockDB(t *testing.T, preloadRealToPseudo map[string]string,
10771077
tx firewalldb.PrivacyMapTx) error {
10781078

10791079
for r, p := range preloadRealToPseudo {
1080-
require.NoError(t, tx.NewPair(r, p))
1080+
require.NoError(t, tx.NewPair(ctx, r, p))
10811081
}
10821082
return nil
10831083
})
@@ -1121,7 +1121,9 @@ func (m *mockPrivacyMapDB) View(ctx context.Context,
11211121
return f(ctx, m)
11221122
}
11231123

1124-
func (m *mockPrivacyMapDB) NewPair(real, pseudo string) error {
1124+
func (m *mockPrivacyMapDB) NewPair(_ context.Context, real,
1125+
pseudo string) error {
1126+
11251127
m.r2p[real] = pseudo
11261128
m.p2r[pseudo] = real
11271129
return nil

firewalldb/privacy_mapper.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ type PrivacyMapDB interface {
7171
// real-pseudo pairs.
7272
type PrivacyMapTx interface {
7373
// NewPair persists a new real-pseudo pair.
74-
NewPair(real, pseudo string) error
74+
NewPair(ctx context.Context, real, pseudo string) error
7575

7676
// PseudoToReal returns the real value associated with the given pseudo
7777
// value. If no such pair is found, then ErrNoSuchKeyFound is returned.
@@ -181,7 +181,7 @@ type privacyMapTx struct {
181181
// NewPair inserts a new real-pseudo pair into the db.
182182
//
183183
// NOTE: this is part of the PrivacyMapTx interface.
184-
func (p *privacyMapTx) NewPair(real, pseudo string) error {
184+
func (p *privacyMapTx) NewPair(_ context.Context, real, pseudo string) error {
185185
privacyBucket, err := getBucket(p.boltTx, privacyBucketKey)
186186
if err != nil {
187187
return err
@@ -314,7 +314,9 @@ func (p *privacyMapTx) FetchAllPairs() (*PrivacyMapPairs, error) {
314314
return NewPrivacyMapPairs(pairs), nil
315315
}
316316

317-
func HideString(tx PrivacyMapTx, real string) (string, error) {
317+
func HideString(ctx context.Context, tx PrivacyMapTx, real string) (string,
318+
error) {
319+
318320
pseudo, err := tx.RealToPseudo(real)
319321
if err != nil && err != ErrNoSuchKeyFound {
320322
return "", err
@@ -328,7 +330,7 @@ func HideString(tx PrivacyMapTx, real string) (string, error) {
328330
return "", err
329331
}
330332

331-
if err = tx.NewPair(real, pseudo); err != nil {
333+
if err = tx.NewPair(ctx, real, pseudo); err != nil {
332334
return "", err
333335
}
334336

@@ -360,7 +362,9 @@ func RevealString(tx PrivacyMapTx, pseudo string) (string, error) {
360362
return tx.PseudoToReal(pseudo)
361363
}
362364

363-
func HideUint64(tx PrivacyMapTx, real uint64) (uint64, error) {
365+
func HideUint64(ctx context.Context, tx PrivacyMapTx, real uint64) (uint64,
366+
error) {
367+
364368
str := Uint64ToStr(real)
365369
pseudo, err := tx.RealToPseudo(str)
366370
if err != nil && err != ErrNoSuchKeyFound {
@@ -371,7 +375,7 @@ func HideUint64(tx PrivacyMapTx, real uint64) (uint64, error) {
371375
}
372376

373377
pseudoUint64, pseudoUint64Str := NewPseudoUint64()
374-
if err := tx.NewPair(str, pseudoUint64Str); err != nil {
378+
if err := tx.NewPair(ctx, str, pseudoUint64Str); err != nil {
375379
return 0, err
376380
}
377381

@@ -391,8 +395,8 @@ func RevealUint64(tx PrivacyMapTx, pseudo uint64) (uint64, error) {
391395
return StrToUint64(real)
392396
}
393397

394-
func HideChanPoint(tx PrivacyMapTx, txid string, index uint32) (string,
395-
uint32, error) {
398+
func HideChanPoint(ctx context.Context, tx PrivacyMapTx, txid string,
399+
index uint32) (string, uint32, error) {
396400

397401
cp := fmt.Sprintf("%s:%d", txid, index)
398402
pseudo, err := tx.RealToPseudo(cp)
@@ -408,7 +412,7 @@ func HideChanPoint(tx PrivacyMapTx, txid string, index uint32) (string,
408412
return "", 0, err
409413
}
410414

411-
if err := tx.NewPair(cp, newCp); err != nil {
415+
if err := tx.NewPair(ctx, cp, newCp); err != nil {
412416
return "", 0, err
413417
}
414418

@@ -444,24 +448,28 @@ func NewPseudoUint32() uint32 {
444448
return binary.BigEndian.Uint32(b)
445449
}
446450

447-
func HideChanPointStr(tx PrivacyMapTx, cp string) (string, error) {
451+
func HideChanPointStr(ctx context.Context, tx PrivacyMapTx, cp string) (string,
452+
error) {
453+
448454
txid, index, err := DecodeChannelPoint(cp)
449455
if err != nil {
450456
return "", err
451457
}
452458

453-
newTxid, newIndex, err := HideChanPoint(tx, txid, index)
459+
newTxid, newIndex, err := HideChanPoint(ctx, tx, txid, index)
454460
if err != nil {
455461
return "", err
456462
}
457463

458464
return fmt.Sprintf("%s:%d", newTxid, newIndex), nil
459465
}
460466

461-
func HideBytes(tx PrivacyMapTx, realBytes []byte) ([]byte, error) {
467+
func HideBytes(ctx context.Context, tx PrivacyMapTx, realBytes []byte) ([]byte,
468+
error) {
469+
462470
real := hex.EncodeToString(realBytes)
463471

464-
pseudo, err := HideString(tx, real)
472+
pseudo, err := HideString(ctx, tx, real)
465473
if err != nil {
466474
return nil, err
467475
}

0 commit comments

Comments
 (0)