Skip to content

Commit 628bb13

Browse files
committed
ms-select2: abbreviation on the server side
1 parent 33c7cf4 commit 628bb13

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

abbrevTree.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,23 @@ func (at *abbrevTree[T]) Abbreviate(pid T) []byte {
4141
return nil
4242
}
4343

44+
func (at *abbrevTree[T]) GetProtocolID(prefix []byte) (T, error) {
45+
if at.root == nil {
46+
return "", ErrUnknownPrefix
47+
}
48+
current := at.root
49+
for _, b := range prefix {
50+
if current.children[b] == nil {
51+
return "", ErrUnknownPrefix
52+
}
53+
current = current.children[b]
54+
}
55+
if current.p == nil {
56+
return "", ErrUnknownPrefix
57+
}
58+
return current.p.protocolID, nil
59+
}
60+
4461
func (at *abbrevTree[T]) AddProtocol(pid T) {
4562
hash := sha256.Sum256([]byte(pid))
4663

multistream.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package multistream
55

66
import (
77
"bufio"
8+
"encoding/hex"
89
"errors"
910
"fmt"
1011
"io"
@@ -18,6 +19,9 @@ import (
1819
// ErrTooLarge is an error to signal that an incoming message was too large
1920
var ErrTooLarge = errors.New("incoming message was too large")
2021

22+
// ErrUnknownPrefix is an error to signal that the protocol hash prefix is unknown
23+
var ErrUnknownPrefix = errors.New("unknown protocol hash prefix")
24+
2125
// ProtocolID identifies the multistream protocol itself and makes sure
2226
// the multistream muxers on both sides of a channel can work with each other.
2327
const ProtocolID = "/multistream/1.0.0"
@@ -55,6 +59,7 @@ type Handler[T StringLike] struct {
5559
type MultistreamMuxer[T StringLike] struct {
5660
handlerlock sync.RWMutex
5761
handlers []Handler[T]
62+
abbrevTree abbrevTree[T]
5863
}
5964

6065
// NewMultistreamMuxer creates a muxer.
@@ -137,6 +142,7 @@ func (msm *MultistreamMuxer[T]) AddHandlerWithFunc(protocol T, match func(T) boo
137142
msm.handlerlock.Lock()
138143
defer msm.handlerlock.Unlock()
139144

145+
msm.abbrevTree.AddProtocol(protocol)
140146
msm.removeHandler(protocol)
141147
msm.handlers = append(msm.handlers, Handler[T]{
142148
MatchFunc: match,
@@ -150,6 +156,7 @@ func (msm *MultistreamMuxer[T]) RemoveHandler(protocol T) {
150156
msm.handlerlock.Lock()
151157
defer msm.handlerlock.Unlock()
152158

159+
msm.abbrevTree.RemoveProtocol(protocol)
153160
msm.removeHandler(protocol)
154161
}
155162

@@ -179,6 +186,21 @@ func (msm *MultistreamMuxer[T]) Protocols() []T {
179186
// fails because of a ProtocolID mismatch.
180187
var ErrIncorrectVersion = errors.New("client connected with incorrect version")
181188

189+
func (msm *MultistreamMuxer[T]) decodeProtocol(s T) (T, error) {
190+
bytes, err := hex.DecodeString(string(s))
191+
// TODO: decide whether to compare strings or use abbrevTree by looking at
192+
// multistream version instead.
193+
if err != nil {
194+
return s, nil
195+
}
196+
197+
proto, err := msm.abbrevTree.GetProtocolID(bytes)
198+
if err != nil {
199+
return "", err
200+
}
201+
return proto, nil
202+
}
203+
182204
func (msm *MultistreamMuxer[T]) findHandler(proto T) *Handler[T] {
183205
msm.handlerlock.RLock()
184206
defer msm.handlerlock.RUnlock()
@@ -225,7 +247,12 @@ loop:
225247
return "", nil, err
226248
}
227249

228-
h := msm.findHandler(tok)
250+
p, err := msm.decodeProtocol(tok)
251+
if err != nil {
252+
return "", nil, err
253+
}
254+
255+
h := msm.findHandler(p)
229256
if h == nil {
230257
if err := delimWriteBuffered(rwc, []byte("na")); err != nil {
231258
return "", nil, err
@@ -236,10 +263,10 @@ loop:
236263
// Ignore the error here. We want the handshake to finish, even if the
237264
// other side has closed this rwc for writing. They may have sent us a
238265
// message and closed. Future writers will get an error anyways.
239-
_ = delimWriteBuffered(rwc, []byte(tok))
266+
_ = delimWriteBuffered(rwc, []byte(p))
240267

241268
// hand off processing to the sub-protocol handler
242-
return tok, h.Handle, nil
269+
return p, h.Handle, nil
243270
}
244271

245272
}

0 commit comments

Comments
 (0)