Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ const (
CmdSendcmpct = "sendcmpct"
CmdAuthch = "authch"
CmdAuthresp = "authresp"
CmdCreateStream = "createstrm"
CmdStreamAck = "streamack"
)

// MessageEncoding represents the wire message encoding format to be used.
Expand Down Expand Up @@ -216,6 +218,12 @@ func makeEmptyMessage(command string) (Message, error) {
case CmdSendcmpct:
msg = &MsgSendcmpct{}

case CmdCreateStream:
msg = &MsgCreateStream{}

case CmdStreamAck:
msg = &MsgStreamAck{}

default:
return nil, fmt.Errorf("unhandled command [%s]: %#v", command, msg) //nolint:err113 // needs refactoring
}
Expand Down
92 changes: 92 additions & 0 deletions msg_create_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package wire

import (
"fmt"
"io"
)

// MsgCreateStream implements the Message interface and represents a bitcoin
// createstream message. It is sent as the first message on a new TCP connection
// to associate it with an existing peer connection as an additional stream.
type MsgCreateStream struct {
AssociationID []byte
StreamType StreamType
StreamPolicyName string
}

// Bsvdecode decodes r using the bitcoin protocol encoding into the receiver.
// This is part of the Message interface implementation.
func (msg *MsgCreateStream) Bsvdecode(r io.Reader, pver uint32, _ MessageEncoding) error {
var err error

msg.AssociationID, err = ReadVarBytes(r, pver, MaxAssociationIDLen, "AssociationID")
if err != nil {
return err
}

if len(msg.AssociationID) == 0 {
return messageError("MsgCreateStream.Bsvdecode", "association ID must not be empty")
}

var streamType uint8
if err = readElement(r, &streamType); err != nil {
return err
}

msg.StreamType = StreamType(streamType)

msg.StreamPolicyName, err = ReadVarString(r, pver)
if err != nil {
return err
}

return nil
}

// BsvEncode encodes the receiver to w using the bitcoin protocol encoding.
// This is part of the Message interface implementation.
func (msg *MsgCreateStream) BsvEncode(w io.Writer, pver uint32, _ MessageEncoding) error {
if len(msg.AssociationID) == 0 {
return messageError("MsgCreateStream.BsvEncode", "association ID must not be empty")
}

if len(msg.AssociationID) > MaxAssociationIDLen {
str := fmt.Sprintf("association ID too long [len %v, max %v]",
len(msg.AssociationID), MaxAssociationIDLen)
return messageError("MsgCreateStream.BsvEncode", str)
}

if err := WriteVarBytes(w, pver, msg.AssociationID); err != nil {
return err
}

if err := writeElement(w, uint8(msg.StreamType)); err != nil {
return err
}

if err := WriteVarString(w, pver, msg.StreamPolicyName); err != nil {
return err
}

return nil
}

// Command returns the protocol command string for the message.
func (msg *MsgCreateStream) Command() string {
return CmdCreateStream
}

// MaxPayloadLength returns the maximum length the payload can be for the receiver.
func (msg *MsgCreateStream) MaxPayloadLength(_ uint32) uint64 {
// varint(association_id_len) + association_id + stream_type(1) + varint(policy_len) + policy_string
return MaxVarIntPayload + MaxAssociationIDLen + 1 + MaxVarIntPayload + MaxUserAgentLen
}

// NewMsgCreateStream returns a new createstream message.
func NewMsgCreateStream(associationID []byte, streamType StreamType, policyName string) *MsgCreateStream {
return &MsgCreateStream{
AssociationID: associationID,
StreamType: streamType,
StreamPolicyName: policyName,
}
}
110 changes: 110 additions & 0 deletions msg_create_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package wire

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCreateStream(t *testing.T) {
pver := ProtocolVersion

assocID := []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11,
}
msg := NewMsgCreateStream(assocID, StreamTypeData1, "BlockPriority")

