Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions abbrevTree.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package multistream

import (
"crypto/sha256"
)

type nodeProtocol[T StringLike] struct {
protocolID T
tombstoneBit bool
}

type abbrevTree[T StringLike] struct {
root *abbrevNode[T]
}

type abbrevNode[T StringLike] struct {
p *nodeProtocol[T]
children [256]*abbrevNode[T]
}

func (at *abbrevTree[T]) Abbreviate(pid T) []byte {
var result []byte
hash := sha256.Sum256([]byte(pid))

if at.root == nil {
return nil
}

current := at.root
// go furthest in the tree
for _, b := range hash {
if current.children[b] != nil {
result = append(result, b)
current = current.children[b]
}
}

if current.p != nil && current.p.protocolID == pid && !current.p.tombstoneBit {
return result
}
return nil
}

func (at *abbrevTree[T]) GetProtocolID(prefix []byte) (T, error) {
if at.root == nil {
return "", ErrUnknownPrefix
}
current := at.root
for _, b := range prefix {
if current.children[b] == nil {
return "", ErrUnknownPrefix
}
current = current.children[b]
}
if current.p == nil {
return "", ErrUnknownPrefix
}
return current.p.protocolID, nil
}

func (at *abbrevTree[T]) AddProtocol(pid T) {
hash := sha256.Sum256([]byte(pid))

if at.root == nil {
at.root = &abbrevNode[T]{}
}

current := at.root
for idx, b := range hash {
if current.children[b] == nil {
current.children[b] = &abbrevNode[T]{
p: &nodeProtocol[T]{
protocolID: pid,
tombstoneBit: false,
},
}
return
}
current = current.children[b]

if current.p != nil {
if current.p.protocolID == pid {
// Resurrect the protocol ID.
current.p.tombstoneBit = false
} else if !current.p.tombstoneBit {
// There is another protocol in this node, so we need to duplicate it down.
h := sha256.Sum256([]byte(current.p.protocolID))

if current.children[h[idx+1]] == nil {
// It should be fine to reference the same nodeProtocol instance.
current.children[h[idx+1]] = &abbrevNode[T]{p: current.p}
}
}
}
}
}

func (at *abbrevTree[T]) RemoveProtocol(pid T) {
hash := sha256.Sum256([]byte(pid))

if at.root == nil {
return
}
current := at.root
for _, b := range hash {
if current.children[b] == nil {
break
}
current = current.children[b]

if current.p.protocolID == pid {
current.p.tombstoneBit = true
}
}
}
155 changes: 155 additions & 0 deletions abbrevTree_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package multistream

import (
"bytes"
"crypto/sha256"
"testing"
)

func TestAbbrevTreeAddProtocol(t *testing.T) {
tree := &abbrevTree[string]{}

proto1 := "protocol1"
hash1 := sha256.Sum256([]byte(proto1))
proto2 := "protocol2"
hash2 := sha256.Sum256([]byte(proto2))
proto3 := "protocol251" // this one has the same first byte as "protocol1"
hash3 := sha256.Sum256([]byte(proto3))

// make sure we don't make mistakes on the hashes
if hash1[0] == hash2[0] {
t.Fatal("the first bytes of hash1 and hash2 should be different")
}
if hash1[0] != hash3[0] {
t.Fatal("the first bytes of hash1 and hash3 should be the same")
}
if hash1[1] == hash3[1] {
t.Fatal("the second bytes of hash1 and hash3 should be different")
}

// add only proto1
tree.AddProtocol(proto1)

if tree.root == nil {
t.Fatal("root should not be nil after adding protocol")
}
if tree.root.children[hash1[0]] == nil || tree.root.children[hash1[0]].p == nil {
t.Fatal("the protocol was not added")
}
if tree.root.children[hash1[0]].p.protocolID != proto1 {
t.Fatal("the protocol ID was wrong")
}
if tree.root.children[hash1[0]].p.tombstoneBit {
t.Fatal("tombstoneBit of proto1 must not be set")
}
if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0]}) {
t.Fatal("abbreviation of proto1 is incorrect")
}

// also add proto2
tree.AddProtocol(proto2)

