diff --git a/message.go b/message.go index 598aa6d..82667d3 100644 --- a/message.go +++ b/message.go @@ -85,6 +85,8 @@ const ( CmdSendcmpct = "sendcmpct" CmdAuthch = "authch" CmdAuthresp = "authresp" + CmdCreateStream = "createstrm" + CmdStreamAck = "streamack" ) // MessageEncoding represents the wire message encoding format to be used. @@ -216,6 +218,12 @@ func makeEmptyMessage(command string) (Message, error) { case CmdSendcmpct: msg = &MsgSendcmpct{} + case CmdCreateStream: + msg = &MsgCreateStream{} + + case CmdStreamAck: + msg = &MsgStreamAck{} + default: return nil, fmt.Errorf("unhandled command [%s]: %#v", command, msg) //nolint:err113 // needs refactoring } diff --git a/msg_create_stream.go b/msg_create_stream.go new file mode 100644 index 0000000..a03cc8e --- /dev/null +++ b/msg_create_stream.go @@ -0,0 +1,92 @@ +package wire + +import ( + "fmt" + "io" +) + +// MsgCreateStream implements the Message interface and represents a bitcoin +// createstream message. It is sent as the first message on a new TCP connection +// to associate it with an existing peer connection as an additional stream. +type MsgCreateStream struct { + AssociationID []byte + StreamType StreamType + StreamPolicyName string +} + +// Bsvdecode decodes r using the bitcoin protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgCreateStream) Bsvdecode(r io.Reader, pver uint32, _ MessageEncoding) error { + var err error + + msg.AssociationID, err = ReadVarBytes(r, pver, MaxAssociationIDLen, "AssociationID") + if err != nil { + return err + } + + if len(msg.AssociationID) == 0 { + return messageError("MsgCreateStream.Bsvdecode", "association ID must not be empty") + } + + var streamType uint8 + if err = readElement(r, &streamType); err != nil { + return err + } + + msg.StreamType = StreamType(streamType) + + msg.StreamPolicyName, err = ReadVarString(r, pver) + if err != nil { + return err + } + + return nil +} + +// BsvEncode encodes the receiver to w using the bitcoin protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgCreateStream) BsvEncode(w io.Writer, pver uint32, _ MessageEncoding) error { + if len(msg.AssociationID) == 0 { + return messageError("MsgCreateStream.BsvEncode", "association ID must not be empty") + } + + if len(msg.AssociationID) > MaxAssociationIDLen { + str := fmt.Sprintf("association ID too long [len %v, max %v]", + len(msg.AssociationID), MaxAssociationIDLen) + return messageError("MsgCreateStream.BsvEncode", str) + } + + if err := WriteVarBytes(w, pver, msg.AssociationID); err != nil { + return err + } + + if err := writeElement(w, uint8(msg.StreamType)); err != nil { + return err + } + + if err := WriteVarString(w, pver, msg.StreamPolicyName); err != nil { + return err + } + + return nil +} + +// Command returns the protocol command string for the message. +func (msg *MsgCreateStream) Command() string { + return CmdCreateStream +} + +// MaxPayloadLength returns the maximum length the payload can be for the receiver. +func (msg *MsgCreateStream) MaxPayloadLength(_ uint32) uint64 { + // varint(association_id_len) + association_id + stream_type(1) + varint(policy_len) + policy_string + return MaxVarIntPayload + MaxAssociationIDLen + 1 + MaxVarIntPayload + MaxUserAgentLen +} + +// NewMsgCreateStream returns a new createstream message. +func NewMsgCreateStream(associationID []byte, streamType StreamType, policyName string) *MsgCreateStream { + return &MsgCreateStream{ + AssociationID: associationID, + StreamType: streamType, + StreamPolicyName: policyName, + } +} diff --git a/msg_create_stream_test.go b/msg_create_stream_test.go new file mode 100644 index 0000000..4fc469f --- /dev/null +++ b/msg_create_stream_test.go @@ -0,0 +1,110 @@ +package wire + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateStream(t *testing.T) { + pver := ProtocolVersion + + assocID := []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, + } + msg := NewMsgCreateStream(assocID, StreamTypeData1, "BlockPriority") + + assert.Equal(t, assocID, msg.AssociationID) + assert.Equal(t, StreamTypeData1, msg.StreamType) + assert.Equal(t, "BlockPriority", msg.StreamPolicyName) + + assertCommand(t, msg, "createstrm") + + wantPayload := uint64(MaxVarIntPayload + MaxAssociationIDLen + 1 + MaxVarIntPayload + MaxUserAgentLen) + assertMaxPayload(t, msg, pver, wantPayload) + + // Roundtrip + dst := &MsgCreateStream{} + assertWireRoundTrip(t, msg, dst, pver, BaseEncoding) +} + +func TestCreateStreamEncodeDecode(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + assocID := []byte{ + 0x01, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, + } + + msg := NewMsgCreateStream(assocID, StreamTypeData1, "BlockPriority") + + var buf bytes.Buffer + require.NoError(t, msg.BsvEncode(&buf, pver, enc)) + + decoded := &MsgCreateStream{} + require.NoError(t, decoded.Bsvdecode(&buf, pver, enc)) + + assert.Equal(t, msg.AssociationID, decoded.AssociationID) + assert.Equal(t, msg.StreamType, decoded.StreamType) + assert.Equal(t, msg.StreamPolicyName, decoded.StreamPolicyName) +} + +func TestCreateStreamEmptyAssocID(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + msg := NewMsgCreateStream(nil, StreamTypeData1, "BlockPriority") + + var buf bytes.Buffer + err := msg.BsvEncode(&buf, pver, enc) + assert.Error(t, err) +} + +func TestCreateStreamLargeAssocID(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + largeID := make([]byte, MaxAssociationIDLen+1) + msg := NewMsgCreateStream(largeID, StreamTypeData1, "BlockPriority") + + var buf bytes.Buffer + err := msg.BsvEncode(&buf, pver, enc) + assert.Error(t, err) +} + +func TestCreateStreamWireErrors(t *testing.T) { + pver := ProtocolVersion + + assocID := []byte{0x01, 0x02, 0x03} + msg := NewMsgCreateStream(assocID, StreamTypeGeneral, "Default") + + tests := []struct { + in *MsgCreateStream + buf []byte + pver uint32 + enc MessageEncoding + max int + writeErr error + readErr error + }{ + // Short write/read at association ID varint. + {msg, []byte{}, pver, BaseEncoding, 0, io.ErrShortWrite, io.EOF}, + } + + for _, test := range tests { + assertWireError(t, test.in, &MsgCreateStream{}, test.buf, test.pver, + test.enc, test.max, test.writeErr, test.readErr) + } +} + +func TestCreateStreamMakeEmptyMessage(t *testing.T) { + msg, err := makeEmptyMessage(CmdCreateStream) + require.NoError(t, err) + _, ok := msg.(*MsgCreateStream) + assert.True(t, ok) +} diff --git a/msg_stream_ack.go b/msg_stream_ack.go new file mode 100644 index 0000000..fda457f --- /dev/null +++ b/msg_stream_ack.go @@ -0,0 +1,66 @@ +package wire + +import ( + "io" +) + +// MsgStreamAck implements the Message interface and represents a bitcoin +// streamack message. It is sent in response to a createstream message to +// confirm the new stream has been accepted and associated. +type MsgStreamAck struct { + AssociationID []byte + StreamType StreamType +} + +// Bsvdecode decodes r using the bitcoin protocol encoding into the receiver. +// This is part of the Message interface implementation. +func (msg *MsgStreamAck) Bsvdecode(r io.Reader, pver uint32, _ MessageEncoding) error { + var err error + + msg.AssociationID, err = ReadVarBytes(r, pver, MaxAssociationIDLen, "AssociationID") + if err != nil { + return err + } + + var streamType uint8 + if err = readElement(r, &streamType); err != nil { + return err + } + + msg.StreamType = StreamType(streamType) + + return nil +} + +// BsvEncode encodes the receiver to w using the bitcoin protocol encoding. +// This is part of the Message interface implementation. +func (msg *MsgStreamAck) BsvEncode(w io.Writer, pver uint32, _ MessageEncoding) error { + if err := WriteVarBytes(w, pver, msg.AssociationID); err != nil { + return err + } + + if err := writeElement(w, uint8(msg.StreamType)); err != nil { + return err + } + + return nil +} + +// Command returns the protocol command string for the message. +func (msg *MsgStreamAck) Command() string { + return CmdStreamAck +} + +// MaxPayloadLength returns the maximum length the payload can be for the receiver. +func (msg *MsgStreamAck) MaxPayloadLength(_ uint32) uint64 { + // varint(association_id_len) + association_id + stream_type(1) + return MaxVarIntPayload + MaxAssociationIDLen + 1 +} + +// NewMsgStreamAck returns a new streamack message. +func NewMsgStreamAck(associationID []byte, streamType StreamType) *MsgStreamAck { + return &MsgStreamAck{ + AssociationID: associationID, + StreamType: streamType, + } +} diff --git a/msg_stream_ack_test.go b/msg_stream_ack_test.go new file mode 100644 index 0000000..fb207c5 --- /dev/null +++ b/msg_stream_ack_test.go @@ -0,0 +1,102 @@ +package wire + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamAck(t *testing.T) { + pver := ProtocolVersion + + assocID := []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, + } + msg := NewMsgStreamAck(assocID, StreamTypeData1) + + assert.Equal(t, assocID, msg.AssociationID) + assert.Equal(t, StreamTypeData1, msg.StreamType) + + assertCommand(t, msg, "streamack") + + wantPayload := uint64(MaxVarIntPayload + MaxAssociationIDLen + 1) + assertMaxPayload(t, msg, pver, wantPayload) + + // Roundtrip + dst := &MsgStreamAck{} + assertWireRoundTrip(t, msg, dst, pver, BaseEncoding) +} + +func TestStreamAckEncodeDecode(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + assocID := []byte{ + 0x01, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, + } + + msg := NewMsgStreamAck(assocID, StreamTypeData1) + + var buf bytes.Buffer + require.NoError(t, msg.BsvEncode(&buf, pver, enc)) + + decoded := &MsgStreamAck{} + require.NoError(t, decoded.Bsvdecode(&buf, pver, enc)) + + assert.Equal(t, msg.AssociationID, decoded.AssociationID) + assert.Equal(t, msg.StreamType, decoded.StreamType) +} + +func TestStreamAckEmptyAssocID(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + // Empty assoc ID is valid for streamack (the field is still encoded as var_bytes) + msg := NewMsgStreamAck(nil, StreamTypeData1) + + var buf bytes.Buffer + require.NoError(t, msg.BsvEncode(&buf, pver, enc)) + + decoded := &MsgStreamAck{} + require.NoError(t, decoded.Bsvdecode(&buf, pver, enc)) + + assert.Empty(t, decoded.AssociationID) + assert.Equal(t, StreamTypeData1, decoded.StreamType) +} + +func TestStreamAckWireErrors(t *testing.T) { + pver := ProtocolVersion + + assocID := []byte{0x01, 0x02, 0x03} + msg := NewMsgStreamAck(assocID, StreamTypeGeneral) + + tests := []struct { + in *MsgStreamAck + buf []byte + pver uint32 + enc MessageEncoding + max int + writeErr error + readErr error + }{ + // Short write/read at association ID varint. + {msg, []byte{}, pver, BaseEncoding, 0, io.ErrShortWrite, io.EOF}, + } + + for _, test := range tests { + assertWireError(t, test.in, &MsgStreamAck{}, test.buf, test.pver, + test.enc, test.max, test.writeErr, test.readErr) + } +} + +func TestStreamAckMakeEmptyMessage(t *testing.T) { + msg, err := makeEmptyMessage(CmdStreamAck) + require.NoError(t, err) + _, ok := msg.(*MsgStreamAck) + assert.True(t, ok) +} diff --git a/msg_version.go b/msg_version.go index cf32bcc..6f194dc 100644 --- a/msg_version.go +++ b/msg_version.go @@ -54,6 +54,12 @@ type MsgVersion struct { // Don't announce transactions to peer. DisableRelayTx bool + + // AssociationID identifies a multistream association. When present, it + // indicates the peer supports the multistreams protocol. Format is + // [type_byte][uuid_bytes] (typically 17 bytes: 0x01 + 16-byte UUID). + // Empty/nil means no multistream support (legacy single-stream mode). + AssociationID []byte } // HasService returns whether the specified service is supported by the peer @@ -150,6 +156,15 @@ func (msg *MsgVersion) Bsvdecode(r io.Reader, pver uint32, _ MessageEncoding) er msg.DisableRelayTx = !relayTx } + // AssociationID is appended after the relay field by peers that support + // the multistreams protocol. It is optional and backward compatible. + if buf.Len() > 0 { + msg.AssociationID, err = ReadVarBytes(buf, pver, MaxAssociationIDLen, "AssociationID") + if err != nil { + return err + } + } + return nil } @@ -202,6 +217,14 @@ func (msg *MsgVersion) BsvEncode(w io.Writer, pver uint32, _ MessageEncoding) er } } + // Write AssociationID if present (multistreams support). + if len(msg.AssociationID) > 0 { + err = WriteVarBytes(w, pver, msg.AssociationID) + if err != nil { + return err + } + } + return nil } @@ -214,13 +237,12 @@ func (msg *MsgVersion) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgVersion) MaxPayloadLength(pver uint32) uint64 { - // XXX: <= 106 different // Protocol version 4 bytes + services 8 bytes + timestamp 8 bytes + // remote and local net addresses + nonce 8 bytes + length of user // agent (varInt) + max allowed useragent length + last block 4 bytes + - // relay transactions flag 1 byte. + // relay transactions flag 1 byte + varint(assoc_id_len) + max assoc ID. return 33 + (maxNetAddressPayload(pver) * 2) + MaxVarIntPayload + - MaxUserAgentLen + MaxUserAgentLen + MaxVarIntPayload + MaxAssociationIDLen } // NewMsgVersion returns a new bitcoin version message that conforms to the diff --git a/msg_version_test.go b/msg_version_test.go index d76767d..d266496 100644 --- a/msg_version_test.go +++ b/msg_version_test.go @@ -119,8 +119,8 @@ func TestVersion(t *testing.T) { // Protocol version 4 bytes + services 8 bytes + timestamp 8 bytes + // remote and local net addresses + nonce 8 bytes + length of user agent // (varInt) + max allowed user agent length and last block 4 bytes + - // relay transactions flag 1 byte. - wantPayload := uint64(358) + // relay transactions flag 1 byte + varint(assoc_id_len) + max assoc ID. + wantPayload := uint64(358 + MaxVarIntPayload + MaxAssociationIDLen) maxPayload := msg.MaxPayloadLength(pver) if maxPayload != wantPayload { @@ -593,3 +593,73 @@ var baseVersionBIP0037Encoded = []byte{ 0xfa, 0x92, 0x03, 0x00, // Last block 0x01, // Relay tx } + +// TestVersionAssociationID tests that AssociationID is correctly encoded and +// decoded in the version message, and that backward compatibility is maintained +// when AssociationID is not present. +func TestVersionAssociationID(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + // Create a version with AssociationID. + tcpAddrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333} + me := NewNetAddress(tcpAddrMe, SFNodeNetwork) + tcpAddrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333} + you := NewNetAddress(tcpAddrYou, SFNodeNetwork) + msg := NewMsgVersion(me, you, 123123, 234234) + msg.AssociationID = []byte{ + 0x01, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0x00, + } + + // Encode. + var buf bytes.Buffer + require.NoError(t, msg.BsvEncode(&buf, pver, enc)) + + // Decode. + var decoded MsgVersion + rbuf := bytes.NewBuffer(buf.Bytes()) + require.NoError(t, decoded.Bsvdecode(rbuf, pver, enc)) + + require.Equal(t, msg.AssociationID, decoded.AssociationID) +} + +// TestVersionWithoutAssociationID ensures backward compatibility when +// AssociationID is not present in the version message. +func TestVersionWithoutAssociationID(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + tcpAddrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333} + me := NewNetAddress(tcpAddrMe, SFNodeNetwork) + tcpAddrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333} + you := NewNetAddress(tcpAddrYou, SFNodeNetwork) + msg := NewMsgVersion(me, you, 123123, 234234) + // No AssociationID set. + + var buf bytes.Buffer + require.NoError(t, msg.BsvEncode(&buf, pver, enc)) + + var decoded MsgVersion + rbuf := bytes.NewBuffer(buf.Bytes()) + require.NoError(t, decoded.Bsvdecode(rbuf, pver, enc)) + + require.Nil(t, decoded.AssociationID) +} + +// TestVersionDecodeOldFormatWithAssocID ensures that a version message +// encoded in the old format (BIP0037 with relay flag but no AssociationID) +// can still be decoded. +func TestVersionDecodeOldFormatWithAssocID(t *testing.T) { + pver := ProtocolVersion + enc := BaseEncoding + + // Decode a standard BIP0037 encoded version (with relay tx but no assoc ID). + var decoded MsgVersion + rbuf := bytes.NewBuffer(baseVersionBIP0037Encoded) + require.NoError(t, decoded.Bsvdecode(rbuf, pver, enc)) + + require.Nil(t, decoded.AssociationID) + require.Equal(t, baseVersionBIP0037.ProtocolVersion, decoded.ProtocolVersion) + require.Equal(t, baseVersionBIP0037.UserAgent, decoded.UserAgent) +} diff --git a/stream_type.go b/stream_type.go new file mode 100644 index 0000000..4fbb187 --- /dev/null +++ b/stream_type.go @@ -0,0 +1,25 @@ +package wire + +// StreamType represents the type of stream within a multistream association. +type StreamType uint8 + +const ( + // StreamTypeUnknown is an unknown stream type. + StreamTypeUnknown StreamType = 0 + // StreamTypeGeneral is a general stream type. + StreamTypeGeneral StreamType = 1 + // StreamTypeData1 is a data stream type. + StreamTypeData1 StreamType = 2 + // StreamTypeData2 is a data stream type. + StreamTypeData2 StreamType = 3 + // StreamTypeData3 is a data stream type. + StreamTypeData3 StreamType = 4 + // StreamTypeData4 is a data stream type. + StreamTypeData4 StreamType = 5 +) + +// MaxAssociationIDLen is the maximum allowed length for an association ID +// in the version message. Format is [type byte][UUID bytes], with the most +// common format being 1 + 16 = 17 bytes, but we allow up to 129 bytes +// for future extensibility. +const MaxAssociationIDLen = 129