Skip to content

Commit f3655da

Browse files
Roasbeefguggero
authored andcommitted
channeldb: convert HTLCEntry to use tlv.RecordT
1 parent 258dd5c commit f3655da

File tree

5 files changed

+209
-163
lines changed

5 files changed

+209
-163
lines changed

channeldb/channel_test.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -579,15 +579,21 @@ func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment,
579579
require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch")
580580
for i, rHtlc := range r.HTLCEntries {
581581
cHtlc := c.Htlcs[i]
582-
require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch")
583-
require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(),
584-
"Amt mismatch")
585-
require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout,
586-
"RefundTimeout mismatch")
587-
require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex,
588-
"OutputIndex mismatch")
589-
require.Equal(t, rHtlc.Incoming, cHtlc.Incoming,
590-
"Incoming mismatch")
582+
require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash")
583+
require.Equal(
584+
t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt",
585+
)
586+
require.Equal(
587+
t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout,
588+
"RefundTimeout",
589+
)
590+
require.EqualValues(
591+
t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex,
592+
"OutputIndex",
593+
)
594+
require.Equal(
595+
t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming",
596+
)
591597
}
592598
}
593599

channeldb/revocation_log.go

Lines changed: 96 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,74 @@ var (
5454
ErrOutputIndexTooBig = errors.New("output index is over uint16")
5555
)
5656

57+
// SparsePayHash is a type alias for a 32 byte array, which when serialized is
58+
// able to save some space by not including an empty payment hash on disk.
59+
type SparsePayHash [32]byte
60+
61+
// NewSparsePayHash creates a new SparsePayHash from a 32 byte array.
62+
func NewSparsePayHash(rHash [32]byte) SparsePayHash {
63+
return SparsePayHash(rHash)
64+
}
65+
66+
// Record returns a tlv record for the SparsePayHash.
67+
func (s *SparsePayHash) Record() tlv.Record {
68+
// We use a zero for the type here, as this'll be used along with the
69+
// RecordT type.
70+
return tlv.MakeDynamicRecord(
71+
0, s, s.hashLen,
72+
sparseHashEncoder, sparseHashDecoder,
73+
)
74+
}
75+
76+
// hashLen is used by MakeDynamicRecord to return the size of the RHash.
77+
//
78+
// NOTE: for zero hash, we return a length 0.
79+
func (s *SparsePayHash) hashLen() uint64 {
80+
if bytes.Equal(s[:], lntypes.ZeroHash[:]) {
81+
return 0
82+
}
83+
84+
return 32
85+
}
86+
87+
// sparseHashEncoder is the customized encoder which skips encoding the empty
88+
// hash.
89+
func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
90+
v, ok := val.(*SparsePayHash)
91+
if !ok {
92+
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
93+
}
94+
95+
// If the value is an empty hash, we will skip encoding it.
96+
if bytes.Equal(v[:], lntypes.ZeroHash[:]) {
97+
return nil
98+
}
99+
100+
vArray := (*[32]byte)(v)
101+
102+
return tlv.EBytes32(w, vArray, buf)
103+
}
104+
105+
// sparseHashDecoder is the customized decoder which skips decoding the empty
106+
// hash.
107+
func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte,
108+
l uint64) error {
109+
110+
v, ok := val.(*SparsePayHash)
111+
if !ok {
112+
return tlv.NewTypeForEncodingErr(val, "SparsePayHash")
113+
}
114+
115+
// If the length is zero, we will skip encoding the empty hash.
116+
if l == 0 {
117+
return nil
118+
}
119+
120+
vArray := (*[32]byte)(v)
121+
122+
return tlv.DBytes32(r, vArray, buf, 32)
123+
}
124+
57125
// HTLCEntry specifies the minimal info needed to be stored on disk for ALL the
58126
// historical HTLCs, which is useful for constructing RevocationLog when a
59127
// breach is detected.
@@ -72,116 +140,60 @@ var (
72140
// made into tlv records without further conversion.
73141
type HTLCEntry struct {
74142
// RHash is the payment hash of the HTLC.
75-
RHash [32]byte
143+
RHash tlv.RecordT[tlv.TlvType0, SparsePayHash]
76144

77145
// RefundTimeout is the absolute timeout on the HTLC that the sender
78146
// must wait before reclaiming the funds in limbo.
79-
RefundTimeout uint32
147+
RefundTimeout tlv.RecordT[tlv.TlvType1, uint32]
80148

81149
// OutputIndex is the output index for this particular HTLC output
82150
// within the commitment transaction.
83151
//
84152
// NOTE: we use uint16 instead of int32 here to save us 2 bytes, which
85153
// gives us a max number of HTLCs of 65K.
86-
OutputIndex uint16
154+
OutputIndex tlv.RecordT[tlv.TlvType2, uint16]
87155

88156
// Incoming denotes whether we're the receiver or the sender of this
89157
// HTLC.
90158
//
91159
// NOTE: this field is the memory representation of the field
92160
// incomingUint.
93-
Incoming bool
161+
Incoming tlv.RecordT[tlv.TlvType3, bool]
94162

95163
// Amt is the amount of satoshis this HTLC escrows.
96164
//
97165
// NOTE: this field is the memory representation of the field amtUint.
98-
Amt btcutil.Amount
99-
100-
// amtTlv is the uint64 format of Amt. This field is created so we can
101-
// easily make it into a tlv record and save it to disk.
102-
//
103-
// NOTE: we keep this field for accounting purpose only. If the disk
104-
// space becomes an issue, we could delete this field to save us extra
105-
// 8 bytes.
106-
amtTlv uint64
107-
108-
// incomingTlv is the uint8 format of Incoming. This field is created
109-
// so we can easily make it into a tlv record and save it to disk.
110-
incomingTlv uint8
111-
}
112-
113-
// RHashLen is used by MakeDynamicRecord to return the size of the RHash.
114-
//
115-
// NOTE: for zero hash, we return a length 0.
116-
func (h *HTLCEntry) RHashLen() uint64 {
117-
if h.RHash == lntypes.ZeroHash {
118-
return 0
119-
}
120-
return 32
121-
}
122-
123-
// RHashEncoder is the customized encoder which skips encoding the empty hash.
124-
func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
125-
v, ok := val.(*[32]byte)
126-
if !ok {
127-
return tlv.NewTypeForEncodingErr(val, "RHash")
128-
}
129-
130-
// If the value is an empty hash, we will skip encoding it.
131-
if *v == lntypes.ZeroHash {
132-
return nil
133-
}
134-
135-
return tlv.EBytes32(w, v, buf)
136-
}
137-
138-
// RHashDecoder is the customized decoder which skips decoding the empty hash.
139-
func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
140-
v, ok := val.(*[32]byte)
141-
if !ok {
142-
return tlv.NewTypeForEncodingErr(val, "RHash")
143-
}
144-
145-
// If the length is zero, we will skip encoding the empty hash.
146-
if l == 0 {
147-
return nil
148-
}
149-
150-
return tlv.DBytes32(r, v, buf, 32)
166+
Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]]
151167
}
152168

