Skip to content

Commit 33c7cf4

Browse files
committed
ms-select2: abbreviation tree and NewMSSelect2
1 parent 82393f9 commit 33c7cf4

File tree

4 files changed

+265
-2
lines changed

4 files changed

+265
-2
lines changed

abbrevTree.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package multistream
2+
3+
import (
4+
"crypto/sha256"
5+
)
6+
7+
type nodeProtocol[T StringLike] struct {
8+
protocolID T
9+
tombstoneBit bool
10+
}
11+
12+
type abbrevTree[T StringLike] struct {
13+
root *abbrevNode[T]
14+
}
15+
16+
type abbrevNode[T StringLike] struct {
17+
p *nodeProtocol[T]
18+
children [256]*abbrevNode[T]
19+
}
20+
21+
func (at *abbrevTree[T]) Abbreviate(pid T) []byte {
22+
var result []byte
23+
hash := sha256.Sum256([]byte(pid))
24+
25+
if at.root == nil {
26+
return nil
27+
}
28+
29+
current := at.root
30+
// go furthest in the tree
31+
for _, b := range hash {
32+
if current.children[b] != nil {
33+
result = append(result, b)
34+
current = current.children[b]
35+
}
36+
}
37+
38+
if current.p != nil && current.p.protocolID == pid && !current.p.tombstoneBit {
39+
return result
40+
}
41+
return nil
42+
}
43+
44+
func (at *abbrevTree[T]) AddProtocol(pid T) {
45+
hash := sha256.Sum256([]byte(pid))
46+
47+
if at.root == nil {
48+
at.root = &abbrevNode[T]{}
49+
}
50+
51+
current := at.root
52+
for idx, b := range hash {
53+
if current.children[b] == nil {
54+
current.children[b] = &abbrevNode[T]{
55+
p: &nodeProtocol[T]{
56+
protocolID: pid,
57+
tombstoneBit: false,
58+
},
59+
}
60+
return
61+
}
62+
current = current.children[b]
63+
64+
if current.p != nil {
65+
if current.p.protocolID == pid {
66+
// Resurrect the protocol ID.
67+
current.p.tombstoneBit = false
68+
} else if !current.p.tombstoneBit {
69+
// There is another protocol in this node, so we need to duplicate it down.
70+
h := sha256.Sum256([]byte(current.p.protocolID))
71+
72+
if current.children[h[idx+1]] == nil {
73+
// It should be fine to reference the same nodeProtocol instance.
74+
current.children[h[idx+1]] = &abbrevNode[T]{p: current.p}
75+
}
76+
}
77+
}
78+
}
79+
}
80+
81+
func (at *abbrevTree[T]) RemoveProtocol(pid T) {
82+
hash := sha256.Sum256([]byte(pid))
83+
84+
if at.root == nil {
85+
return
86+
}
87+
current := at.root
88+
for _, b := range hash {
89+
if current.children[b] == nil {
90+
break
91+
}
92+
current = current.children[b]
93+
94+
if current.p.protocolID == pid {
95+
current.p.tombstoneBit = true
96+
}
97+
}
98+
}

