diff --git a/abbrevTree.go b/abbrevTree.go new file mode 100644 index 0000000..21b4de9 --- /dev/null +++ b/abbrevTree.go @@ -0,0 +1,115 @@ +package multistream + +import ( + "crypto/sha256" +) + +type nodeProtocol[T StringLike] struct { + protocolID T + tombstoneBit bool +} + +type abbrevTree[T StringLike] struct { + root *abbrevNode[T] +} + +type abbrevNode[T StringLike] struct { + p *nodeProtocol[T] + children [256]*abbrevNode[T] +} + +func (at *abbrevTree[T]) Abbreviate(pid T) []byte { + var result []byte + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + return nil + } + + current := at.root + // go furthest in the tree + for _, b := range hash { + if current.children[b] != nil { + result = append(result, b) + current = current.children[b] + } + } + + if current.p != nil && current.p.protocolID == pid && !current.p.tombstoneBit { + return result + } + return nil +} + +func (at *abbrevTree[T]) GetProtocolID(prefix []byte) (T, error) { + if at.root == nil { + return "", ErrUnknownPrefix + } + current := at.root + for _, b := range prefix { + if current.children[b] == nil { + return "", ErrUnknownPrefix + } + current = current.children[b] + } + if current.p == nil { + return "", ErrUnknownPrefix + } + return current.p.protocolID, nil +} + +func (at *abbrevTree[T]) AddProtocol(pid T) { + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + at.root = &abbrevNode[T]{} + } + + current := at.root + for idx, b := range hash { + if current.children[b] == nil { + current.children[b] = &abbrevNode[T]{ + p: &nodeProtocol[T]{ + protocolID: pid, + tombstoneBit: false, + }, + } + return + } + current = current.children[b] + + if current.p != nil { + if current.p.protocolID == pid { + // Resurrect the protocol ID. + current.p.tombstoneBit = false + } else if !current.p.tombstoneBit { + // There is another protocol in this node, so we need to duplicate it down. + h := sha256.Sum256([]byte(current.p.protocolID)) + + if current.children[h[idx+1]] == nil { + // It should be fine to reference the same nodeProtocol instance. + current.children[h[idx+1]] = &abbrevNode[T]{p: current.p} + } + } + } + } +} + +func (at *abbrevTree[T]) RemoveProtocol(pid T) { + hash := sha256.Sum256([]byte(pid)) + + if at.root == nil { + return + } + current := at.root + for _, b := range hash { + if current.children[b] == nil { + break + } + current = current.children[b] + + if current.p.protocolID == pid { + current.p.tombstoneBit = true + } + } +} diff --git a/abbrevTree_test.go b/abbrevTree_test.go new file mode 100644 index 0000000..cd435e7 --- /dev/null +++ b/abbrevTree_test.go @@ -0,0 +1,155 @@ +package multistream + +import ( + "bytes" + "crypto/sha256" + "testing" +) + +func TestAbbrevTreeAddProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + hash2 := sha256.Sum256([]byte(proto2)) + proto3 := "protocol251" // this one has the same first byte as "protocol1" + hash3 := sha256.Sum256([]byte(proto3)) + + // make sure we don't make mistakes on the hashes + if hash1[0] == hash2[0] { + t.Fatal("the first bytes of hash1 and hash2 should be different") + } + if hash1[0] != hash3[0] { + t.Fatal("the first bytes of hash1 and hash3 should be the same") + } + if hash1[1] == hash3[1] { + t.Fatal("the second bytes of hash1 and hash3 should be different") + } + + // add only proto1 + tree.AddProtocol(proto1) + + if tree.root == nil { + t.Fatal("root should not be nil after adding protocol") + } + if tree.root.children[hash1[0]] == nil || tree.root.children[hash1[0]].p == nil { + t.Fatal("the protocol was not added") + } + if tree.root.children[hash1[0]].p.protocolID != proto1 { + t.Fatal("the protocol ID was wrong") + } + if tree.root.children[hash1[0]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } + + // also add proto2 + tree.AddProtocol(proto2) + + if tree.root.children[hash2[0]] == nil || tree.root.children[hash2[0]].p == nil { + t.Fatal("the protocol was not added") + } + if tree.root.children[hash2[0]].p.protocolID != proto2 { + t.Fatal("the protocol ID was wrong") + } + if tree.root.children[hash2[0]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto2 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto2), []byte{hash2[0]}) { + t.Fatal("abbreviation of proto2 is incorrect") + } + + // add proto3 which has the same first byte of the hash as proto1 + tree.AddProtocol(proto3) + + n1 := tree.root.children[hash1[0]] + // the node at the first level should still be proto1 + if n1.p.protocolID != proto1 { + t.Fatal("the node in the first level should not be modified") + } + // proto1 should be duplicated down + if n1.children[hash1[1]] == nil || n1.children[hash1[1]].p == nil { + t.Fatal("proto1 was not duplicated") + } + if n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } + // proto3 should be added in the second level + if n1.children[hash3[1]] == nil || n1.children[hash3[1]].p == nil { + t.Fatal("proto3 was not added") + } + if n1.children[hash3[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto3 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto3), []byte{hash3[0], hash3[1]}) { + t.Fatal("abbreviation of proto3 is incorrect") + } +} + +func TestAbbrevTreeRemoveProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + proto3 := "protocol251" // this one has the same first byte as "protocol1" + + tree.AddProtocol(proto1) + tree.AddProtocol(proto2) + tree.AddProtocol(proto3) + + // remove only proto1 + tree.RemoveProtocol(proto1) + + n1 := tree.root.children[hash1[0]] + if !n1.p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must be set") + } + if !n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must be set") + } + if tree.Abbreviate(proto1) != nil { + t.Fatal("abbreviation of proto1 should be nil") + } +} + +func TestAbbrevTreeResurrectProtocol(t *testing.T) { + tree := &abbrevTree[string]{} + + proto1 := "protocol1" + hash1 := sha256.Sum256([]byte(proto1)) + proto2 := "protocol2" + proto3 := "protocol251" // this one has the same first byte as "protocol1" + + tree.AddProtocol(proto1) + tree.AddProtocol(proto2) + tree.AddProtocol(proto3) + tree.RemoveProtocol(proto1) + tree.AddProtocol(proto1) + + n1 := tree.root.children[hash1[0]] + if n1.p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if n1.children[hash1[1]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + + // There should be another leaf node added for proto1 + n2 := n1.children[hash1[1]] + if n2.children[hash1[2]] == nil || n2.children[hash1[2]].p == nil { + t.Fatal("proto1 was not added") + } + if n2.children[hash1[2]].p.tombstoneBit { + t.Fatal("tombstoneBit of proto1 must not be set") + } + if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1], hash1[2]}) { + t.Fatal("abbreviation of proto1 is incorrect") + } +} diff --git a/lazyClient.go b/lazyClient.go index 3ff48f9..7690b0b 100644 --- a/lazyClient.go +++ b/lazyClient.go @@ -1,6 +1,7 @@ package multistream import ( + "encoding/hex" "fmt" "io" ) @@ -17,6 +18,23 @@ func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn { } } +func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) LazyConn { + t := &abbrevTree[T]{} + for _, p := range peerProtos { + t.AddProtocol(p) + } + + // TODO: use a proper varint instead of a hex string later + abbrv := T(hex.EncodeToString(t.Abbreviate(proto))) + return &lazyClientConn[T]{ + protos: []T{ProtocolID, abbrv}, + con: c, + + rhandshakeOnce: newOnce(), + whandshakeOnce: newOnce(), + } +} + // NewMultistream returns a multistream for the given protocol. This will not // perform any protocol selection. If you are using a MultistreamMuxer, use // NewMSSelect. diff --git a/multistream.go b/multistream.go index 17e1ef7..2e1e604 100644 --- a/multistream.go +++ b/multistream.go @@ -5,6 +5,7 @@ package multistream import ( "bufio" + "encoding/hex" "errors" "fmt" "io" @@ -18,10 +19,16 @@ import ( // ErrTooLarge is an error to signal that an incoming message was too large var ErrTooLarge = errors.New("incoming message was too large") +// ErrUnknownPrefix is an error to signal that the protocol hash prefix is unknown +var ErrUnknownPrefix = errors.New("unknown protocol hash prefix") + // ProtocolID identifies the multistream protocol itself and makes sure // the multistream muxers on both sides of a channel can work with each other. const ProtocolID = "/multistream/1.0.0" +// Multistream-select version that protocol abbreviation is supported +const AbbrevSupportedMSSVersion = 2 + var writerPool = sync.Pool{ New: func() interface{} { return bufio.NewWriter(nil) @@ -52,6 +59,7 @@ type Handler[T StringLike] struct { type MultistreamMuxer[T StringLike] struct { handlerlock sync.RWMutex handlers []Handler[T] + abbrevTree abbrevTree[T] } // NewMultistreamMuxer creates a muxer. @@ -134,6 +142,7 @@ func (msm *MultistreamMuxer[T]) AddHandlerWithFunc(protocol T, match func(T) boo msm.handlerlock.Lock() defer msm.handlerlock.Unlock() + msm.abbrevTree.AddProtocol(protocol) msm.removeHandler(protocol) msm.handlers = append(msm.handlers, Handler[T]{ MatchFunc: match, @@ -147,6 +156,7 @@ func (msm *MultistreamMuxer[T]) RemoveHandler(protocol T) { msm.handlerlock.Lock() defer msm.handlerlock.Unlock() + msm.abbrevTree.RemoveProtocol(protocol) msm.removeHandler(protocol) } @@ -176,6 +186,24 @@ func (msm *MultistreamMuxer[T]) Protocols() []T { // fails because of a ProtocolID mismatch. var ErrIncorrectVersion = errors.New("client connected with incorrect version") +func (msm *MultistreamMuxer[T]) decodeProtocol(s T) (T, error) { + msm.handlerlock.RLock() + defer msm.handlerlock.RUnlock() + + bytes, err := hex.DecodeString(string(s)) + // TODO: decide whether to compare strings or use abbrevTree by looking at + // multistream version instead. + if err != nil { + return s, nil + } + + proto, err := msm.abbrevTree.GetProtocolID(bytes) + if err != nil { + return "", err + } + return proto, nil +} + func (msm *MultistreamMuxer[T]) findHandler(proto T) *Handler[T] { msm.handlerlock.RLock() defer msm.handlerlock.RUnlock() @@ -222,7 +250,12 @@ loop: return "", nil, err } - h := msm.findHandler(tok) + p, err := msm.decodeProtocol(tok) + if err != nil { + return "", nil, err + } + + h := msm.findHandler(p) if h == nil { if err := delimWriteBuffered(rwc, []byte("na")); err != nil { return "", nil, err @@ -236,7 +269,7 @@ loop: _ = delimWriteBuffered(rwc, []byte(tok)) // hand off processing to the sub-protocol handler - return tok, h.Handle, nil + return p, h.Handle, nil } }