assert.Equal(t, assocID, msg.AssociationID)
assert.Equal(t, StreamTypeData1, msg.StreamType)
assert.Equal(t, "BlockPriority", msg.StreamPolicyName)

assertCommand(t, msg, "createstrm")

wantPayload := uint64(MaxVarIntPayload + MaxAssociationIDLen + 1 + MaxVarIntPayload + MaxUserAgentLen)
assertMaxPayload(t, msg, pver, wantPayload)

// Roundtrip
dst := &MsgCreateStream{}
assertWireRoundTrip(t, msg, dst, pver, BaseEncoding)
}

func TestCreateStreamEncodeDecode(t *testing.T) {
pver := ProtocolVersion
enc := BaseEncoding

assocID := []byte{
0x01, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00,
}

msg := NewMsgCreateStream(assocID, StreamTypeData1, "BlockPriority")

var buf bytes.Buffer
require.NoError(t, msg.BsvEncode(&buf, pver, enc))

decoded := &MsgCreateStream{}
require.NoError(t, decoded.Bsvdecode(&buf, pver, enc))

assert.Equal(t, msg.AssociationID, decoded.AssociationID)
assert.Equal(t, msg.StreamType, decoded.StreamType)
assert.Equal(t, msg.StreamPolicyName, decoded.StreamPolicyName)
}

func TestCreateStreamEmptyAssocID(t *testing.T) {
pver := ProtocolVersion
enc := BaseEncoding

msg := NewMsgCreateStream(nil, StreamTypeData1, "BlockPriority")

var buf bytes.Buffer
err := msg.BsvEncode(&buf, pver, enc)
assert.Error(t, err)
}

func TestCreateStreamLargeAssocID(t *testing.T) {
pver := ProtocolVersion
enc := BaseEncoding

largeID := make([]byte, MaxAssociationIDLen+1)
msg := NewMsgCreateStream(largeID, StreamTypeData1, "BlockPriority")

var buf bytes.Buffer
err := msg.BsvEncode(&buf, pver, enc)
assert.Error(t, err)
}

func TestCreateStreamWireErrors(t *testing.T) {
pver := ProtocolVersion

assocID := []byte{0x01, 0x02, 0x03}
msg := NewMsgCreateStream(assocID, StreamTypeGeneral, "Default")

tests := []struct {
in *MsgCreateStream
buf []byte
pver uint32
enc MessageEncoding
max int
writeErr error
readErr error
}{
// Short write/read at association ID varint.
{msg, []byte{}, pver, BaseEncoding, 0, io.ErrShortWrite, io.EOF},
}

for _, test := range tests {
assertWireError(t, test.in, &MsgCreateStream{}, test.buf, test.pver,
test.enc, test.max, test.writeErr, test.readErr)
}
}

func TestCreateStreamMakeEmptyMessage(t *testing.T) {
msg, err := makeEmptyMessage(CmdCreateStream)
require.NoError(t, err)
_, ok := msg.(*MsgCreateStream)
assert.True(t, ok)
}
66 changes: 66 additions & 0 deletions msg_stream_ack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package wire

import (
"io"
)

// MsgStreamAck implements the Message interface and represents a bitcoin
// streamack message. It is sent in response to a createstream message to
// confirm the new stream has been accepted and associated.
type MsgStreamAck struct {
AssociationID []byte
StreamType StreamType
}

// Bsvdecode decodes r using the bitcoin protocol encoding into the receiver.
// This is part of the Message interface implementation.
func (msg *MsgStreamAck) Bsvdecode(r io.Reader, pver uint32, _ MessageEncoding) error {
var err error

msg.AssociationID, err = ReadVarBytes(r, pver, MaxAssociationIDLen, "AssociationID")
if err != nil {
return err
}

var streamType uint8
if err = readElement(r, &streamType); err != nil {
return err
}

msg.StreamType = StreamType(streamType)

return nil
}

