Skip to content

Commit 31a30d2

Browse files
committed
Add TLV stream to Init message
1 parent 8ae6fd2 commit 31a30d2

File tree

5 files changed

+54
-46
lines changed

5 files changed

+54
-46
lines changed

eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,14 @@ object InitTlvCodecs {
4040

4141
import InitTlv._
4242

43-
// TODO: wire test-cases:
44-
// * Init not containing any tlv stream
45-
// * Init containing tlv stream without networks (other odd records)
46-
// * Init containing tlv stream without networks (other even records)
47-
// * Init containing tlv stream with networks only
48-
// * Init containing tlv stream with networks and other odd records)
49-
// * Init containing tlv stream with networks and other even records)
50-
5143
// TODO:
52-
// * Add to the Init message after flat features merged
5344
// * Send the chainHash from nodeParams when creating Init
5445
// * Add logic to Peer.scala to fail connections to others that don't offer my chainHash
5546

5647
private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks]
5748

58-
val initTlvCodec = discriminated[InitTlv].by(varint)
49+
val initTlvCodec = TlvCodecs.tlvStream(discriminated[InitTlv].by(varint)
5950
.typecase(UInt64(1), networks)
51+
)
6052

6153
}

eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ object LightningMessageCodecs {
3939
},
4040
{ features => (ByteVector.empty, features) })
4141

42-
val initCodec: Codec[Init] = combinedFeaturesCodec.as[Init]
42+
val initCodec: Codec[Init] = (("features" | combinedFeaturesCodec) :: ("tlvStream" | InitTlvCodecs.initTlvCodec)).as[Init]
4343

