Skip to content

Commit 6e63dcd

Browse files
authored
Merge pull request #8952 from ProofOfKeags/refactor/lnwallet-channel-typed-list
[MICRO]: lnwallet: refactor channel to use new typed List
2 parents 546c8eb + 04c3734 commit 6e63dcd

File tree

4 files changed

+37
-38
lines changed

4 files changed

+37
-38
lines changed

lnwallet/channel.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,7 +2483,7 @@ type htlcView struct {
24832483
func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *htlcView {
24842484
var ourHTLCs []*PaymentDescriptor
24852485
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
2486-
htlc := e.Value.(*PaymentDescriptor)
2486+
htlc := e.Value
24872487

24882488
// This HTLC is active from this point-of-view iff the log
24892489
// index of the state update is below the specified index in
@@ -2495,7 +2495,7 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, ourLogIndex uint64) *ht
24952495

24962496
var theirHTLCs []*PaymentDescriptor
24972497
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
2498-
htlc := e.Value.(*PaymentDescriptor)
2498+
htlc := e.Value
24992499

25002500
// If this is an incoming HTLC, then it is only active from
25012501
// this point-of-view if the index of the HTLC addition in
@@ -3139,7 +3139,7 @@ func (lc *LightningChannel) createCommitDiff(
31393139
// set of items we need to retransmit if we reconnect and find that
31403140
// they didn't process this new state fully.
31413141
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
3142-
pd := e.Value.(*PaymentDescriptor)
3142+
pd := e.Value
31433143

31443144
// If this entry wasn't committed at the exact height of this
31453145
// remote commitment, then we'll skip it as it was already
@@ -3277,7 +3277,7 @@ func (lc *LightningChannel) getUnsignedAckedUpdates() []channeldb.LogUpdate {
32773277
// remote party expects.
32783278
var logUpdates []channeldb.LogUpdate
32793279
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
3280-
pd := e.Value.(*PaymentDescriptor)
3280+
pd := e.Value
32813281

32823282
// Skip all remote updates that we have already included in our
32833283
// commit chain.
@@ -5225,7 +5225,7 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
52255225

52265226
var addIndex, settleFailIndex uint16
52275227
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
5228-
pd := e.Value.(*PaymentDescriptor)
5228+
pd := e.Value
52295229

52305230
// Fee updates are local to this particular channel, and should
52315231
// never be forwarded.
@@ -5555,7 +5555,7 @@ func (lc *LightningChannel) GetDustSum(whoseCommit lntypes.ChannelParty,
55555555

55565556
// Grab all of our HTLCs and evaluate against the dust limit.
55575557
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
5558-
pd := e.Value.(*PaymentDescriptor)
5558+
pd := e.Value
55595559
if pd.EntryType != Add {
55605560
continue
55615561
}
@@ -5574,7 +5574,7 @@ func (lc *LightningChannel) GetDustSum(whoseCommit lntypes.ChannelParty,
55745574

55755575
// Grab all of their HTLCs and evaluate against the dust limit.
55765576
for e := lc.remoteUpdateLog.Front(); e != nil; e = e.Next() {
5577-
pd := e.Value.(*PaymentDescriptor)
5577+
pd := e.Value
55785578
if pd.EntryType != Add {
55795579
continue
55805580
}
@@ -8579,7 +8579,7 @@ func (lc *LightningChannel) unsignedLocalUpdates(remoteMessageIndex,
85798579

85808580
var localPeerUpdates []channeldb.LogUpdate
85818581
for e := lc.localUpdateLog.Front(); e != nil; e = e.Next() {
8582-
pd := e.Value.(*PaymentDescriptor)
8582+
pd := e.Value
85838583

85848584
// We don't save add updates as they are restored from the
85858585
// remote commitment in restoreStateLogs.

lnwallet/channel_test.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package lnwallet
22

33
import (
44
"bytes"
5-
"container/list"
65
"crypto/sha256"
76
"fmt"
87
"math/rand"
@@ -1906,7 +1905,7 @@ func TestStateUpdatePersistence(t *testing.T) {
19061905

19071906
// Newly generated pkScripts for HTLCs should be the same as in the old channel.
19081907
for _, entry := range aliceChannel.localUpdateLog.htlcIndex {
1909-
htlc := entry.Value.(*PaymentDescriptor)
1908+
htlc := entry.Value
19101909
restoredHtlc := aliceChannelNew.localUpdateLog.lookupHtlc(htlc.HtlcIndex)
19111910
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
19121911
t.Fatalf("alice ourPkScript in ourLog: expected %X, got %X",
@@ -1918,7 +1917,7 @@ func TestStateUpdatePersistence(t *testing.T) {
19181917
}
19191918
}
19201919
for _, entry := range aliceChannel.remoteUpdateLog.htlcIndex {
1921-
htlc := entry.Value.(*PaymentDescriptor)
1920+
htlc := entry.Value
19221921
restoredHtlc := aliceChannelNew.remoteUpdateLog.lookupHtlc(htlc.HtlcIndex)
19231922
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
19241923
t.Fatalf("alice ourPkScript in theirLog: expected %X, got %X",
@@ -1930,7 +1929,7 @@ func TestStateUpdatePersistence(t *testing.T) {
19301929
}
19311930
}
19321931
for _, entry := range bobChannel.localUpdateLog.htlcIndex {
1933-
htlc := entry.Value.(*PaymentDescriptor)
1932+
htlc := entry.Value
19341933
restoredHtlc := bobChannelNew.localUpdateLog.lookupHtlc(htlc.HtlcIndex)
19351934
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
19361935
t.Fatalf("bob ourPkScript in ourLog: expected %X, got %X",
@@ -1942,7 +1941,7 @@ func TestStateUpdatePersistence(t *testing.T) {
19421941
}
19431942
}
19441943
for _, entry := range bobChannel.remoteUpdateLog.htlcIndex {
1945-
htlc := entry.Value.(*PaymentDescriptor)
1944+
htlc := entry.Value
19461945
restoredHtlc := bobChannelNew.remoteUpdateLog.lookupHtlc(htlc.HtlcIndex)
19471946
if !bytes.Equal(htlc.ourPkScript, restoredHtlc.ourPkScript) {
19481947
t.Fatalf("bob ourPkScript in theirLog: expected %X, got %X",
@@ -4472,7 +4471,7 @@ func TestFeeUpdateOldDiskFormat(t *testing.T) {
44724471
countLog := func(log *updateLog) (int, int) {
44734472
var numUpdates, numFee int
44744473
for e := log.Front(); e != nil; e = e.Next() {
4475-
htlc := e.Value.(*PaymentDescriptor)
4474+
htlc := e.Value
44764475
if htlc.EntryType == FeeUpdate {
44774476
numFee++
44784477
}
@@ -6755,14 +6754,14 @@ func compareHtlcs(htlc1, htlc2 *PaymentDescriptor) error {
67556754
}
67566755

67576756
// compareIndexes is a helper method to compare two index maps.
6758-
func compareIndexes(a, b map[uint64]*list.Element) error {
6757+
func compareIndexes(a, b map[uint64]*fn.Node[*PaymentDescriptor]) error {
67596758
for k1, e1 := range a {
67606759
e2, ok := b[k1]
67616760
if !ok {
67626761
return fmt.Errorf("element with key %d "+
67636762
"not found in b", k1)
67646763
}
6765-
htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor)
6764+
htlc1, htlc2 := e1.Value, e2.Value
67666765
if err := compareHtlcs(htlc1, htlc2); err != nil {
67676766
return err
67686767
}
@@ -6774,7 +6773,7 @@ func compareIndexes(a, b map[uint64]*list.Element) error {
67746773
return fmt.Errorf("element with key %d not "+
67756774
"found in a", k1)
67766775
}
6777-
htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor)
6776+
htlc1, htlc2 := e1.Value, e2.Value
67786777
if err := compareHtlcs(htlc1, htlc2); err != nil {
67796778
return err
67806779
}
@@ -6809,7 +6808,7 @@ func compareLogs(a, b *updateLog) error {
68096808

68106809
e1, e2 := a.Front(), b.Front()
68116810
for ; e1 != nil; e1, e2 = e1.Next(), e2.Next() {
6812-
htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor)
6811+
htlc1, htlc2 := e1.Value, e2.Value
68136812
if err := compareHtlcs(htlc1, htlc2); err != nil {
68146813
return err
68156814
}
@@ -6917,7 +6916,7 @@ func TestChannelRestoreUpdateLogs(t *testing.T) {
69176916
func fetchNumUpdates(t updateType, log *updateLog) int {
69186917
num := 0
69196918
for e := log.Front(); e != nil; e = e.Next() {
6920-
htlc := e.Value.(*PaymentDescriptor)
6919+
htlc := e.Value
69216920
if htlc.EntryType == t {
69226921
num++
69236922
}

lnwallet/commitment_chain.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package lnwallet
22

3-
import "container/list"
3+
import (
4+
"github.com/lightningnetwork/lnd/fn"
5+
)
46

57
// commitmentChain represents a chain of unrevoked commitments. The tail of the
68
// chain is the latest fully signed, yet unrevoked commitment. Two chains are
@@ -15,13 +17,13 @@ type commitmentChain struct {
1517
// commitments are added to the end of the chain with increase height.
1618
// Once a commitment transaction is revoked, the tail is incremented,
1719
// freeing up the revocation window for new commitments.
18-
commitments *list.List
20+
commitments *fn.List[*commitment]
1921
}
2022

2123
// newCommitmentChain creates a new commitment chain.
2224
func newCommitmentChain() *commitmentChain {
2325
return &commitmentChain{
24-
commitments: list.New(),
26+
commitments: fn.NewList[*commitment](),
2527
}
2628
}
2729

@@ -42,14 +44,12 @@ func (s *commitmentChain) advanceTail() {
4244

4345
// tip returns the latest commitment added to the chain.
4446
func (s *commitmentChain) tip() *commitment {
45-
//nolint:forcetypeassert
46-
return s.commitments.Back().Value.(*commitment)
47+
return s.commitments.Back().Value
4748
}
4849

4950
// tail returns the lowest unrevoked commitment transaction in the chain.
5051
func (s *commitmentChain) tail() *commitment {
51-
//nolint:forcetypeassert
52-
return s.commitments.Front().Value.(*commitment)
52+
return s.commitments.Front().Value
5353
}
5454

5555
// hasUnackedCommitment returns true if the commitment chain has more than one

lnwallet/update_log.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package lnwallet
22

3-
import "container/list"
3+
import (
4+
"github.com/lightningnetwork/lnd/fn"
5+
)
46

57
// updateLog is an append-only log that stores updates to a node's commitment
68
// chain. This structure can be seen as the "mempool" within Lightning where
@@ -27,16 +29,16 @@ type updateLog struct {
2729

2830
// List is the updatelog itself, we embed this value so updateLog has
2931
// access to all the method of a list.List.
30-
*list.List
32+
*fn.List[*PaymentDescriptor]
3133

3234
// updateIndex maps a `logIndex` to a particular update entry. It
3335
// deals with the four update types:
3436
// `Fail|MalformedFail|Settle|FeeUpdate`
35-
updateIndex map[uint64]*list.Element
37+
updateIndex map[uint64]*fn.Node[*PaymentDescriptor]
3638

3739
// htlcIndex maps a `htlcCounter` to an offered HTLC entry, hence the
3840
// `Add` update.
39-
htlcIndex map[uint64]*list.Element
41+
htlcIndex map[uint64]*fn.Node[*PaymentDescriptor]
4042

4143
// modifiedHtlcs is a set that keeps track of all the current modified
4244
// htlcs, hence update types `Fail|MalformedFail|Settle`. A modified
@@ -48,9 +50,9 @@ type updateLog struct {
4850
// newUpdateLog creates a new updateLog instance.
4951
func newUpdateLog(logIndex, htlcCounter uint64) *updateLog {
5052
return &updateLog{
51-
List: list.New(),
52-
updateIndex: make(map[uint64]*list.Element),
53-
htlcIndex: make(map[uint64]*list.Element),
53+
List: fn.NewList[*PaymentDescriptor](),
54+
updateIndex: make(map[uint64]*fn.Node[*PaymentDescriptor]),
55+
htlcIndex: make(map[uint64]*fn.Node[*PaymentDescriptor]),
5456
logIndex: logIndex,
5557
htlcCounter: htlcCounter,
5658
modifiedHtlcs: make(map[uint64]struct{}),
@@ -101,8 +103,7 @@ func (u *updateLog) lookupHtlc(i uint64) *PaymentDescriptor {
101103
return nil
102104
}
103105

104-
//nolint:forcetypeassert
105-
return htlc.Value.(*PaymentDescriptor)
106+
return htlc.Value
106107
}
107108

108109
// remove attempts to remove an entry from the update log. If the entry is
@@ -145,15 +146,14 @@ func compactLogs(ourLog, theirLog *updateLog,
145146
localChainTail, remoteChainTail uint64) {
146147

147148
compactLog := func(logA, logB *updateLog) {
148-
var nextA *list.Element
149+
var nextA *fn.Node[*PaymentDescriptor]
149150
for e := logA.Front(); e != nil; e = nextA {
150151
// Assign next iteration element at top of loop because
151152
// we may remove the current element from the list,
152153
// which can change the iterated sequence.
153154
nextA = e.Next()
154155

155-
//nolint:forcetypeassert
156-
htlc := e.Value.(*PaymentDescriptor)
156+
htlc := e.Value
157157

158158
// We skip Adds, as they will be removed along with the
159159
// fail/settles below.

0 commit comments

Comments
 (0)