Skip to content

Commit 09e5b3b

Browse files
authored
Merge pull request #266 from blinklabs-io/test/ouroboros-conversation-mock
test: framework for mocking Ouroboros connections
2 parents 6a8e06d + a5455a6 commit 09e5b3b

File tree

5 files changed

+314
-26
lines changed

5 files changed

+314
-26
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
// Copyright 2023 Blink Labs, LLC.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ouroboros_mock
16+
17+
import (
18+
"bytes"
19+
"fmt"
20+
"net"
21+
"reflect"
22+
"time"
23+
24+
"github.com/blinklabs-io/gouroboros/cbor"
25+
"github.com/blinklabs-io/gouroboros/muxer"
26+
)
27+
28+
// Connection mocks an Ouroboros connection
29+
type Connection struct {
30+
mockConn net.Conn
31+
conn net.Conn
32+
conversation []ConversationEntry
33+
muxer *muxer.Muxer
34+
muxerRecvChan chan *muxer.Segment
35+
}
36+
37+
// NewConnection returns a new Connection with the provided conversation entries
38+
func NewConnection(conversation []ConversationEntry) net.Conn {
39+
c := &Connection{
40+
conversation: conversation,
41+
}
42+
c.conn, c.mockConn = net.Pipe()
43+
// Start a muxer on the mocked side of the connection
44+
c.muxer = muxer.New(c.mockConn)
45+
// We use ProtocolUnknown to catch all inbound messages when no other protocols are registered
46+
_, c.muxerRecvChan, _ = c.muxer.RegisterProtocol(muxer.ProtocolUnknown)
47+
c.muxer.Start()
48+
// Start async muxer error handler
49+
go func() {
50+
err, ok := <-c.muxer.ErrorChan()
51+
if !ok {
52+
return
53+
}
54+
panic(fmt.Sprintf("muxer error: %s", err))
55+
}()
56+
// Start async conversation handler
57+
go c.asyncLoop()
58+
return c
59+
}
60+
61+
// Read provides a proxy to the client-side connection's Read function. This is needed to satisfy the net.Conn interface
62+
func (c *Connection) Read(b []byte) (n int, err error) {
63+
return c.conn.Read(b)
64+
}
65+
66+
// Write provides a proxy to the client-side connection's Write function. This is needed to satisfy the net.Conn interface
67+
func (c *Connection) Write(b []byte) (n int, err error) {
68+
return c.conn.Write(b)
69+
}
70+
71+
// Close closes both sides of the connection. This is needed to satisfy the net.Conn interface
72+
func (c *Connection) Close() error {
73+
if err := c.conn.Close(); err != nil {
74+
return err
75+
}
76+
if err := c.mockConn.Close(); err != nil {
77+
return err
78+
}
79+
return nil
80+
}
81+
82+
// LocalAddr provides a proxy to the client-side connection's LocalAddr function. This is needed to satisfy the net.Conn interface
83+
func (c *Connection) LocalAddr() net.Addr {
84+
return c.conn.LocalAddr()
85+
}
86+
87+
// RemoteAddr provides a proxy to the client-side connection's RemoteAddr function. This is needed to satisfy the net.Conn interface
88+
func (c *Connection) RemoteAddr() net.Addr {
89+
return c.conn.RemoteAddr()
90+
}
91+
92+
// SetDeadline provides a proxy to the client-side connection's SetDeadline function. This is needed to satisfy the net.Conn interface
93+
func (c *Connection) SetDeadline(t time.Time) error {
94+
return c.conn.SetDeadline(t)
95+
}
96+
97+
// SetReadDeadline provides a proxy to the client-side connection's SetReadDeadline function. This is needed to satisfy the net.Conn interface
98+
func (c *Connection) SetReadDeadline(t time.Time) error {
99+
return c.conn.SetReadDeadline(t)
100+
}
101+
102+
// SetWriteDeadline provides a proxy to the client-side connection's SetWriteDeadline function. This is needed to satisfy the net.Conn interface
103+
func (c *Connection) SetWriteDeadline(t time.Time) error {
104+
return c.conn.SetWriteDeadline(t)
105+
}
106+
107+
func (c *Connection) asyncLoop() {
108+
for _, entry := range c.conversation {
109+
switch entry.Type {
110+
case EntryTypeInput:
111+
if err := c.processInputEntry(entry); err != nil {
112+
panic(err.Error())
113+
}
114+
case EntryTypeOutput:
115+
if err := c.processOutputEntry(entry); err != nil {
116+
panic(fmt.Sprintf("output error: %s", err))
117+
}
118+
case EntryTypeClose:
119+
c.Close()
120+
default:
121+
panic(fmt.Sprintf("unknown conversation entry type: %d: %#v", entry.Type, entry))
122+
}
123+
}
124+
}
125+
126+
func (c *Connection) processInputEntry(entry ConversationEntry) error {
127+
// Wait for segment to be received from muxer
128+
segment, ok := <-c.muxerRecvChan
129+
if !ok {
130+
return nil
131+
}
132+
if segment.GetProtocolId() != entry.ProtocolId {
133+
return fmt.Errorf("input message protocol ID did not match expected value: expected %d, got %d", entry.ProtocolId, segment.GetProtocolId())
134+
}
135+
if segment.IsResponse() != entry.IsResponse {
136+
return fmt.Errorf("input message response flag did not match expected value: expected %v, got %v", entry.IsResponse, segment.IsResponse())
137+
}
138+
// Determine message type
139+
msgType, err := cbor.DecodeIdFromList(segment.Payload)
140+
if err != nil {
141+
return fmt.Errorf("decode error: %s", err)
142+
}
143+
if entry.InputMessage != nil {
144+
// Create Message object from CBOR
145+
msg, err := entry.MsgFromCborFunc(uint(msgType), segment.Payload)
146+
if err != nil {
147+
return fmt.Errorf("message from CBOR error: %s", err)
148+
}
149+
if msg == nil {
150+
return fmt.Errorf("received unknown message type: %d", msgType)
151+
}
152+
if !reflect.DeepEqual(msg, entry.InputMessage) {
153+
return fmt.Errorf("parsed message does not match expected value: got %#v, expected %#v", msg, entry.InputMessage)
154+
}
155+
} else {
156+
if entry.InputMessageType == uint(msgType) {
157+
return nil
158+
}
159+
return fmt.Errorf("input message is not of expected type: expected %d, got %d", entry.InputMessageType, msgType)
160+
}
161+
return nil
162+
}
163+
164+
func (c *Connection) processOutputEntry(entry ConversationEntry) error {
165+
payloadBuf := bytes.NewBuffer(nil)
166+
for _, msg := range entry.OutputMessages {
167+
// Get raw CBOR from message
168+
data := msg.Cbor()
169+
// If message has no raw CBOR, encode the message
170+
if data == nil {
171+
var err error
172+
data, err = cbor.Encode(msg)
173+
if err != nil {
174+
return err
175+
}
176+
}
177+
payloadBuf.Write(data)
178+
}
179+
segment := muxer.NewSegment(entry.ProtocolId, payloadBuf.Bytes(), entry.IsResponse)
180+
if err := c.muxer.Send(segment); err != nil {
181+
return err
182+
}
183+
return nil
184+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright 2023 Blink Labs, LLC.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ouroboros_mock
16+
17+
import (
18+
"github.com/blinklabs-io/gouroboros/protocol"
19+
"github.com/blinklabs-io/gouroboros/protocol/handshake"
20+
)
21+
22+
const (
23+
MockNetworkMagic uint32 = 999999
24+
MockProtocolVersionNtC uint16 = 14
25+
)
26+
27+
type EntryType int
28+
29+
const (
30+
EntryTypeNone EntryType = 0
31+
EntryTypeInput EntryType = 1
32+
EntryTypeOutput EntryType = 2
33+
EntryTypeClose EntryType = 3
34+
)
35+
36+
type ConversationEntry struct {
37+
Type EntryType
38+
ProtocolId uint16
39+
IsResponse bool
40+
OutputMessages []protocol.Message
41+
InputMessage protocol.Message
42+
InputMessageType uint
43+
MsgFromCborFunc protocol.MessageFromCborFunc
44+
}
45+
46+
// ConversationEntryHandshakeRequestGeneric is a pre-defined conversation event that matches a generic
47+
// handshake request from a client
48+
var ConversationEntryHandshakeRequestGeneric = ConversationEntry{
49+
Type: EntryTypeInput,
50+
ProtocolId: handshake.ProtocolId,
51+
InputMessageType: handshake.MessageTypeProposeVersions,
52+
}
53+
54+
// ConversationEntryHandshakeResponse is a pre-defined conversation entry for a server NtC handshake response
55+
var ConversationEntryHandshakeResponse = ConversationEntry{
56+
Type: EntryTypeOutput,
57+
ProtocolId: handshake.ProtocolId,
58+
IsResponse: true,
59+
OutputMessages: []protocol.Message{
60+
handshake.NewMsgAcceptVersion(MockProtocolVersionNtC, MockNetworkMagic),
61+
},
62+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright 2023 Blink Labs, LLC.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ouroboros_mock
16+
17+
import (
18+
"testing"
19+
20+
ouroboros "github.com/blinklabs-io/gouroboros"
21+
)
22+
23+
// Basic test of conversation mock functionality
24+
func TestBasic(t *testing.T) {
25+
mockConn := NewConnection(
26+
[]ConversationEntry{
27+
ConversationEntryHandshakeRequestGeneric,
28+
ConversationEntryHandshakeResponse,
29+
},
30+
)
31+
oConn, err := ouroboros.New(
32+
ouroboros.WithConnection(mockConn),
33+
ouroboros.WithNetworkMagic(MockNetworkMagic),
34+
)
35+
if err != nil {
36+
t.Fatalf("unexpected error when creating Ouroboros object: %s", err)
37+
}
38+
// Close Ouroboros connection
39+
if err := oConn.Close(); err != nil {
40+
t.Fatalf("unexpected error when closing Ouroboros object: %s", err)
41+
}
42+
}

ouroboros.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,10 +241,10 @@ func (o *Ouroboros) setupConnection() error {
241241
}
242242
var protoVersions []uint16
243243
if o.useNodeToNodeProto {
244-
protoVersions = getProtocolVersionsNtN()
244+
protoVersions = GetProtocolVersionsNtN()
245245
protoOptions.Mode = protocol.ProtocolModeNodeToNode
246246
} else {
247-
protoVersions = getProtocolVersionsNtC()
247+
protoVersions = GetProtocolVersionsNtC()
248248
protoOptions.Mode = protocol.ProtocolModeNodeToClient
249249
}
250250
if o.server {
@@ -312,7 +312,7 @@ func (o *Ouroboros) setupConnection() error {
312312
}()
313313
// Configure the relevant mini-protocols
314314
if o.useNodeToNodeProto {
315-
versionNtN := getProtocolVersionNtN(handshakeVersion)
315+
versionNtN := GetProtocolVersionNtN(handshakeVersion)
316316
protoOptions.Mode = protocol.ProtocolModeNodeToNode
317317
o.chainSync = chainsync.New(protoOptions, o.chainSyncConfig)
318318
o.blockFetch = blockfetch.New(protoOptions, o.blockFetchConfig)
@@ -324,7 +324,7 @@ func (o *Ouroboros) setupConnection() error {
324324
}
325325
}
326326
} else {
327-
versionNtC := getProtocolVersionNtC(handshakeVersion)
327+
versionNtC := GetProtocolVersionNtC(handshakeVersion)
328328
protoOptions.Mode = protocol.ProtocolModeNodeToClient
329329
o.chainSync = chainsync.New(protoOptions, o.chainSyncConfig)
330330
o.localTxSubmission = localtxsubmission.New(protoOptions, o.localTxSubmissionConfig)

0 commit comments

Comments
 (0)