Skip to content

Commit f8839ce

Browse files
committed
ms-select2: abbreviation on the server side
1 parent 33c7cf4 commit f8839ce

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

abbrevTree.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,23 @@ func (at *abbrevTree[T]) Abbreviate(pid T) []byte {
4141
return nil
4242
}
4343

44+
func (at *abbrevTree[T]) GetProtocolID(prefix []byte) (T, error) {
45+
if at.root == nil {
46+
return "", ErrUnknownPrefix
47+
}
48+
current := at.root
49+
for _, b := range prefix {
50+
if current.children[b] == nil {
51+
return "", ErrUnknownPrefix
52+
}
53+
current = current.children[b]
54+
}
55+
if current.p == nil {
56+
return "", ErrUnknownPrefix
57+
}
58+
return current.p.protocolID, nil
59+
}
60+
4461
func (at *abbrevTree[T]) AddProtocol(pid T) {
4562
hash := sha256.Sum256([]byte(pid))
4663

multistream.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package multistream
55

66
import (
77
"bufio"
8+
"encoding/hex"
89
"errors"
910
"fmt"
1011
"io"
@@ -18,6 +19,9 @@ import (
1819
// ErrTooLarge is an error to signal that an incoming message was too large
1920
var ErrTooLarge = errors.New("incoming message was too large")
2021

22+
// ErrUnknownPrefix is an error to signal that the protocol hash prefix is unknown
23+
var ErrUnknownPrefix = errors.New("unknown protocol hash prefix")
24+
2125
// ProtocolID identifies the multistream protocol itself and makes sure
2226
// the multistream muxers on both sides of a channel can work with each other.
2327
const ProtocolID = "/multistream/1.0.0"
@@ -55,6 +59,7 @@ type Handler[T StringLike] struct {
5559
type MultistreamMuxer[T StringLike] struct {
5660
handlerlock sync.RWMutex
5761
handlers []Handler[T]
62+
abbrevTree abbrevTree[T]
5863
}
5964

6065
// NewMultistreamMuxer creates a muxer.
@@ -137,6 +142,7 @@ func (msm *MultistreamMuxer[T]) AddHandlerWithFunc(protocol T, match func(T) boo
137142
msm.handlerlock.Lock()
138143
defer msm.handlerlock.Unlock()
139144

145+
msm.abbrevTree.AddProtocol(protocol)
140146
msm.removeHandler(protocol)
141147
msm.handlers = append(msm.handlers, Handler[T]{
142148
MatchFunc: match,
@@ -150,6 +156,7 @@ func (msm *MultistreamMuxer[T]) RemoveHandler(protocol T) {
150156
msm.handlerlock.Lock()
151157
defer msm.handlerlock.Unlock()
152158

159+
msm.abbrevTree.RemoveProtocol(protocol)
153160
msm.removeHandler(protocol)
154161
}
155162

@@ -179,6 +186,24 @@ func (msm *MultistreamMuxer[T]) Protocols() []T {
179186
// fails because of a ProtocolID mismatch.
180187
var ErrIncorrectVersion = errors.New("client connected with incorrect version")
181188

189+
func (msm *MultistreamMuxer[T]) decodeProtocol(s T) (T, error) {
190+
msm.handlerlock.RLock()
191+
defer msm.handlerlock.RUnlock()
192+
193+
bytes, err := hex.DecodeString(string(s))
194+
// TODO: decide whether to compare strings or use abbrevTree by looking at
195+
// multistream version instead.
196+
if err != nil {
197+
return s, nil
198+
}
199+
200+
proto, err := msm.abbrevTree.GetProtocolID(bytes)
201+
if err != nil {
202+
return "", err
203+
}
204+
return proto, nil
205+
}
206+
182207
func (msm *MultistreamMuxer[T]) findHandler(proto T) *Handler[T] {
183208
msm.handlerlock.RLock()
184209
defer msm.handlerlock.RUnlock()
@@ -225,7 +250,12 @@ loop:
225250
return "", nil, err
226251
}
227252

228-
h := msm.findHandler(tok)
253+
p, err := msm.decodeProtocol(tok)
254+
if err != nil {
255+
return "", nil, err
256+
}
257+
258+
h := msm.findHandler(p)
229259
if h == nil {
230260
if err := delimWriteBuffered(rwc, []byte("na")); err != nil {
231261
return "", nil, err
@@ -239,7 +269,7 @@ loop:
239269
_ = delimWriteBuffered(rwc, []byte(tok))
240270

241271
// hand off processing to the sub-protocol handler
242-
return tok, h.Handle, nil
272+
return p, h.Handle, nil
243273
}
244274

245275
}

0 commit comments

Comments
 (0)