Skip to content

Commit d9709b8

Browse files
Roasbeefguggero
authored andcommitted
channeldb: convert HTLCEntry to use tlv.RecordT
1 parent 59f73c7 commit d9709b8

File tree

5 files changed

+159
-154
lines changed

5 files changed

+159
-154
lines changed

channeldb/channel_test.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -579,15 +579,21 @@ func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment,
579579
require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch")
580580
for i, rHtlc := range r.HTLCEntries {
581581
cHtlc := c.Htlcs[i]
582-
require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch")
583-
require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(),
584-
"Amt mismatch")
585-
require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout,
586-
"RefundTimeout mismatch")
587-
require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex,
588-
"OutputIndex mismatch")
589-
require.Equal(t, rHtlc.Incoming, cHtlc.Incoming,
590-
"Incoming mismatch")
582+
require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash")
583+
require.Equal(
584+
t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt",
585+
)
586+
require.Equal(
587+
t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout,
588+
"RefundTimeout",
589+
)
590+
require.EqualValues(
591+
t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex,
592+
"OutputIndex",
593+
)
594+
require.Equal(
595+
t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming",
596+
)
591597
}
592598
}
593599

channeldb/revocation_log.go

Lines changed: 104 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,82 @@ var (
5454
ErrOutputIndexTooBig = errors.New("output index is over uint16")
5555
)
5656

57+
// SparsePayHash is a type alias for a 32 byte array, which when serialized is
58+
// able to save some space by not including an empty payment hash on disk.
59+
type SparsePayHash [32]byte
60+
61+
// NewSparsePayHash creates a new SparsePayHash from a 32 byte array.
62+
func NewSparsePayHash(rHash [32]byte) SparsePayHash {
63+
return SparsePayHash(rHash)
64+
}
65+
66+
// Record returns a tlv record for the SparsePayHash.
67+
func (s *SparsePayHash) Record() tlv.Record {
68+
// We use a zero for the type here, as this'll be used along with the
69+
// RecordT type.
70+
return tlv.MakeDynamicRecord(
71+
0, s, s.hashLen,
72+
sparseHashEncoder, sparseHashDecoder,
73+
)
74+
}
75+
76+
// hashLen is used by MakeDynamicRecord to return the size of the RHash.
77+
//
78+
// NOTE: for zero hash, we return a length 0.
79+
func (s *SparsePayHash) hashLen() uint64 {
80+
if bytes.Equal(s[:], lntypes.ZeroHash[:]) {
81+
return 0
82+
}
83+
84+
return 32
85+
}
86+
87+
// sparseHashEncoder is the customized encoder which skips encoding the empty
88+
// hash.
89+
func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
90+
v, ok := val.(*SparsePayHash)
91+
if !ok {
92+
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
93+
}
94+
95+
// If the value is an empty hash, we will skip encoding it.
96+
if bytes.Equal(v[:], lntypes.ZeroHash[:]) {
97+
return nil
98+
}
99+
100+
vArray := (*[32]byte)(v)
101+
102+
return tlv.EBytes32(w, vArray, buf)
103+
}
104+
105+
// sparseHashDecoder is the customized decoder which skips decoding the empty
106+
// hash.
107+
func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte,
108+
l uint64) error {
109+
110+
v, ok := val.(*SparsePayHash)
111+
if !ok {
112+
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
113+
}
114+
115+
// If the length is zero, we will skip encoding the empty hash.
116+
if l == 0 {
117+
return nil
118+
}
119+
120+
vArray := (*[32]byte)(v)
121+
122+
if err := tlv.DBytes32(r, vArray, buf, 32); err != nil {
123+
return err
124+
}
125+
126+
vHash := SparsePayHash(*vArray)
127+
128+
v = &vHash
129+
130+
return nil
131+
}
132+
57133
// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the
58134
// historical HTLCs, which is useful for constructing RevocationLog when a
59135
// breach is detected.
@@ -72,116 +148,60 @@ var (
72148
// made into tlv records without further conversion.
73149
type HTLCEntry struct {
74150
// RHash is the payment hash of the HTLC.
75-
RHash [32]byte
151+
RHash tlv.RecordT[tlv.TlvType0, SparsePayHash]
76152

77153
// RefundTimeout is the absolute timeout on the HTLC that the sender
78154
// must wait before reclaiming the funds in limbo.
79-
RefundTimeout uint32
155+
RefundTimeout tlv.RecordT[tlv.TlvType1, uint32]
80156

81157
// OutputIndex is the output index for this particular HTLC output
82158
// within the commitment transaction.
83159
//
84160
// NOTE: we use uint16 instead of int32 here to save us 2 bytes, which
85161
// gives us a max number of HTLCs of 65K.
86-
OutputIndex uint16
162+
OutputIndex tlv.RecordT[tlv.TlvType2, uint16]
87163

88164
// Incoming denotes whether we're the receiver or the sender of this
89165
// HTLC.
90166
//
91167
// NOTE: this field is the memory representation of the field
92168
// incomingUint.
93-
Incoming bool
169+
Incoming tlv.RecordT[tlv.TlvType3, bool]
94170

95171
// Amt is the amount of satoshis this HTLC escrows.
96172
//
97173
// NOTE: this field is the memory representation of the field amtUint.
98-
Amt btcutil.Amount
99-
100-
// amtTlv is the uint64 format of Amt. This field is created so we can
101-
// easily make it into a tlv record and save it to disk.
102-
//
103-
// NOTE: we keep this field for accounting purpose only. If the disk
104-
// space becomes an issue, we could delete this field to save us extra
105-
// 8 bytes.
106-
amtTlv uint64
107-
108-
// incomingTlv is the uint8 format of Incoming. This field is created
109-
// so we can easily make it into a tlv record and save it to disk.
110-
incomingTlv uint8
111-
}
112-
113-
// RHashLen is used by MakeDynamicRecord to return the size of the RHash.
114-
//
115-
// NOTE: for zero hash, we return a length 0.
116-
func (h *HTLCEntry) RHashLen() uint64 {
117-
if h.RHash == lntypes.ZeroHash {
118-
return 0
119-
}
120-
return 32
121-
}
122-
123-
// RHashEncoder is the customized encoder which skips encoding the empty hash.
124-
func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
125-
v, ok := val.(*[32]byte)
126-
if !ok {
127-
return tlv.NewTypeForEncodingErr(val, "RHash")
128-
}
129-
130-
// If the value is an empty hash, we will skip encoding it.
131-
if *v == lntypes.ZeroHash {
132-
return nil
133-
}
134-
135-
return tlv.EBytes32(w, v, buf)
136-
}
137-
138-
// RHashDecoder is the customized decoder which skips decoding the empty hash.
139-
func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
140-
v, ok := val.(*[32]byte)
141-
if !ok {
142-
return tlv.NewTypeForEncodingErr(val, "RHash")
143-
}
144-
145-
// If the length is zero, we will skip encoding the empty hash.
146-
if l == 0 {
147-
return nil
148-
}
149-
150-
return tlv.DBytes32(r, v, buf, 32)
174+
Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]]
151175
}
152176

