Skip to content

Commit a771847

Browse files
authored
1788: add multistream policies (#102)
* 1788: add multistream policies * fix lint errors
1 parent 4256280 commit a771847

File tree

8 files changed

+500
-5
lines changed

8 files changed

+500
-5
lines changed

message.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ const (
8585
CmdSendcmpct = "sendcmpct"
8686
CmdAuthch = "authch"
8787
CmdAuthresp = "authresp"
88+
CmdCreateStream = "createstrm"
89+
CmdStreamAck = "streamack"
8890
)
8991

9092
// MessageEncoding represents the wire message encoding format to be used.
@@ -216,6 +218,12 @@ func makeEmptyMessage(command string) (Message, error) {
216218
case CmdSendcmpct:
217219
msg = &MsgSendcmpct{}
218220

221+
case CmdCreateStream:
222+
msg = &MsgCreateStream{}
223+
224+
case CmdStreamAck:
225+
msg = &MsgStreamAck{}
226+
219227
default:
220228
return nil, fmt.Errorf("unhandled command [%s]: %#v", command, msg) //nolint:err113 // needs refactoring
221229
}

msg_create_stream.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package wire
2+
3+
import (
4+
"fmt"
5+
"io"
6+
)
7+
8+
// MsgCreateStream implements the Message interface and represents a bitcoin
9+
// createstream message. It is sent as the first message on a new TCP connection
10+
// to associate it with an existing peer connection as an additional stream.
11+
type MsgCreateStream struct {
12+
AssociationID []byte
13+
StreamType StreamType
14+
StreamPolicyName string
15+
}
16+
17+
// Bsvdecode decodes r using the bitcoin protocol encoding into the receiver.
18+
// This is part of the Message interface implementation.
19+
func (msg *MsgCreateStream) Bsvdecode(r io.Reader, pver uint32, _ MessageEncoding) error {
20+
var err error
21+
22+
msg.AssociationID, err = ReadVarBytes(r, pver, MaxAssociationIDLen, "AssociationID")
23+
if err != nil {
24+
return err
25+
}
26+
27+
if len(msg.AssociationID) == 0 {
28+
return messageError("MsgCreateStream.Bsvdecode", "association ID must not be empty")
29+
}
30+
31+
var streamType uint8
32+
if err = readElement(r, &streamType); err != nil {
33+
return err
34+
}
35+
36+
msg.StreamType = StreamType(streamType)
37+
38+
msg.StreamPolicyName, err = ReadVarString(r, pver)
39+
if err != nil {
40+
return err
41+
}
42+
43+
return nil
44+
}
45+
46+
// BsvEncode encodes the receiver to w using the bitcoin protocol encoding.
47+
// This is part of the Message interface implementation.
48+
func (msg *MsgCreateStream) BsvEncode(w io.Writer, pver uint32, _ MessageEncoding) error {
49+
if len(msg.AssociationID) == 0 {
50+
return messageError("MsgCreateStream.BsvEncode", "association ID must not be empty")
51+
}
52+
53+
if len(msg.AssociationID) > MaxAssociationIDLen {
54+
str := fmt.Sprintf("association ID too long [len %v, max %v]",
55+
len(msg.AssociationID), MaxAssociationIDLen)
56+
return messageError("MsgCreateStream.BsvEncode", str)
57+
}
58+
59+
if err := WriteVarBytes(w, pver, msg.AssociationID); err != nil {
60+
return err
61+
}
62+
63+
if err := writeElement(w, uint8(msg.StreamType)); err != nil {
64+
return err
65+
}
66+
67+
if err := WriteVarString(w, pver, msg.StreamPolicyName); err != nil {
68+
return err
69+
}
70+
71+
return nil
72+
}
73+
74+
// Command returns the protocol command string for the message.
75+
func (msg *MsgCreateStream) Command() string {
76+
return CmdCreateStream
77+
}
78+
79+
// MaxPayloadLength returns the maximum length the payload can be for the receiver.
80+
func (msg *MsgCreateStream) MaxPayloadLength(_ uint32) uint64 {
81+
// varint(association_id_len) + association_id + stream_type(1) + varint(policy_len) + policy_string
82+
return MaxVarIntPayload + MaxAssociationIDLen + 1 + MaxVarIntPayload + MaxUserAgentLen
83+
}
84+
85+
// NewMsgCreateStream returns a new createstream message.
86+
func NewMsgCreateStream(associationID []byte, streamType StreamType, policyName string) *MsgCreateStream {
87+
return &MsgCreateStream{
88+
AssociationID: associationID,
89+
StreamType: streamType,
90+
StreamPolicyName: policyName,
91+
}
92+
}

msg_create_stream_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package wire
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestCreateStream(t *testing.T) {
13+
pver := ProtocolVersion
14+
15+
assocID := []byte{
16+
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
17+
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11,
18+
}
19+
msg := NewMsgCreateStream(assocID, StreamTypeData1, "BlockPriority")
20+
21+
assert.Equal(t, assocID, msg.AssociationID)
22+
assert.Equal(t, StreamTypeData1, msg.StreamType)
23+
assert.Equal(t, "BlockPriority", msg.StreamPolicyName)
24+
25+
assertCommand(t, msg, "createstrm")
26+
27+
wantPayload := uint64(MaxVarIntPayload + MaxAssociationIDLen + 1 + MaxVarIntPayload + MaxUserAgentLen)
28+
assertMaxPayload(t, msg, pver, wantPayload)
29+
30+
// Roundtrip
31+
dst := &MsgCreateStream{}
32+
assertWireRoundTrip(t, msg, dst, pver, BaseEncoding)
33+
}
34+
35+
func TestCreateStreamEncodeDecode(t *testing.T) {
36+
pver := ProtocolVersion
37+
enc := BaseEncoding
38+
39+
assocID := []byte{
40+
0x01, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
41+
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00,
42+
}
43+
44+
msg := NewMsgCreateStream(assocID, StreamTypeData1, "BlockPriority")
45+
46+
var buf bytes.Buffer
47+
require.NoError(t, msg.BsvEncode(&buf, pver, enc))
48+
49+
decoded := &MsgCreateStream{}
50+
require.NoError(t, decoded.Bsvdecode(&buf, pver, enc))
51+
52+
assert.Equal(t, msg.AssociationID, decoded.AssociationID)
53+
assert.Equal(t, msg.StreamType, decoded.StreamType)
54+
assert.Equal(t, msg.StreamPolicyName, decoded.StreamPolicyName)
55+
}
56+
57+
func TestCreateStreamEmptyAssocID(t *testing.T) {
58+
pver := ProtocolVersion
59+
enc := BaseEncoding
60+
61+
msg := NewMsgCreateStream(nil, StreamTypeData1, "BlockPriority")
62+
63+
var buf bytes.Buffer
64+
err := msg.BsvEncode(&buf, pver, enc)
65+
assert.Error(t, err)
66+
}
67+
68+
func TestCreateStreamLargeAssocID(t *testing.T) {
69+
pver := ProtocolVersion
70+
enc := BaseEncoding
71+
72+
largeID := make([]byte, MaxAssociationIDLen+1)
73+
msg := NewMsgCreateStream(largeID, StreamTypeData1, "BlockPriority")
74+
75+
var buf bytes.Buffer
76+
err := msg.BsvEncode(&buf, pver, enc)
77+
assert.Error(t, err)
78+
}
79+
80+
func TestCreateStreamWireErrors(t *testing.T) {
81+
pver := ProtocolVersion
82+
83+
assocID := []byte{0x01, 0x02, 0x03}
84+
msg := NewMsgCreateStream(assocID, StreamTypeGeneral, "Default")
85+
86+
tests := []struct {
87+
in *MsgCreateStream
88+
buf []byte
89+
pver uint32
90+
enc MessageEncoding
91+
max int
92+
writeErr error
93+
readErr error
94+
}{
95+
// Short write/read at association ID varint.
96+
{msg, []byte{}, pver, BaseEncoding, 0, io.ErrShortWrite, io.EOF},
97+
}
98+
99+
for _, test := range tests {
100+
assertWireError(t, test.in, &MsgCreateStream{}, test.buf, test.pver,
101+
test.enc, test.max, test.writeErr, test.readErr)
102+
}
103+
}
104+
105+
func TestCreateStreamMakeEmptyMessage(t *testing.T) {
106+
msg, err := makeEmptyMessage(CmdCreateStream)
107+
require.NoError(t, err)
108+
_, ok := msg.(*MsgCreateStream)
109+
assert.True(t, ok)
110+
}

msg_stream_ack.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package wire
2+
3+
import (
4+
"io"
5+
)
6+
7+
// MsgStreamAck implements the Message interface and represents a bitcoin
8+
// streamack message. It is sent in response to a createstream message to
9+
// confirm the new stream has been accepted and associated.
10+
type MsgStreamAck struct {
11+
AssociationID []byte
12+
StreamType StreamType
13+
}
14+
15+
// Bsvdecode decodes r using the bitcoin protocol encoding into the receiver.
16+
// This is part of the Message interface implementation.
17+
func (msg *MsgStreamAck) Bsvdecode(r io.Reader, pver uint32, _ MessageEncoding) error {
18+
var err error
19+
20+
msg.AssociationID, err = ReadVarBytes(r, pver, MaxAssociationIDLen, "AssociationID")
21+
if err != nil {
22+
return err
23+
}
24+
25+
var streamType uint8
26+
if err = readElement(r, &streamType); err != nil {
27+
return err
28+
}
29+
30+
msg.StreamType = StreamType(streamType)
31+
32+
return nil
33+
}
34+
35+
// BsvEncode encodes the receiver to w using the bitcoin protocol encoding.
36+
// This is part of the Message interface implementation.
37+
func (msg *MsgStreamAck) BsvEncode(w io.Writer, pver uint32, _ MessageEncoding) error {
38+
if err := WriteVarBytes(w, pver, msg.AssociationID); err != nil {
39+
return err
40+
}
41+
42+
if err := writeElement(w, uint8(msg.StreamType)); err != nil {
43+
return err
44+
}
45+
46+
return nil
47+
}
48+
49+
// Command returns the protocol command string for the message.
50+
func (msg *MsgStreamAck) Command() string {
51+
return CmdStreamAck
52+
}
53+
54+
// MaxPayloadLength returns the maximum length the payload can be for the receiver.
55+
func (msg *MsgStreamAck) MaxPayloadLength(_ uint32) uint64 {
56+
// varint(association_id_len) + association_id + stream_type(1)
57+
return MaxVarIntPayload + MaxAssociationIDLen + 1
58+
}
59+
60+
// NewMsgStreamAck returns a new streamack message.
61+
func NewMsgStreamAck(associationID []byte, streamType StreamType) *MsgStreamAck {
62+
return &MsgStreamAck{
63+
AssociationID: associationID,
64+
StreamType: streamType,
65+
}
66+
}

msg_stream_ack_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package wire
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestStreamAck(t *testing.T) {
13+
pver := ProtocolVersion
14+
15+
assocID := []byte{
16+
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
17+
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11,
18+
}
19+
msg := NewMsgStreamAck(assocID, StreamTypeData1)
20+
21+
assert.Equal(t, assocID, msg.AssociationID)
22+
assert.Equal(t, StreamTypeData1, msg.StreamType)
23+
24+
assertCommand(t, msg, "streamack")
25+
26+
wantPayload := uint64(MaxVarIntPayload + MaxAssociationIDLen + 1)
27+
assertMaxPayload(t, msg, pver, wantPayload)
28+
29+
// Roundtrip
30+
dst := &MsgStreamAck{}
31+
assertWireRoundTrip(t, msg, dst, pver, BaseEncoding)
32+
}
33+
34+
func TestStreamAckEncodeDecode(t *testing.T) {
35+
pver := ProtocolVersion
36+
enc := BaseEncoding
37+
38+
assocID := []byte{
39+
0x01, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
40+
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00,
41+
}
42+
43+
msg := NewMsgStreamAck(assocID, StreamTypeData1)
44+
45+
var buf bytes.Buffer
46+
require.NoError(t, msg.BsvEncode(&buf, pver, enc))
47+
48+
decoded := &MsgStreamAck{}
49+
require.NoError(t, decoded.Bsvdecode(&buf, pver, enc))
50+
51+
assert.Equal(t, msg.AssociationID, decoded.AssociationID)
52+
assert.Equal(t, msg.StreamType, decoded.StreamType)
53+
}
54+
55+
func TestStreamAckEmptyAssocID(t *testing.T) {
56+
pver := ProtocolVersion
57+
enc := BaseEncoding
58+
59+
// Empty assoc ID is valid for streamack (the field is still encoded as var_bytes)
60+
msg := NewMsgStreamAck(nil, StreamTypeData1)
61+
62+
var buf bytes.Buffer
63+
require.NoError(t, msg.BsvEncode(&buf, pver, enc))
64+
65+
decoded := &MsgStreamAck{}
66+
require.NoError(t, decoded.Bsvdecode(&buf, pver, enc))
67+
68+
assert.Empty(t, decoded.AssociationID)
69+
assert.Equal(t, StreamTypeData1, decoded.StreamType)
70+
}
71+
72+
func TestStreamAckWireErrors(t *testing.T) {
73+
pver := ProtocolVersion
74+
75+
assocID := []byte{0x01, 0x02, 0x03}
76+
msg := NewMsgStreamAck(assocID, StreamTypeGeneral)
77+
78+
tests := []struct {
79+
in *MsgStreamAck
80+
buf []byte
81+
pver uint32
82+
enc MessageEncoding
83+
max int
84+
writeErr error
85+
readErr error
86+
}{
87+
// Short write/read at association ID varint.
88+
{msg, []byte{}, pver, BaseEncoding, 0, io.ErrShortWrite, io.EOF},
89+
}
90+
91+
for _, test := range tests {
92+
assertWireError(t, test.in, &MsgStreamAck{}, test.buf, test.pver,
93+
test.enc, test.max, test.writeErr, test.readErr)
94+
}
95+
}
96+
97+
func TestStreamAckMakeEmptyMessage(t *testing.T) {
98+
msg, err := makeEmptyMessage(CmdStreamAck)
99+
require.NoError(t, err)
100+
_, ok := msg.(*MsgStreamAck)
101+
assert.True(t, ok)
102+
}

0 commit comments

Comments
 (0)