// BsvEncode encodes the receiver to w using the bitcoin protocol encoding.
// This is part of the Message interface implementation.
func (msg *MsgStreamAck) BsvEncode(w io.Writer, pver uint32, _ MessageEncoding) error {
if err := WriteVarBytes(w, pver, msg.AssociationID); err != nil {
return err
}

if err := writeElement(w, uint8(msg.StreamType)); err != nil {
return err
}

return nil
}

// Command returns the protocol command string for the message.
func (msg *MsgStreamAck) Command() string {
return CmdStreamAck
}

// MaxPayloadLength returns the maximum length the payload can be for the receiver.
func (msg *MsgStreamAck) MaxPayloadLength(_ uint32) uint64 {
// varint(association_id_len) + association_id + stream_type(1)
return MaxVarIntPayload + MaxAssociationIDLen + 1
}

// NewMsgStreamAck returns a new streamack message.
func NewMsgStreamAck(associationID []byte, streamType StreamType) *MsgStreamAck {
return &MsgStreamAck{
AssociationID: associationID,
StreamType: streamType,
}
}
102 changes: 102 additions & 0 deletions msg_stream_ack_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package wire

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestStreamAck(t *testing.T) {
pver := ProtocolVersion

assocID := []byte{
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11,
}
msg := NewMsgStreamAck(assocID, StreamTypeData1)

assert.Equal(t, assocID, msg.AssociationID)
assert.Equal(t, StreamTypeData1, msg.StreamType)

assertCommand(t, msg, "streamack")

wantPayload := uint64(MaxVarIntPayload + MaxAssociationIDLen + 1)
assertMaxPayload(t, msg, pver, wantPayload)

// Roundtrip
dst := &MsgStreamAck{}
assertWireRoundTrip(t, msg, dst, pver, BaseEncoding)
}

func TestStreamAckEncodeDecode(t *testing.T) {
pver := ProtocolVersion
enc := BaseEncoding

assocID := []byte{
0x01, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00,
}

msg := NewMsgStreamAck(assocID, StreamTypeData1)

var buf bytes.Buffer
require.NoError(t, msg.BsvEncode(&buf, pver, enc))

decoded := &MsgStreamAck{}
require.NoError(t, decoded.Bsvdecode(&buf, pver, enc))

assert.Equal(t, msg.AssociationID, decoded.AssociationID)
assert.Equal(t, msg.StreamType, decoded.StreamType)
}

func TestStreamAckEmptyAssocID(t *testing.T) {
pver := ProtocolVersion
enc := BaseEncoding

// Empty assoc ID is valid for streamack (the field is still encoded as var_bytes)
msg := NewMsgStreamAck(nil, StreamTypeData1)

var buf bytes.Buffer
require.NoError(t, msg.BsvEncode(&buf, pver, enc))

decoded := &MsgStreamAck{}
require.NoError(t, decoded.Bsvdecode(&buf, pver, enc))

assert.Empty(t, decoded.AssociationID)
assert.Equal(t, StreamTypeData1, decoded.StreamType)
}

func TestStreamAckWireErrors(t *testing.T) {
pver := ProtocolVersion

assocID := []byte{0x01, 0x02, 0x03}
msg := NewMsgStreamAck(assocID, StreamTypeGeneral)

tests := []struct {
in *MsgStreamAck
buf []byte
pver uint32
enc MessageEncoding
max int
writeErr error
readErr error
}{
// Short write/read at association ID varint.
{msg, []byte{}, pver, BaseEncoding, 0, io.ErrShortWrite, io.EOF},
}

for _, test := range tests {
assertWireError(t, test.in, &MsgStreamAck{}, test.buf, test.pver,
test.enc, test.max, test.writeErr, test.readErr)
}
}

func TestStreamAckMakeEmptyMessage(t *testing.T) {
msg, err := makeEmptyMessage(CmdStreamAck)
require.NoError(t, err)
_, ok := msg.(*MsgStreamAck)
assert.True(t, ok)
}
Loading