if tree.root.children[hash2[0]] == nil || tree.root.children[hash2[0]].p == nil {
t.Fatal("the protocol was not added")
}
if tree.root.children[hash2[0]].p.protocolID != proto2 {
t.Fatal("the protocol ID was wrong")
}
if tree.root.children[hash2[0]].p.tombstoneBit {
t.Fatal("tombstoneBit of proto2 must not be set")
}
if !bytes.Equal(tree.Abbreviate(proto2), []byte{hash2[0]}) {
t.Fatal("abbreviation of proto2 is incorrect")
}

// add proto3 which has the same first byte of the hash as proto1
tree.AddProtocol(proto3)

n1 := tree.root.children[hash1[0]]
// the node at the first level should still be proto1
if n1.p.protocolID != proto1 {
t.Fatal("the node in the first level should not be modified")
}
// proto1 should be duplicated down
if n1.children[hash1[1]] == nil || n1.children[hash1[1]].p == nil {
t.Fatal("proto1 was not duplicated")
}
if n1.children[hash1[1]].p.tombstoneBit {
t.Fatal("tombstoneBit of proto1 must not be set")
}
if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1]}) {
t.Fatal("abbreviation of proto1 is incorrect")
}
// proto3 should be added in the second level
if n1.children[hash3[1]] == nil || n1.children[hash3[1]].p == nil {
t.Fatal("proto3 was not added")
}
if n1.children[hash3[1]].p.tombstoneBit {
t.Fatal("tombstoneBit of proto3 must not be set")
}
if !bytes.Equal(tree.Abbreviate(proto3), []byte{hash3[0], hash3[1]}) {
t.Fatal("abbreviation of proto3 is incorrect")
}
}

func TestAbbrevTreeRemoveProtocol(t *testing.T) {
tree := &abbrevTree[string]{}

proto1 := "protocol1"
hash1 := sha256.Sum256([]byte(proto1))
proto2 := "protocol2"
proto3 := "protocol251" // this one has the same first byte as "protocol1"

tree.AddProtocol(proto1)
tree.AddProtocol(proto2)
tree.AddProtocol(proto3)

// remove only proto1
tree.RemoveProtocol(proto1)

n1 := tree.root.children[hash1[0]]
if !n1.p.tombstoneBit {
t.Fatal("tombstoneBit of proto1 must be set")
}
if !n1.children[hash1[1]].p.tombstoneBit {
t.Fatal("tombstoneBit of proto1 must be set")
}
if tree.Abbreviate(proto1) != nil {
t.Fatal("abbreviation of proto1 should be nil")
}
}

func TestAbbrevTreeResurrectProtocol(t *testing.T) {
tree := &abbrevTree[string]{}

proto1 := "protocol1"
hash1 := sha256.Sum256([]byte(proto1))
proto2 := "protocol2"
proto3 := "protocol251" // this one has the same first byte as "protocol1"

tree.AddProtocol(proto1)
tree.AddProtocol(proto2)
tree.AddProtocol(proto3)
tree.RemoveProtocol(proto1)
tree.AddProtocol(proto1)

n1 := tree.root.children[hash1[0]]
if n1.p.tombstoneBit {
t.Fatal("tombstoneBit of proto1 must not be set")
}
if n1.children[hash1[1]].p.tombstoneBit {
t.Fatal("tombstoneBit of proto1 must not be set")
}

// There should be another leaf node added for proto1
n2 := n1.children[hash1[1]]
if n2.children[hash1[2]] == nil || n2.children[hash1[2]].p == nil {
t.Fatal("proto1 was not added")
}
if n2.children[hash1[2]].p.tombstoneBit {
t.Fatal("tombstoneBit of proto1 must not be set")
}
if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1], hash1[2]}) {
t.Fatal("abbreviation of proto1 is incorrect")
}
}
18 changes: 18 additions & 0 deletions lazyClient.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multistream

import (
"encoding/hex"
"fmt"
"io"
)
Expand All @@ -17,6 +18,23 @@ func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
}
}

func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) LazyConn {
t := &abbrevTree[T]{}
for _, p := range peerProtos {
t.AddProtocol(p)
}

// TODO: use a proper varint instead of a hex string later
abbrv := T(hex.EncodeToString(t.Abbreviate(proto)))
return &lazyClientConn[T]{
protos: []T{ProtocolID, abbrv},
con: c,

rhandshakeOnce: newOnce(),
whandshakeOnce: newOnce(),
}
}

// NewMultistream returns a multistream for the given protocol. This will not
// perform any protocol selection. If you are using a MultistreamMuxer, use
// NewMSSelect.
Expand Down
Loading