abbrevTree_test.go

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package multistream
2+
3+
import (
4+
"bytes"
5+
"crypto/sha256"
6+
"testing"
7+
)
8+
9+
func TestAbbrevTreeAddProtocol(t *testing.T) {
10+
tree := &abbrevTree[string]{}
11+
12+
proto1 := "protocol1"
13+
hash1 := sha256.Sum256([]byte(proto1))
14+
proto2 := "protocol2"
15+
hash2 := sha256.Sum256([]byte(proto2))
16+
proto3 := "protocol251" // this one has the same first byte as "protocol1"
17+
hash3 := sha256.Sum256([]byte(proto3))
18+
19+
// make sure we don't make mistakes on the hashes
20+
if hash1[0] == hash2[0] {
21+
t.Fatal("the first bytes of hash1 and hash2 should be different")
22+
}
23+
if hash1[0] != hash3[0] {
24+
t.Fatal("the first bytes of hash1 and hash3 should be the same")
25+
}
26+
if hash1[1] == hash3[1] {
27+
t.Fatal("the second bytes of hash1 and hash3 should be different")
28+
}
29+
30+
// add only proto1
31+
tree.AddProtocol(proto1)
32+
33+
if tree.root == nil {
34+
t.Fatal("root should not be nil after adding protocol")
35+
}
36+
if tree.root.children[hash1[0]] == nil || tree.root.children[hash1[0]].p == nil {
37+
t.Fatal("the protocol was not added")
38+
}
39+
if tree.root.children[hash1[0]].p.protocolID != proto1 {
40+
t.Fatal("the protocol ID was wrong")
41+
}
42+
if tree.root.children[hash1[0]].p.tombstoneBit {
43+
t.Fatal("tombstoneBit of proto1 must not be set")
44+
}
45+
if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0]}) {
46+
t.Fatal("abbreviation of proto1 is incorrect")
47+
}
48+
49+
// also add proto2
50+
tree.AddProtocol(proto2)
51+
52+
if tree.root.children[hash2[0]] == nil || tree.root.children[hash2[0]].p == nil {
53+
t.Fatal("the protocol was not added")
54+
}
55+
if tree.root.children[hash2[0]].p.protocolID != proto2 {
56+
t.Fatal("the protocol ID was wrong")
57+
}
58+
if tree.root.children[hash2[0]].p.tombstoneBit {
59+
t.Fatal("tombstoneBit of proto2 must not be set")
60+
}
61+
if !bytes.Equal(tree.Abbreviate(proto2), []byte{hash2[0]}) {
62+
t.Fatal("abbreviation of proto2 is incorrect")
63+
}
64+
65+
// add proto3 which has the same first byte of the hash as proto1
66+
tree.AddProtocol(proto3)
67+
68+
n1 := tree.root.children[hash1[0]]
69+
// the node at the first level should still be proto1
70+
if n1.p.protocolID != proto1 {
71+
t.Fatal("the node in the first level should not be modified")
72+
}
73+
// proto1 should be duplicated down
74+
if n1.children[hash1[1]] == nil || n1.children[hash1[1]].p == nil {
75+
t.Fatal("proto1 was not duplicated")
76+
}
77+
if n1.children[hash1[1]].p.tombstoneBit {
78+
t.Fatal("tombstoneBit of proto1 must not be set")
79+
}
80+
if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1]}) {
81+
t.Fatal("abbreviation of proto1 is incorrect")
82+
}
83+
// proto3 should be added in the second level
84+
if n1.children[hash3[1]] == nil || n1.children[hash3[1]].p == nil {
85+
t.Fatal("proto3 was not added")
86+
}
87+
if n1.children[hash3[1]].p.tombstoneBit {
88+
t.Fatal("tombstoneBit of proto3 must not be set")
89+
}
90+
if !bytes.Equal(tree.Abbreviate(proto3), []byte{hash3[0], hash3[1]}) {
91+
t.Fatal("abbreviation of proto3 is incorrect")
92+
}
93+
}
94+
95+
func TestAbbrevTreeRemoveProtocol(t *testing.T) {
96+
tree := &abbrevTree[string]{}
97+
98+
proto1 := "protocol1"
99+
hash1 := sha256.Sum256([]byte(proto1))
100+
proto2 := "protocol2"
101+
proto3 := "protocol251" // this one has the same first byte as "protocol1"
102+
103+
tree.AddProtocol(proto1)
104+
tree.AddProtocol(proto2)
105+
tree.AddProtocol(proto3)
106+
107+
// remove only proto1
108+
tree.RemoveProtocol(proto1)
109+
110+
n1 := tree.root.children[hash1[0]]
111+
if !n1.p.tombstoneBit {
112+
t.Fatal("tombstoneBit of proto1 must be set")
113+
}
114+
if !n1.children[hash1[1]].p.tombstoneBit {
115+
t.Fatal("tombstoneBit of proto1 must be set")
116+
}
117+
if tree.Abbreviate(proto1) != nil {
118+
t.Fatal("abbreviation of proto1 should be nil")
119+
}
120+
}
121+
122+
func TestAbbrevTreeResurrectProtocol(t *testing.T) {
123+
tree := &abbrevTree[string]{}
124+
125+
proto1 := "protocol1"
126+
hash1 := sha256.Sum256([]byte(proto1))
127+
proto2 := "protocol2"
128+
proto3 := "protocol251" // this one has the same first byte as "protocol1"
129+
130+
tree.AddProtocol(proto1)
131+
tree.AddProtocol(proto2)
132+
tree.AddProtocol(proto3)
133+
tree.RemoveProtocol(proto1)
134+
tree.AddProtocol(proto1)
135+
136+
n1 := tree.root.children[hash1[0]]
137+
if n1.p.tombstoneBit {
138+
t.Fatal("tombstoneBit of proto1 must not be set")
139+
}
140+
if n1.children[hash1[1]].p.tombstoneBit {
141+
t.Fatal("tombstoneBit of proto1 must not be set")
142+
}
143+
144+
// There should be another leaf node added for proto1
145+
n2 := n1.children[hash1[1]]
146+
if n2.children[hash1[2]] == nil || n2.children[hash1[2]].p == nil {
147+
t.Fatal("proto1 was not added")
148+
}
149+
if n2.children[hash1[2]].p.tombstoneBit {
150+
t.Fatal("tombstoneBit of proto1 must not be set")
151+
}
152+
if !bytes.Equal(tree.Abbreviate(proto1), []byte{hash1[0], hash1[1], hash1[2]}) {
153+
t.Fatal("abbreviation of proto1 is incorrect")
154+
}
155+
}

lazyClient.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package multistream
22

33
import (
4+
"encoding/hex"
45
"fmt"
56
"io"
67
)
@@ -18,9 +19,15 @@ func NewMSSelect[T StringLike](c io.ReadWriteCloser, proto T) LazyConn {
1819
}
1920

2021
func NewMSSelect2[T StringLike](c io.ReadWriteCloser, proto T, peerProtos []T) LazyConn {
21-
// TODO: put peerProtos into lazyClientConn so that it knows what protocols the other peer supports
22+
t := &abbrevTree[T]{}
23+
for _, p := range peerProtos {
24+
t.AddProtocol(p)
25+
}
26+
27+
// TODO: use a proper varint instead of a hex string later
28+
abbrv := T(hex.EncodeToString(t.Abbreviate(proto)))
2229
return &lazyClientConn[T]{
23-
protos: []T{ProtocolID, proto},
30+
protos: []T{ProtocolID, abbrv},
2431
con: c,
2532

2633
rhandshakeOnce: newOnce(),

multistream.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ var ErrTooLarge = errors.New("incoming message was too large")
2222
// the multistream muxers on both sides of a channel can work with each other.
2323
const ProtocolID = "/multistream/1.0.0"
2424

25+
// Multistream-select version that protocol abbreviation is supported
26+
const AbbrevSupportedMSSVersion = 2
27+
2528
var writerPool = sync.Pool{
2629
New: func() interface{} {
2730
return bufio.NewWriter(nil)

0 commit comments

Comments
 (0)