153169
// toTlvStream converts an HTLCEntry record into a tlv representation.
154170
func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) {
155-
const (
156-
// A set of tlv type definitions used to serialize htlc entries
157-
// to the database. We define it here instead of the head of
158-
// the file to avoid naming conflicts.
159-
//
160-
// NOTE: A migration should be added whenever this list
161-
// changes.
162-
rHashType tlv.Type = 0
163-
refundTimeoutType tlv.Type = 1
164-
outputIndexType tlv.Type = 2
165-
incomingType tlv.Type = 3
166-
amtType tlv.Type = 4
171+
return tlv.NewStream(
172+
h.RHash.Record(),
173+
h.RefundTimeout.Record(),
174+
h.OutputIndex.Record(),
175+
h.Incoming.Record(),
176+
h.Amt.Record(),
167177
)
178+
}
168179

169-
return tlv.NewStream(
170-
tlv.MakeDynamicRecord(
171-
rHashType, &h.RHash, h.RHashLen,
172-
RHashEncoder, RHashDecoder,
180+
// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC.
181+
func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry {
182+
return &HTLCEntry{
183+
RHash: tlv.NewRecordT[tlv.TlvType0](
184+
NewSparsePayHash(htlc.RHash),
173185
),
174-
tlv.MakePrimitiveRecord(
175-
refundTimeoutType, &h.RefundTimeout,
186+
RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1](
187+
htlc.RefundTimeout,
176188
),
177-
tlv.MakePrimitiveRecord(
178-
outputIndexType, &h.OutputIndex,
189+
OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2](
190+
uint16(htlc.OutputIndex),
179191
),
180-
tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv),
181-
// We will save 3 bytes if the amount is less or equal to
182-
// 4,294,967,295 msat, or roughly 0.043 bitcoin.
183-
tlv.MakeBigSizeRecord(amtType, &h.amtTlv),
184-
)
192+
Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming),
193+
Amt: tlv.NewRecordT[tlv.TlvType4](
194+
tlv.NewBigSizeT(htlc.Amt.ToSatoshis()),
195+
),
196+
}
185197
}
186198

187199
// RevocationLog stores the info needed to construct a breach retribution. Its
@@ -265,13 +277,7 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment,
265277
return ErrOutputIndexTooBig
266278
}
267279

268-
entry := &HTLCEntry{
269-
RHash: htlc.RHash,
270-
RefundTimeout: htlc.RefundTimeout,
271-
Incoming: htlc.Incoming,
272-
OutputIndex: uint16(htlc.OutputIndex),
273-
Amt: htlc.Amt.ToSatoshis(),
274-
}
280+
entry := NewHTLCEntryFromHTLC(htlc)
275281
rl.HTLCEntries = append(rl.HTLCEntries, entry)
276282
}
277283

@@ -351,14 +357,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error {
351357
// format.
352358
func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error {
353359
for _, htlc := range htlcs {
354-
// Patch the incomingTlv field.
355-
if htlc.Incoming {
356-
htlc.incomingTlv = 1
357-
}
358-
359-
// Patch the amtTlv field.
360-
htlc.amtTlv = uint64(htlc.Amt)
361-
362360
// Create the tlv stream.
363361
tlvStream, err := htlc.toTlvStream()
364362
if err != nil {
@@ -447,14 +445,6 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) {
447445
return nil, err
448446
}
449447

450-
// Patch the Incoming field.
451-
if htlc.incomingTlv == 1 {
452-
htlc.Incoming = true
453-
}
454-
455-
// Patch the Amt field.
456-
htlc.Amt = btcutil.Amount(htlc.amtTlv)
457-
458448
// Append the entry.
459449
htlcs = append(htlcs, &htlc)
460450
}
@@ -469,6 +459,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error {
469459
if err := s.Encode(&b); err != nil {
470460
return err
471461
}
462+
472463
// Write the stream's length as a varint.
473464
err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{})
474465
if err != nil {

0 commit comments

Comments
 (0)