@@ -5,6 +5,7 @@ package multistream
55
66import (
77 "bufio"
8+ "encoding/hex"
89 "errors"
910 "fmt"
1011 "io"
@@ -18,6 +19,9 @@ import (
1819// ErrTooLarge is an error to signal that an incoming message was too large
1920var ErrTooLarge = errors .New ("incoming message was too large" )
2021
22+ // ErrUnknownPrefix is an error to signal that the protocol hash prefix is unknown
23+ var ErrUnknownPrefix = errors .New ("unknown protocol hash prefix" )
24+
2125// ProtocolID identifies the multistream protocol itself and makes sure
2226// the multistream muxers on both sides of a channel can work with each other.
2327const ProtocolID = "/multistream/1.0.0"
@@ -55,6 +59,7 @@ type Handler[T StringLike] struct {
5559type MultistreamMuxer [T StringLike ] struct {
5660 handlerlock sync.RWMutex
5761 handlers []Handler [T ]
62+ abbrevTree abbrevTree [T ]
5863}
5964
6065// NewMultistreamMuxer creates a muxer.
@@ -137,6 +142,7 @@ func (msm *MultistreamMuxer[T]) AddHandlerWithFunc(protocol T, match func(T) boo
137142 msm .handlerlock .Lock ()
138143 defer msm .handlerlock .Unlock ()
139144
145+ msm .abbrevTree .AddProtocol (protocol )
140146 msm .removeHandler (protocol )
141147 msm .handlers = append (msm .handlers , Handler [T ]{
142148 MatchFunc : match ,
@@ -150,6 +156,7 @@ func (msm *MultistreamMuxer[T]) RemoveHandler(protocol T) {
150156 msm .handlerlock .Lock ()
151157 defer msm .handlerlock .Unlock ()
152158
159+ msm .abbrevTree .RemoveProtocol (protocol )
153160 msm .removeHandler (protocol )
154161}
155162
@@ -179,6 +186,21 @@ func (msm *MultistreamMuxer[T]) Protocols() []T {
179186// fails because of a ProtocolID mismatch.
180187var ErrIncorrectVersion = errors .New ("client connected with incorrect version" )
181188
189+ func (msm * MultistreamMuxer [T ]) decodeProtocol (s T ) (T , error ) {
190+ bytes , err := hex .DecodeString (string (s ))
191+ // TODO: decide whether to compare strings or use abbrevTree by looking at
192+ // multistream version instead.
193+ if err != nil {
194+ return s , nil
195+ }
196+
197+ proto , err := msm .abbrevTree .GetProtocolID (bytes )
198+ if err != nil {
199+ return "" , err
200+ }
201+ return proto , nil
202+ }
203+
182204func (msm * MultistreamMuxer [T ]) findHandler (proto T ) * Handler [T ] {
183205 msm .handlerlock .RLock ()
184206 defer msm .handlerlock .RUnlock ()
@@ -225,7 +247,12 @@ loop:
225247 return "" , nil , err
226248 }
227249
228- h := msm .findHandler (tok )
250+ p , err := msm .decodeProtocol (tok )
251+ if err != nil {
252+ return "" , nil , err
253+ }
254+
255+ h := msm .findHandler (p )
229256 if h == nil {
230257 if err := delimWriteBuffered (rwc , []byte ("na" )); err != nil {
231258 return "" , nil , err
@@ -236,10 +263,10 @@ loop:
236263 // Ignore the error here. We want the handshake to finish, even if the
237264 // other side has closed this rwc for writing. They may have sent us a
238265 // message and closed. Future writers will get an error anyways.
239- _ = delimWriteBuffered (rwc , []byte (tok ))
266+ _ = delimWriteBuffered (rwc , []byte (p ))
240267
241268 // hand off processing to the sub-protocol handler
242- return tok , h .Handle , nil
269+ return p , h .Handle , nil
243270 }
244271
245272}
0 commit comments