153177
// toTlvStream converts an HTLCEntry record into a tlv representation.
154178
func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) {
155-
const (
156-
// A set of tlv type definitions used to serialize htlc entries
157-
// to the database. We define it here instead of the head of
158-
// the file to avoid naming conflicts.
159-
//
160-
// NOTE: A migration should be added whenever this list
161-
// changes.
162-
rHashType tlv.Type = 0
163-
refundTimeoutType tlv.Type = 1
164-
outputIndexType tlv.Type = 2
165-
incomingType tlv.Type = 3
166-
amtType tlv.Type = 4
179+
return tlv.NewStream(
180+
h.RHash.Record(),
181+
h.RefundTimeout.Record(),
182+
h.OutputIndex.Record(),
183+
h.Incoming.Record(),
184+
h.Amt.Record(),
167185
)
186+
}
168187

169-
return tlv.NewStream(
170-
tlv.MakeDynamicRecord(
171-
rHashType, &h.RHash, h.RHashLen,
172-
RHashEncoder, RHashDecoder,
188+
// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC.
189+
func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry {
190+
return &HTLCEntry{
191+
RHash: tlv.NewRecordT[tlv.TlvType0](
192+
NewSparsePayHash(htlc.RHash),
173193
),
174-
tlv.MakePrimitiveRecord(
175-
refundTimeoutType, &h.RefundTimeout,
194+
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1](
195+
htlc.RefundTimeout,
176196
),
177-
tlv.MakePrimitiveRecord(
178-
outputIndexType, &h.OutputIndex,
197+
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2](
198+
uint16(htlc.OutputIndex),
179199
),
180-
tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv),
181-
// We will save 3 bytes if the amount is less or equal to
182-
// 4,294,967,295 msat, or roughly 0.043 bitcoin.
183-
tlv.MakeBigSizeRecord(amtType, &h.amtTlv),
184-
)
200+
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming),
201+
Amt: tlv.NewRecordT[tlv.TlvType4](
202+
tlv.NewBigSizeT(htlc.Amt.ToSatoshis()),
203+
),
204+
}
185205
}
186206

