Skip to content

Commit 74fb777

Browse files
committed
feat: handshake server-side refusal
Fixes #32 and #39
1 parent 3e73687 commit 74fb777

File tree

4 files changed

+215
-28
lines changed

4 files changed

+215
-28
lines changed

internal/test/ouroboros_mock/connection.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ func (c *Connection) processInputEntry(entry ConversationEntry) error {
185185
if msg == nil {
186186
return fmt.Errorf("received unknown message type: %d", msgType)
187187
}
188+
// Set CBOR for expected message to match received to make comparison easier
189+
entry.InputMessage.SetCbor(msg.Cbor())
190+
// Compare received message to expected message
188191
if !reflect.DeepEqual(msg, entry.InputMessage) {
189192
return fmt.Errorf(
190193
"parsed message does not match expected value: got %#v, expected %#v",

protocol/handshake/client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
"github.com/blinklabs-io/gouroboros/internal/test/ouroboros_mock"
2323
)
2424

25-
func TestBasicHandshake(t *testing.T) {
25+
func TestClientBasicHandshake(t *testing.T) {
2626
mockConn := ouroboros_mock.NewConnection(
2727
ouroboros_mock.ProtocolRoleClient,
2828
[]ouroboros_mock.ConversationEntry{
@@ -52,7 +52,7 @@ func TestBasicHandshake(t *testing.T) {
5252
}
5353
}
5454

55-
func TestDoubleStart(t *testing.T) {
55+
func TestClientDoubleStart(t *testing.T) {
5656
mockConn := ouroboros_mock.NewConnection(
5757
ouroboros_mock.ProtocolRoleClient,
5858
[]ouroboros_mock.ConversationEntry{

protocol/handshake/server.go

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,40 +62,89 @@ func (s *Server) handleMessage(msg protocol.Message, isResponse bool) error {
6262
return err
6363
}
6464

65-
func (s *Server) handleProposeVersions(msgGeneric protocol.Message) error {
65+
func (s *Server) handleProposeVersions(msg protocol.Message) error {
6666
if s.config.FinishedFunc == nil {
6767
return fmt.Errorf(
6868
"received handshake ProposeVersions message but no callback function is defined",
6969
)
7070
}
71-
msg := msgGeneric.(*MsgProposeVersions)
72-
var highestVersion uint16
73-
var versionData protocol.VersionData
74-
for proposedVersion := range msg.VersionMap {
75-
if proposedVersion > highestVersion {
76-
for allowedVersion := range s.config.ProtocolVersionMap {
77-
if allowedVersion == proposedVersion {
78-
highestVersion = proposedVersion
79-
versionConfig := protocol.GetProtocolVersion(proposedVersion)
80-
tmpVersionData, err := versionConfig.NewVersionDataFromCborFunc(msg.VersionMap[proposedVersion])
81-
versionData = tmpVersionData
82-
if err != nil {
83-
return err
84-
}
85-
break
86-
}
87-
}
71+
msgProposeVersions := msg.(*MsgProposeVersions)
72+
// Compute intersection of supported and proposed protocol versions
73+
var versionIntersect []uint16
74+
for proposedVersion := range msgProposeVersions.VersionMap {
75+
if _, ok := s.config.ProtocolVersionMap[proposedVersion]; ok {
76+
versionIntersect = append(versionIntersect, proposedVersion)
8877
}
8978
}
90-
if highestVersion > 0 {
91-
resp := NewMsgAcceptVersion(highestVersion, versionData)
92-
if err := s.SendMessage(resp); err != nil {
79+
// Send refusal if there are no matching versions
80+
if len(versionIntersect) == 0 {
81+
var supportedVersions []uint16
82+
for supportedVersion := range s.config.ProtocolVersionMap {
83+
supportedVersions = append(supportedVersions, supportedVersion)
84+
}
85+
msgRefuse := NewMsgRefuse(
86+
[]any{
87+
RefuseReasonVersionMismatch,
88+
supportedVersions,
89+
},
90+
)
91+
if err := s.SendMessage(msgRefuse); err != nil {
9392
return err
9493
}
95-
return s.config.FinishedFunc(highestVersion, versionData)
96-
} else {
97-
// TODO: handle failures
98-
// https://github.com/blinklabs-io/gouroboros/issues/32
99-
return fmt.Errorf("handshake failed, but we don't yet support this")
94+
return fmt.Errorf("handshake failed: refused due to version mismatch")
95+
}
96+
// Compute highest version from intersection
97+
var proposedVersion uint16
98+
for _, version := range versionIntersect {
99+
if version > proposedVersion {
100+
proposedVersion = version
101+
}
102+
}
103+
// Decode protocol parameters for selected version
104+
versionInfo := protocol.GetProtocolVersion(proposedVersion)
105+
versionData := s.config.ProtocolVersionMap[proposedVersion]
106+
proposedVersionData, err := versionInfo.NewVersionDataFromCborFunc(
107+
msgProposeVersions.VersionMap[proposedVersion],
108+
)
109+
if err != nil {
110+
msgRefuse := NewMsgRefuse(
111+
[]any{
112+
RefuseReasonDecodeError,
113+
proposedVersion,
114+
err.Error(),
115+
},
116+
)
117+
if err := s.SendMessage(msgRefuse); err != nil {
118+
return err
119+
}
120+
return fmt.Errorf(
121+
"handshake failed: refused due to protocol parameters decode failure: %s",
122+
err,
123+
)
124+
}
125+
// Check network magic
126+
if proposedVersionData.NetworkMagic() != versionData.NetworkMagic() {
127+
errMsg := fmt.Sprintf("network magic mismatch: %#v /= %#v", versionData, proposedVersionData)
128+
msgRefuse := NewMsgRefuse(
129+
[]any{
130+
RefuseReasonRefused,
131+
proposedVersion,
132+
errMsg,
133+
},
134+
)
135+
if err := s.SendMessage(msgRefuse); err != nil {
136+
return err
137+
}
138+
return fmt.Errorf(
139+
"handshake failed: refused due to protocol parameters mismatch: %s",
140+
errMsg,
141+
)
142+
}
143+
// Accept the proposed version
144+
// We send our version data in the response and the proposed version data in the callback
145+
msgAcceptVersion := NewMsgAcceptVersion(proposedVersion, versionData)
146+
if err := s.SendMessage(msgAcceptVersion); err != nil {
147+
return err
100148
}
149+
return s.config.FinishedFunc(proposedVersion, proposedVersionData)
101150
}

protocol/handshake/server_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Copyright 2023 Blink Labs Software
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package handshake_test
16+
17+
import (
18+
"fmt"
19+
"testing"
20+
21+
ouroboros "github.com/blinklabs-io/gouroboros"
22+
"github.com/blinklabs-io/gouroboros/internal/test/ouroboros_mock"
23+
"github.com/blinklabs-io/gouroboros/protocol"
24+
"github.com/blinklabs-io/gouroboros/protocol/handshake"
25+
)
26+
27+
func TestServerBasicHandshake(t *testing.T) {
28+
mockConn := ouroboros_mock.NewConnection(
29+
ouroboros_mock.ProtocolRoleServer,
30+
[]ouroboros_mock.ConversationEntry{
31+
// MsgProposeVersions from mock client
32+
{
33+
Type: ouroboros_mock.EntryTypeOutput,
34+
ProtocolId: handshake.ProtocolId,
35+
OutputMessages: []protocol.Message{
36+
handshake.NewMsgProposeVersions(
37+
protocol.ProtocolVersionMap{
38+
(10 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic),
39+
(11 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic),
40+
(12 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic),
41+
},
42+
),
43+
},
44+
},
45+
// MsgAcceptVersion from server
46+
{
47+
Type: ouroboros_mock.EntryTypeInput,
48+
IsResponse: true,
49+
ProtocolId: handshake.ProtocolId,
50+
MsgFromCborFunc: handshake.NewMsgFromCbor,
51+
InputMessageType: handshake.MessageTypeAcceptVersion,
52+
InputMessage: handshake.NewMsgAcceptVersion(
53+
(12 + protocol.ProtocolVersionNtCOffset),
54+
protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic),
55+
),
56+
},
57+
},
58+
)
59+
oConn, err := ouroboros.New(
60+
ouroboros.WithConnection(mockConn),
61+
ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic),
62+
ouroboros.WithServer(true),
63+
)
64+
if err != nil {
65+
t.Fatalf("unexpected error when creating Ouroboros object: %s", err)
66+
}
67+
// Async error handler
68+
go func() {
69+
err, ok := <-oConn.ErrorChan()
70+
if !ok {
71+
return
72+
}
73+
// We can't call t.Fatalf() from a different Goroutine, so we panic instead
74+
panic(fmt.Sprintf("unexpected Ouroboros error: %s", err))
75+
}()
76+
// Close Ouroboros connection
77+
if err := oConn.Close(); err != nil {
78+
t.Fatalf("unexpected error when closing Ouroboros object: %s", err)
79+
}
80+
}
81+
82+
func TestServerHandshakeRefuseVersionMismatch(t *testing.T) {
83+
expectedErr := fmt.Errorf("handshake failed: refused due to version mismatch")
84+
mockConn := ouroboros_mock.NewConnection(
85+
ouroboros_mock.ProtocolRoleServer,
86+
[]ouroboros_mock.ConversationEntry{
87+
// MsgProposeVersions from mock client
88+
{
89+
Type: ouroboros_mock.EntryTypeOutput,
90+
ProtocolId: handshake.ProtocolId,
91+
OutputMessages: []protocol.Message{
92+
handshake.NewMsgProposeVersions(
93+
protocol.ProtocolVersionMap{
94+
(100 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic),
95+
(101 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic),
96+
(102 + protocol.ProtocolVersionNtCOffset): protocol.VersionDataNtC9to14(ouroboros_mock.MockNetworkMagic),
97+
},
98+
),
99+
},
100+
},
101+
// MsgRefuse from server
102+
{
103+
Type: ouroboros_mock.EntryTypeInput,
104+
IsResponse: true,
105+
ProtocolId: handshake.ProtocolId,
106+
MsgFromCborFunc: handshake.NewMsgFromCbor,
107+
InputMessageType: handshake.MessageTypeRefuse,
108+
InputMessage: handshake.NewMsgRefuse(
109+
[]any{
110+
handshake.RefuseReasonVersionMismatch,
111+
protocol.GetProtocolVersionsNtC(),
112+
},
113+
),
114+
},
115+
},
116+
)
117+
oConn, err := ouroboros.New(
118+
ouroboros.WithConnection(mockConn),
119+
ouroboros.WithNetworkMagic(ouroboros_mock.MockNetworkMagic),
120+
ouroboros.WithServer(true),
121+
)
122+
if err != nil {
123+
if err.Error() != expectedErr.Error() {
124+
t.Fatalf("unexpected error when creating Ouroboros object: %s", err)
125+
}
126+
}
127+
// Async error handler
128+
go func() {
129+
err, ok := <-oConn.ErrorChan()
130+
if !ok {
131+
return
132+
}
133+
panic(fmt.Sprintf("unexpected Ouroboros error: %s", err))
134+
}()
135+
}

0 commit comments

Comments
 (0)