4444
val errorCodec: Codec[Error] = (
4545
("channelId" | bytes32) ::

eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ sealed trait HasChainHash extends LightningMessage { def chainHash: ByteVector32
4545
sealed trait UpdateMessage extends HtlcMessage // <- not in the spec
4646
// @formatter:on
4747

48-
case class Init(features: ByteVector) extends SetupMessage
48+
case class Init(features: ByteVector, tlvs: TlvStream[InitTlv] = TlvStream.empty) extends SetupMessage {
49+
val networks = tlvs.get[InitTlv.Networks].map(_.chainHashes).getOrElse(Nil)
50+
}
4951

5052
case class Error(channelId: ByteVector32, data: ByteVector) extends SetupMessage with HasChannelId {
5153
def toAscii: String = if (fr.acinq.eclair.isAsciiPrintable(data)) new String(data.toArray, StandardCharsets.US_ASCII) else "n/a"

eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,41 +22,40 @@ import scodec.bits.ByteVector
2222
import scala.reflect.ClassTag
2323

2424
/**
25-
* Created by t-bast on 20/06/2019.
26-
*/
25+
* Created by t-bast on 20/06/2019.
26+
*/
2727

2828
trait Tlv
2929

3030
/**
31-
* Generic tlv type we fallback to if we don't understand the incoming tlv.
32-
*
33-
* @param tag tlv tag.
34-
* @param value tlv value (length is implicit, and encoded as a varint).
35-
*/
31+
* Generic tlv type we fallback to if we don't understand the incoming tlv.
32+
*
33+
* @param tag tlv tag.
34+
* @param value tlv value (length is implicit, and encoded as a varint).
35+
*/
3636
case class GenericTlv(tag: UInt64, value: ByteVector) extends Tlv
3737

3838
/**
39-
* A tlv stream is a collection of tlv records.
40-
* A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records.
41-
* That namespace is provided by a trait extending the top-level tlv trait.
42-
*
43-
* @param records known tlv records.
44-
* @param unknown unknown tlv records.
45-
* @tparam T the stream namespace is a trait extending the top-level tlv trait.
46-
*/
39+
* A tlv stream is a collection of tlv records.
40+
* A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records.
41+
* That namespace is provided by a trait extending the top-level tlv trait.
42+
*
43+
* @param records known tlv records.
44+
* @param unknown unknown tlv records.
45+
* @tparam T the stream namespace is a trait extending the top-level tlv trait.
46+
*/
4747
case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) {
4848
/**
49-
*
50-
* @tparam R input type parameter, must be a subtype of the main TLV type
51-
* @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify
52-
* that TLV records are supposed to be unique)
53-
*/
49+
*
50+
* @tparam R input type parameter, must be a subtype of the main TLV type
51+
* @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify
52+
* that TLV records are supposed to be unique)
53+
*/
5454
def get[R <: T : ClassTag]: Option[R] = records.collectFirst { case r: R => r }
5555
}
5656

5757
object TlvStream {
58-
def empty[T <: Tlv] = TlvStream[T](Nil, Nil)
58+
def empty[T <: Tlv]: TlvStream[T] = TlvStream[T](Nil, Nil)
5959

6060
def apply[T <: Tlv](records: T*): TlvStream[T] = TlvStream(records, Nil)
61-
6261
}

eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,36 @@ class LightningMessageCodecsSpec extends FunSuite {
4444
def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey
4545

4646
test("encode/decode init message") {
47+
val chainHash1 = ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101")
48+
val chainHash2 = ByteVector32(hex"0202020202020202020202020202020202020202020202020202020202020202")
4749
val testCases = Seq(
48-
(hex"0000 0000", hex"", hex"0000 0000"), // no features
49-
(hex"0000 0002088a", hex"088a", hex"0000 0002088a"), // no global features
50-
(hex"00020200 0000", hex"0200", hex"0000 00020200"), // no local features
51-
(hex"00020200 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - no conflict - same size
52-
(hex"00020200 0003020002", hex"020202", hex"0000 0003020202"), // local and global - no conflict - different sizes
53-
(hex"00020a02 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - conflict - same size
54-
(hex"00022200 000302aaa2", hex"02aaa2", hex"0000 000302aaa2") // local and global - conflict - different sizes
50+
(hex"0000 0000", hex"", Nil, true, None), // no features
51+
(hex"0000 0002088a", hex"088a", Nil, true, None), // no global features
52+
(hex"00020200 0000", hex"0200", Nil, true, Some(hex"0000 00020200")), // no local features
53+
(hex"00020200 0002088a", hex"0a8a", Nil, true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size
54+
(hex"00020200 0003020002", hex"020202", Nil, true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes
55+
(hex"00020a02 0002088a", hex"0a8a", Nil, true, Some(hex"0000 00020a8a")), // local and global - conflict - same size
56+
(hex"00022200 000302aaa2", hex"02aaa2", Nil, true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes
57+
(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, true, None), // unknown odd records
58+
(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, false, None), // unknown even records
59+
(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, false, None), // invalid tlv stream
60+
(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), true, None), // single network
61+
(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), true, None), // multiple networks
62+
(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010103012a", hex"088a", List(chainHash1), true, None), // network and unknown odd records
63+
(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, false, None) // network and unknown even records
5564
)
5665

57-
for ((bin, features, encoded) <- testCases) {
58-
val init = initCodec.decode(bin.bits).require.value
59-
assert(init.features === features)
60-
assert(initCodec.encode(init).require.bytes === encoded)
61-
assert(initCodec.decode(encoded.bits).require.value === init)
66+
for ((bin, features, networks, valid, encodedOverride) <- testCases) {
67+
if (valid) {
68+
val init = initCodec.decode(bin.bits).require.value
69+
assert(init.features === features)
70+
assert(init.networks === networks)
71+
val encoded = initCodec.encode(init).require
72+
assert(encoded.bytes === encodedOverride.getOrElse(bin))
73+
assert(initCodec.decode(encoded).require.value === init)
74+
} else {
75+
assert(initCodec.decode(bin.bits).isFailure)
76+
}
6277
}
6378
}
6479

0 commit comments

Comments
 (0)