187207
// RevocationLog stores the info needed to construct a breach retribution. Its
@@ -265,13 +285,7 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment,
265285
return ErrOutputIndexTooBig
266286
}
267287

268-
entry := &HTLCEntry{
269-
RHash: htlc.RHash,
270-
RefundTimeout: htlc.RefundTimeout,
271-
Incoming: htlc.Incoming,
272-
OutputIndex: uint16(htlc.OutputIndex),
273-
Amt: htlc.Amt.ToSatoshis(),
274-
}
288+
entry := NewHTLCEntryFromHTLC(htlc)
275289
rl.HTLCEntries = append(rl.HTLCEntries, entry)
276290
}
277291

@@ -351,14 +365,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
351365
// format.
352366
func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error {
353367
for _, htlc := range htlcs {
354-
// Patch the incomingTlv field.
355-
if htlc.Incoming {
356-
htlc.incomingTlv = 1
357-
}
358-
359-
// Patch the amtTlv field.
360-
htlc.amtTlv = uint64(htlc.Amt)
361-
362368
// Create the tlv stream.
363369
tlvStream, err := htlc.toTlvStream()
364370
if err != nil {
@@ -447,14 +453,6 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
447453
return nil, err
448454
}
449455

450-
// Patch the Incoming field.
451-
if htlc.incomingTlv == 1 {
452-
htlc.Incoming = true
453-
}
454-
455-
// Patch the Amt field.
456-
htlc.Amt = btcutil.Amount(htlc.amtTlv)
457-
458456
// Append the entry.
459457
htlcs = append(htlcs, &htlc)
460458
}
@@ -469,6 +467,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
469467
if err := s.Encode(&b); err != nil {
470468
return err
471469
}
470+
472471
// Write the stream's length as a varint.
473472
err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{})
474473
if err != nil {

channeldb/revocation_log_test.go

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,16 @@ var (
3434
}
3535

3636
testHTLCEntry = HTLCEntry{
37-
RefundTimeout: 740_000,
38-
OutputIndex: 10,
39-
Incoming: true,
40-
Amt: 1000_000,
41-
amtTlv: 1000_000,
42-
incomingTlv: 1,
37+
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32](
38+
740_000,
39+
),
40+
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](
41+
10,
42+
),
43+
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true),
44+
Amt: tlv.NewRecordT[tlv.TlvType4](
45+
tlv.NewBigSizeT(btcutil.Amount(1_000_000)),
46+
),
4347
}
4448
testHTLCEntryBytes = []byte{
4549
// Body length 23.
@@ -68,11 +72,11 @@ var (
6872
CommitTx: channels.TestFundingTx,
6973
CommitSig: bytes.Repeat([]byte{1}, 71),
7074
Htlcs: []HTLC{{
71-
RefundTimeout: testHTLCEntry.RefundTimeout,
72-
OutputIndex: int32(testHTLCEntry.OutputIndex),
73-
Incoming: testHTLCEntry.Incoming,
75+
RefundTimeout: testHTLCEntry.RefundTimeout.Val,
76+
OutputIndex: int32(testHTLCEntry.OutputIndex.Val),
77+
Incoming: testHTLCEntry.Incoming.Val,
7478
Amt: lnwire.NewMSatFromSatoshis(
75-
testHTLCEntry.Amt,
79+
testHTLCEntry.Amt.Val.Int(),
7680
),
7781
}},
7882
}
@@ -193,11 +197,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) {
193197
// Copy the testHTLCEntry.
194198
entry := testHTLCEntry
195199

196-
// Set the internal fields to empty values so we can test the bytes are
197-
// padded.
198-
entry.incomingTlv = 0
199-
entry.amtTlv = 0
200-
201200
// Write the tlv stream.
202201
buf := bytes.NewBuffer([]byte{})
203202
err := serializeHTLCEntries(buf, []*HTLCEntry{&entry})
@@ -215,7 +214,7 @@ func TestSerializeHTLCEntries(t *testing.T) {
215214

216215
// Create a fake rHash.
217216
rHashBytes := bytes.Repeat([]byte{10}, 32)
218-
copy(entry.RHash[:], rHashBytes)
217+
copy(entry.RHash.Val[:], rHashBytes)
219218

220219
// Construct the serialized bytes.
221220
//
@@ -330,7 +329,7 @@ func TestDerializeHTLCEntries(t *testing.T) {
330329

331330
// Create a fake rHash.
332331
rHashBytes := bytes.Repeat([]byte{10}, 32)
333-
copy(entry.RHash[:], rHashBytes)
332+
copy(entry.RHash.Val[:], rHashBytes)
334333

335334
// Construct the serialized bytes.
336335
//

0 commit comments

Comments
 (0)