Skip to content

Commit 3e73687

Browse files
authored
Merge pull request #425 from blinklabs-io/feat/generalize-proto-versions
feat: generalize protocol version handling
2 parents 824e32b + ef3d01c commit 3e73687

File tree

13 files changed

+510
-373
lines changed

13 files changed

+510
-373
lines changed

connection.go

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,13 @@ func (c *Connection) shutdown() {
221221
// setupConnection establishes the muxer, configures and starts the handshake process, and initializes
222222
// the appropriate mini-protocols
223223
func (c *Connection) setupConnection() error {
224+
// Check network magic value
225+
if c.networkMagic == 0 {
226+
return fmt.Errorf(
227+
"invalid network magic value provided: %d\n",
228+
c.networkMagic,
229+
)
230+
}
224231
// Start Goroutine to shutdown when doneChan is closed
225232
go func() {
226233
<-c.doneChan
@@ -255,36 +262,41 @@ func (c *Connection) setupConnection() error {
255262
Muxer: c.muxer,
256263
ErrorChan: c.protoErrorChan,
257264
}
258-
var protoVersions []uint16
259265
if c.useNodeToNodeProto {
260-
protoVersions = GetProtocolVersionsNtN()
261266
protoOptions.Mode = protocol.ProtocolModeNodeToNode
262267
} else {
263-
protoVersions = GetProtocolVersionsNtC()
264268
protoOptions.Mode = protocol.ProtocolModeNodeToClient
265269
}
266270
if c.server {
267271
protoOptions.Role = protocol.ProtocolRoleServer
268272
} else {
269273
protoOptions.Role = protocol.ProtocolRoleClient
270274
}
271-
// Check network magic value
272-
if c.networkMagic == 0 {
273-
return fmt.Errorf(
274-
"invalid network magic value provided: %d\n",
275-
c.networkMagic,
276-
)
275+
// Generate protocol version map for handshake
276+
handshakeDiffusionMode := protocol.DiffusionModeInitiatorOnly
277+
if c.fullDuplex {
278+
handshakeDiffusionMode = protocol.DiffusionModeInitiatorAndResponder
277279
}
280+
protoVersions := protocol.GetProtocolVersionMap(
281+
protoOptions.Mode,
282+
c.networkMagic,
283+
handshakeDiffusionMode,
284+
// TODO: make these configurable
285+
protocol.PeerSharingModeNoPeerSharing,
286+
protocol.QueryModeDisabled,
287+
)
278288
// Perform handshake
279289
var handshakeVersion uint16
280290
var handshakeFullDuplex bool
281291
handshakeConfig := handshake.NewConfig(
282-
handshake.WithProtocolVersions(protoVersions),
283-
handshake.WithNetworkMagic(c.networkMagic),
284-
handshake.WithClientFullDuplex(c.fullDuplex),
285-
handshake.WithFinishedFunc(func(version uint16, fullDuplex bool) error {
292+
handshake.WithProtocolVersionMap(protoVersions),
293+
handshake.WithFinishedFunc(func(version uint16, versionData protocol.VersionData) error {
286294
handshakeVersion = version
287-
handshakeFullDuplex = fullDuplex
295+
if c.useNodeToNodeProto {
296+
if versionData.DiffusionMode() == protocol.DiffusionModeInitiatorAndResponder {
297+
handshakeFullDuplex = true
298+
}
299+
}
288300
close(c.handshakeFinishedChan)
289301
return nil
290302
}),
@@ -307,10 +319,6 @@ func (c *Connection) setupConnection() error {
307319
}
308320
// Provide the negotiated protocol version to the various mini-protocols
309321
protoOptions.Version = handshakeVersion
310-
// Drop bit used to signify NtC protocol versions
311-
if protoOptions.Version > protocolVersionNtCFlag {
312-
protoOptions.Version = protoOptions.Version - protocolVersionNtCFlag
313-
}
314322
// Start Goroutine to pass along errors from the mini-protocols
315323
c.waitGroup.Add(1)
316324
go func() {
@@ -331,7 +339,7 @@ func (c *Connection) setupConnection() error {
331339
}()
332340
// Configure the relevant mini-protocols
333341
if c.useNodeToNodeProto {
334-
versionNtN := GetProtocolVersionNtN(handshakeVersion)
342+
versionNtN := protocol.GetProtocolVersion(handshakeVersion)
335343
protoOptions.Mode = protocol.ProtocolModeNodeToNode
336344
c.chainSync = chainsync.New(protoOptions, c.chainSyncConfig)
337345
c.blockFetch = blockfetch.New(protoOptions, c.blockFetchConfig)
@@ -365,7 +373,7 @@ func (c *Connection) setupConnection() error {
365373
}
366374
}
367375
} else {
368-
versionNtC := GetProtocolVersionNtC(handshakeVersion)
376+
versionNtC := protocol.GetProtocolVersion(handshakeVersion)
369377
protoOptions.Mode = protocol.ProtocolModeNodeToClient
370378
c.chainSync = chainsync.New(protoOptions, c.chainSyncConfig)
371379
c.localTxSubmission = localtxsubmission.New(protoOptions, c.localTxSubmissionConfig)

internal/test/ouroboros_mock/entry.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import (
2121

2222
const (
2323
MockNetworkMagic uint32 = 999999
24-
MockProtocolVersionNtC uint16 = 14
24+
MockProtocolVersionNtC uint16 = (14 + protocol.ProtocolVersionNtCOffset)
2525
)
2626

2727
type EntryType int
@@ -57,6 +57,9 @@ var ConversationEntryHandshakeResponse = ConversationEntry{
5757
ProtocolId: handshake.ProtocolId,
5858
IsResponse: true,
5959
OutputMessages: []protocol.Message{
60-
handshake.NewMsgAcceptVersion(MockProtocolVersionNtC, MockNetworkMagic),
60+
handshake.NewMsgAcceptVersion(
61+
MockProtocolVersionNtC,
62+
protocol.VersionDataNtC9to14(MockNetworkMagic),
63+
),
6164
},
6265
}

protocol/handshake/client.go

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -65,39 +65,7 @@ func (c *Client) Start() {
6565
c.onceStart.Do(func() {
6666
c.Protocol.Start()
6767
// Send our ProposeVersions message
68-
versionMap := make(map[uint16]interface{})
69-
diffusionMode := DiffusionModeInitiatorOnly
70-
if c.config.ClientFullDuplex {
71-
diffusionMode = DiffusionModeInitiatorAndResponder
72-
}
73-
for _, version := range c.config.ProtocolVersions {
74-
if c.Mode() == protocol.ProtocolModeNodeToNode {
75-
if version >= 11 {
76-
// TODO: make peer sharing mode configurable once it actually works
77-
versionMap[version] = NtNVersionDataPeerSharingQuery{
78-
NetworkMagic: c.config.NetworkMagic,
79-
InitiatorAndResponderDiffusionMode: diffusionMode,
80-
PeerSharing: PeerSharingModeNoPeerSharing,
81-
Query: QueryModeDisabled,
82-
}
83-
} else {
84-
versionMap[version] = NtNVersionDataLegacy{
85-
NetworkMagic: c.config.NetworkMagic,
86-
InitiatorAndResponderDiffusionMode: diffusionMode,
87-
}
88-
}
89-
} else {
90-
if (version - NodeToClientVersionOffset) >= 15 {
91-
versionMap[version] = NtCVersionData{
92-
NetworkMagic: c.config.NetworkMagic,
93-
Query: QueryModeDisabled,
94-
}
95-
} else {
96-
versionMap[version] = c.config.NetworkMagic
97-
}
98-
}
99-
}
100-
msg := NewMsgProposeVersions(versionMap)
68+
msg := NewMsgProposeVersions(c.config.ProtocolVersionMap)
10169
_ = c.SendMessage(msg)
10270
})
10371
}
@@ -119,24 +87,19 @@ func (c *Client) handleMessage(msg protocol.Message, isResponse bool) error {
11987
return err
12088
}
12189

122-
func (c *Client) handleAcceptVersion(msgGeneric protocol.Message) error {
90+
func (c *Client) handleAcceptVersion(msg protocol.Message) error {
12391
if c.config.FinishedFunc == nil {
12492
return fmt.Errorf(
12593
"received handshake AcceptVersion message but no callback function is defined",
12694
)
12795
}
128-
msg := msgGeneric.(*MsgAcceptVersion)
129-
fullDuplex := false
130-
if c.Mode() == protocol.ProtocolModeNodeToNode {
131-
// TODO: switch to using the VersionData types
132-
// this is more annoying than it would seem until we fix some other things
133-
versionData := msg.VersionData.([]interface{})
134-
//nolint:gosimple
135-
if versionData[1].(bool) == DiffusionModeInitiatorAndResponder {
136-
fullDuplex = true
137-
}
96+
msgAcceptVersion := msg.(*MsgAcceptVersion)
97+
protoVersion := protocol.GetProtocolVersion(msgAcceptVersion.Version)
98+
versionData, err := protoVersion.NewVersionDataFromCborFunc(msgAcceptVersion.VersionData)
99+
if err != nil {
100+
return err
138101
}
139-
return c.config.FinishedFunc(msg.Version, fullDuplex)
102+
return c.config.FinishedFunc(msgAcceptVersion.Version, versionData)
140103
}
141104

142105
func (c *Client) handleRefuse(msgGeneric protocol.Message) error {

protocol/handshake/handshake.go

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,6 @@ const (
2727
ProtocolId = 0
2828
)
2929

30-
// Diffusion modes
31-
const (
32-
DiffusionModeInitiatorOnly = true
33-
DiffusionModeInitiatorAndResponder = false
34-
)
35-
36-
// Peer sharing modes
37-
const (
38-
PeerSharingModeNoPeerSharing = 0
39-
PeerSharingModePeerSharingPrivate = 1
40-
PeerSharingModePeerSharingPublic = 2
41-
)
42-
43-
// Query modes
44-
const (
45-
QueryModeDisabled = false
46-
QueryModeEnabled = true
47-
)
48-
49-
// NtC version numbers have the 15th bit set
50-
const NodeToClientVersionOffset = 0x8000
51-
5230
var (
5331
statePropose = protocol.NewState(1, "Propose")
5432
stateConfirm = protocol.NewState(2, "Confirm")
@@ -92,15 +70,13 @@ type Handshake struct {
9270

9371
// Config is used to configure the Handshake protocol instance
9472
type Config struct {
95-
ProtocolVersions []uint16
96-
NetworkMagic uint32
97-
ClientFullDuplex bool
98-
FinishedFunc FinishedFunc
99-
Timeout time.Duration
73+
ProtocolVersionMap protocol.ProtocolVersionMap
74+
FinishedFunc FinishedFunc
75+
Timeout time.Duration
10076
}
10177

10278
// Callback function types
103-
type FinishedFunc func(uint16, bool) error
79+
type FinishedFunc func(uint16, protocol.VersionData) error
10480

10581
// New returns a new Handshake object
10682
func New(protoOptions protocol.ProtocolOptions, cfg *Config) *Handshake {
@@ -126,24 +102,10 @@ func NewConfig(options ...HandshakeOptionFunc) Config {
126102
return c
127103
}
128104

129-
// WithProtocolVersions specifies the supported protocol versions
130-
func WithProtocolVersions(versions []uint16) HandshakeOptionFunc {
131-
return func(c *Config) {
132-
c.ProtocolVersions = versions
133-
}
134-
}
135-
136-
// WithNetworkMagic specifies the network magic value
137-
func WithNetworkMagic(networkMagic uint32) HandshakeOptionFunc {
138-
return func(c *Config) {
139-
c.NetworkMagic = networkMagic
140-
}
141-
}
142-
143-
// WithClientFullDuplex specifies whether to request full duplex mode when acting as a client
144-
func WithClientFullDuplex(fullDuplex bool) HandshakeOptionFunc {
105+
// WithProtocolVersionMap specifies the supported protocol versions
106+
func WithProtocolVersionMap(versionMap protocol.ProtocolVersionMap) HandshakeOptionFunc {
145107
return func(c *Config) {
146-
c.ClientFullDuplex = fullDuplex
108+
c.ProtocolVersionMap = versionMap
147109
}
148110
}
149111

protocol/handshake/messages.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,37 +58,45 @@ func NewMsgFromCbor(msgType uint, data []byte) (protocol.Message, error) {
5858

5959
type MsgProposeVersions struct {
6060
protocol.MessageBase
61-
VersionMap map[uint16]interface{}
61+
VersionMap map[uint16]cbor.RawMessage
6262
}
6363

6464
func NewMsgProposeVersions(
65-
versionMap map[uint16]interface{},
65+
versionMap protocol.ProtocolVersionMap,
6666
) *MsgProposeVersions {
67+
rawVersionMap := map[uint16]cbor.RawMessage{}
68+
for version, versionData := range versionMap {
69+
// This should never fail with our known VersionData types
70+
cborData, _ := cbor.Encode(&versionData)
71+
rawVersionMap[version] = cbor.RawMessage(cborData)
72+
}
6773
m := &MsgProposeVersions{
6874
MessageBase: protocol.MessageBase{
6975
MessageType: MessageTypeProposeVersions,
7076
},
71-
VersionMap: versionMap,
77+
VersionMap: rawVersionMap,
7278
}
7379
return m
7480
}
7581

7682
type MsgAcceptVersion struct {
7783
protocol.MessageBase
7884
Version uint16
79-
VersionData interface{}
85+
VersionData cbor.RawMessage
8086
}
8187

8288
func NewMsgAcceptVersion(
8389
version uint16,
84-
versionData interface{},
90+
versionData protocol.VersionData,
8591
) *MsgAcceptVersion {
92+
// This should never fail with our known VersionData types
93+
cborData, _ := cbor.Encode(&versionData)
8694
m := &MsgAcceptVersion{
8795
MessageBase: protocol.MessageBase{
8896
MessageType: MessageTypeAcceptVersion,
8997
},
9098
Version: version,
91-
VersionData: versionData,
99+
VersionData: cbor.RawMessage(cborData),
92100
}
93101
return m
94102
}

protocol/handshake/messages_test.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,36 @@ var tests = []testDefinition{
3434
CborHex: "8200a4078202f4088202f4098202f40a8202f4",
3535
MessageType: MessageTypeProposeVersions,
3636
Message: NewMsgProposeVersions(
37-
map[uint16]interface{}{
38-
7: []interface{}{uint64(2), false},
39-
8: []interface{}{uint64(2), false},
40-
9: []interface{}{uint64(2), false},
41-
10: []interface{}{uint64(2), false},
37+
map[uint16]protocol.VersionData{
38+
7: protocol.VersionDataNtN7to10{
39+
CborNetworkMagic: 2,
40+
CborInitiatorAndResponderDiffusionMode: false,
41+
},
42+
8: protocol.VersionDataNtN7to10{
43+
CborNetworkMagic: 2,
44+
CborInitiatorAndResponderDiffusionMode: false,
45+
},
46+
9: protocol.VersionDataNtN7to10{
47+
CborNetworkMagic: 2,
48+
CborInitiatorAndResponderDiffusionMode: false,
49+
},
50+
10: protocol.VersionDataNtN7to10{
51+
CborNetworkMagic: 2,
52+
CborInitiatorAndResponderDiffusionMode: false,
53+
},
4254
},
4355
),
4456
},
4557
{
4658
CborHex: "83010a8202f4",
4759
MessageType: MessageTypeAcceptVersion,
48-
Message: NewMsgAcceptVersion(10, []interface{}{uint64(2), false}),
60+
Message: NewMsgAcceptVersion(
61+
10,
62+
protocol.VersionDataNtN7to10{
63+
CborNetworkMagic: 2,
64+
CborInitiatorAndResponderDiffusionMode: false,
65+
},
66+
),
4967
},
5068
{
5169
CborHex: "82028200840708090a",

protocol/handshake/server.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,17 @@ func (s *Server) handleProposeVersions(msgGeneric protocol.Message) error {
7070
}
7171
msg := msgGeneric.(*MsgProposeVersions)
7272
var highestVersion uint16
73-
var fullDuplex bool
74-
var versionData []interface{}
73+
var versionData protocol.VersionData
7574
for proposedVersion := range msg.VersionMap {
7675
if proposedVersion > highestVersion {
77-
for _, allowedVersion := range s.config.ProtocolVersions {
76+
for allowedVersion := range s.config.ProtocolVersionMap {
7877
if allowedVersion == proposedVersion {
7978
highestVersion = proposedVersion
80-
versionData = msg.VersionMap[proposedVersion].([]interface{})
81-
//nolint:gosimple
82-
if versionData[1].(bool) == DiffusionModeInitiatorAndResponder {
83-
fullDuplex = true
84-
} else {
85-
fullDuplex = false
79+
versionConfig := protocol.GetProtocolVersion(proposedVersion)
80+
tmpVersionData, err := versionConfig.NewVersionDataFromCborFunc(msg.VersionMap[proposedVersion])
81+
versionData = tmpVersionData
82+
if err != nil {
83+
return err
8684
}
8785
break
8886
}
@@ -94,7 +92,7 @@ func (s *Server) handleProposeVersions(msgGeneric protocol.Message) error {
9492
if err := s.SendMessage(resp); err != nil {
9593
return err
9694
}
97-
return s.config.FinishedFunc(highestVersion, fullDuplex)
95+
return s.config.FinishedFunc(highestVersion, versionData)
9896
} else {
9997
// TODO: handle failures
10098
// https://github.com/blinklabs-io/gouroboros/issues/32

0 commit comments

Comments
 (0)