@@ -5,6 +5,7 @@ package multistream
55
66import (
77 "bufio"
8+ "encoding/hex"
89 "errors"
910 "fmt"
1011 "io"
@@ -18,6 +19,9 @@ import (
1819// ErrTooLarge is an error to signal that an incoming message was too large
1920var ErrTooLarge = errors .New ("incoming message was too large" )
2021
22+ // ErrUnknownPrefix is an error to signal that the protocol hash prefix is unknown
23+ var ErrUnknownPrefix = errors .New ("unknown protocol hash prefix" )
24+
2125// ProtocolID identifies the multistream protocol itself and makes sure
2226// the multistream muxers on both sides of a channel can work with each other.
2327const ProtocolID = "/multistream/1.0.0"
@@ -55,6 +59,7 @@ type Handler[T StringLike] struct {
5559type MultistreamMuxer [T StringLike ] struct {
5660 handlerlock sync.RWMutex
5761 handlers []Handler [T ]
62+ abbrevTree abbrevTree [T ]
5863}
5964
6065// NewMultistreamMuxer creates a muxer.
@@ -137,6 +142,7 @@ func (msm *MultistreamMuxer[T]) AddHandlerWithFunc(protocol T, match func(T) boo
137142 msm .handlerlock .Lock ()
138143 defer msm .handlerlock .Unlock ()
139144
145+ msm .abbrevTree .AddProtocol (protocol )
140146 msm .removeHandler (protocol )
141147 msm .handlers = append (msm .handlers , Handler [T ]{
142148 MatchFunc : match ,
@@ -150,6 +156,7 @@ func (msm *MultistreamMuxer[T]) RemoveHandler(protocol T) {
150156 msm .handlerlock .Lock ()
151157 defer msm .handlerlock .Unlock ()
152158
159+ msm .abbrevTree .RemoveProtocol (protocol )
153160 msm .removeHandler (protocol )
154161}
155162
@@ -179,6 +186,24 @@ func (msm *MultistreamMuxer[T]) Protocols() []T {
179186// fails because of a ProtocolID mismatch.
180187var ErrIncorrectVersion = errors .New ("client connected with incorrect version" )
181188
189+ func (msm * MultistreamMuxer [T ]) decodeProtocol (s T ) (T , error ) {
190+ msm .handlerlock .RLock ()
191+ defer msm .handlerlock .RUnlock ()
192+
193+ bytes , err := hex .DecodeString (string (s ))
194+ // TODO: decide whether to compare strings or use abbrevTree by looking at
195+ // multistream version instead.
196+ if err != nil {
197+ return s , nil
198+ }
199+
200+ proto , err := msm .abbrevTree .GetProtocolID (bytes )
201+ if err != nil {
202+ return "" , err
203+ }
204+ return proto , nil
205+ }
206+
182207func (msm * MultistreamMuxer [T ]) findHandler (proto T ) * Handler [T ] {
183208 msm .handlerlock .RLock ()
184209 defer msm .handlerlock .RUnlock ()
@@ -225,7 +250,12 @@ loop:
225250 return "" , nil , err
226251 }
227252
228- h := msm .findHandler (tok )
253+ p , err := msm .decodeProtocol (tok )
254+ if err != nil {
255+ return "" , nil , err
256+ }
257+
258+ h := msm .findHandler (p )
229259 if h == nil {
230260 if err := delimWriteBuffered (rwc , []byte ("na" )); err != nil {
231261 return "" , nil , err
@@ -239,7 +269,7 @@ loop:
239269 _ = delimWriteBuffered (rwc , []byte (tok ))
240270
241271 // hand off processing to the sub-protocol handler
242- return tok , h .Handle , nil
272+ return p , h .Handle , nil
243273 }
244274
245275}
0 commit comments