@@ -22,11 +22,11 @@ import (
2222 "github.com/openziti/channel/v4"
2323 "github.com/openziti/sdk-golang/inspect"
2424 "github.com/openziti/sdk-golang/xgress"
25+ cmap "github.com/orcaman/concurrent-map/v2"
2526 "github.com/pkg/errors"
2627 "github.com/sirupsen/logrus"
2728 "math"
2829 "strings"
29- "sync"
3030 "sync/atomic"
3131 "time"
3232)
@@ -47,28 +47,28 @@ type MsgMux interface {
4747 GetNextId () uint32
4848}
4949
50- func NewCowMapMsgMux () MsgMux {
51- result := & CowMapMsgMux {
50+ func NewMapMsgMux () MsgMux {
51+ result := & MsgMuxImpl {
5252 maxId : (math .MaxUint32 / 2 ) - 1 ,
53+ sinks : cmap.NewWithCustomShardingFunction [uint32 , MsgSink ](func (key uint32 ) uint32 {
54+ return key
55+ }),
5356 }
54- result .sinks .Store (map [uint32 ]MsgSink {})
5557 return result
5658}
5759
58- type CowMapMsgMux struct {
59- sync.Mutex
60+ type MsgMuxImpl struct {
6061 closed atomic.Bool
61- sinks atomic. Value
62+ sinks cmap. ConcurrentMap [ uint32 , MsgSink ]
6263 nextId uint32
6364 minId uint32
6465 maxId uint32
6566}
6667
67- func (mux * CowMapMsgMux ) GetNextId () uint32 {
68+ func (mux * MsgMuxImpl ) GetNextId () uint32 {
6869 nextId := atomic .AddUint32 (& mux .nextId , 1 )
69- sinks := mux .getSinks ()
7070 for {
71- if _ , found := sinks [ nextId ] ; found {
71+ if _ , found := mux . sinks . Get ( nextId ) ; found {
7272 // if it's in use, try next one
7373 nextId = atomic .AddUint32 (& mux .nextId , 1 )
7474 } else if nextId < mux .minId || nextId >= mux .maxId {
@@ -82,11 +82,11 @@ func (mux *CowMapMsgMux) GetNextId() uint32 {
8282 }
8383}
8484
85- func (mux * CowMapMsgMux ) ContentType () int32 {
85+ func (mux * MsgMuxImpl ) ContentType () int32 {
8686 return ContentTypeData
8787}
8888
89- func (mux * CowMapMsgMux ) HandleReceive (msg * channel.Message , ch channel.Channel ) {
89+ func (mux * MsgMuxImpl ) HandleReceive (msg * channel.Message , ch channel.Channel ) {
9090 connId , found := msg .GetUint32Header (ConnIdHeader )
9191 if ! found {
9292 if msg .ContentType == ContentTypeInspectRequest {
@@ -97,22 +97,45 @@ func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel)
9797 return
9898 }
9999
100- sinks := mux .getSinks ()
101- if sink , found := sinks [connId ]; found {
100+ if sink , found := mux .sinks .Get (connId ); found {
102101 sink .Accept (msg )
103102 } else if msg .ContentType == ContentTypeConnInspectRequest {
104- pfxlog .Logger ().WithField ("connId" , connId ).Trace ("no conn found for connection inspect" )
103+ pfxlog .Logger ().WithField ("connId" , int ( connId ) ).Trace ("no conn found for connection inspect" )
105104 resp := NewConnInspectResponse (connId , ConnTypeInvalid , fmt .Sprintf ("invalid conn id [%v]" , connId ))
106105 if err := resp .ReplyTo (msg ).Send (ch ); err != nil {
107106 logrus .WithFields (GetLoggerFields (msg )).WithError (err ).
108107 Error ("failed to send inspect response" )
109108 }
109+ } else if msg .ContentType == ContentTypeXgPayload {
110+ mux .handlePayloadWithNoSink (msg , ch )
111+ } else if msg .ContentType == ContentTypeStateClosed {
112+ // ignore, as conn is already closed
110113 } else {
111- pfxlog .Logger ().Debugf ("unable to dispatch msg received for unknown edge conn id: %v" , connId )
114+ pfxlog .Logger ().WithField ("connId" , connId ).WithField ("contentType" , msg .ContentType ).
115+ Debug ("unable to dispatch msg received for unknown edge conn id" )
112116 }
113117}
114118
115- func (mux * CowMapMsgMux ) HandleInspect (msg * channel.Message , ch channel.Channel ) {
119+ func (mux * MsgMuxImpl ) handlePayloadWithNoSink (msg * channel.Message , ch channel.Channel ) {
120+ connId , _ := msg .GetUint32Header (ConnIdHeader )
121+ payload , err := xgress .UnmarshallPayload (msg )
122+ if err == nil {
123+ if payload .IsCircuitEndFlagSet () && len (payload .Data ) == 0 {
124+ ack := xgress .NewAcknowledgement (payload .CircuitId , payload .GetOriginator ().Invert ())
125+ ackMsg := ack .Marshall ()
126+ ackMsg .PutUint32Header (ConnIdHeader , connId )
127+ _ , _ = ch .TrySend (msg )
128+ } else {
129+ pfxlog .Logger ().WithField ("connId" , int (connId )).WithField ("circuitId" , payload .CircuitId ).
130+ Debug ("unable to dispatch xg payload received for unknown edge conn id" )
131+ }
132+ } else {
133+ pfxlog .Logger ().WithError (err ).WithField ("connId" , int (connId )).
134+ Debug ("unable to dispatch xg payload received for unknown edge conn id" )
135+ }
136+ }
137+
138+ func (mux * MsgMuxImpl ) HandleInspect (msg * channel.Message , ch channel.Channel ) {
116139 resp := & inspect.SdkInspectResponse {
117140 Success : true ,
118141 Values : make (map [string ]any ),
@@ -132,7 +155,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel)
132155 Circuits : make (map [string ]* xgress.CircuitDetail ),
133156 }
134157
135- for _ , sink := range mux .getSinks () {
158+ for _ , sink := range mux .sinks . Items () {
136159 if circuitInfoSrc , ok := sink .(interface {
137160 GetCircuitDetail () * xgress.CircuitDetail
138161 }); ok {
@@ -149,7 +172,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel)
149172 mux .returnInspectResponse (msg , ch , resp )
150173}
151174
152- func (mux * CowMapMsgMux ) returnInspectResponse (msg * channel.Message , ch channel.Channel , resp * inspect.SdkInspectResponse ) {
175+ func (mux * MsgMuxImpl ) returnInspectResponse (msg * channel.Message , ch channel.Channel , resp * inspect.SdkInspectResponse ) {
153176 var sender channel.Sender = ch
154177 if mc , ok := ch .(channel.MultiChannel ); ok {
155178 if sdkChan , ok := mc .GetUnderlayHandler ().(SdkChannel ); ok {
@@ -169,61 +192,35 @@ func (mux *CowMapMsgMux) returnInspectResponse(msg *channel.Message, ch channel.
169192 }
170193}
171194
172- func (mux * CowMapMsgMux ) HandleClose (channel.Channel ) {
195+ func (mux * MsgMuxImpl ) HandleClose (channel.Channel ) {
173196 mux .Close ()
174197}
175198
176- func (mux * CowMapMsgMux ) AddMsgSink (sink MsgSink ) error {
199+ func (mux * MsgMuxImpl ) AddMsgSink (sink MsgSink ) error {
177200 if mux .closed .Load () {
178201 return errors .Errorf ("mux is closed, can't add sink with id [%v]" , sink .Id ())
179202 }
180203
181- var err error
182- mux .updateSinkMap (func (m map [uint32 ]MsgSink ) {
183- if _ , found := m [sink .Id ()]; found {
184- err = errors .Errorf ("sink id %v already in use" , sink .Id ())
185- } else {
186- m [sink .Id ()] = sink
187- }
188- })
189-
190- // check again, just in case it was closed while we were adding
191- if mux .closed .Load () {
192- return errors .Errorf ("mux is closed, can't add sink with id [%v]" , sink .Id ())
204+ if ! mux .sinks .SetIfAbsent (sink .Id (), sink ) {
205+ return errors .Errorf ("sink id %v already in use" , sink .Id ())
193206 }
194-
195- return err
207+ return nil
196208}
197209
198- func (mux * CowMapMsgMux ) RemoveMsgSink (sink MsgSink ) {
210+ func (mux * MsgMuxImpl ) RemoveMsgSink (sink MsgSink ) {
199211 mux .RemoveMsgSinkById (sink .Id ())
200212}
201213
202- func (mux * CowMapMsgMux ) RemoveMsgSinkById (sinkId uint32 ) {
203- mux .updateSinkMap (func (m map [uint32 ]MsgSink ) {
204- delete (m , sinkId )
205- })
214+ func (mux * MsgMuxImpl ) RemoveMsgSinkById (sinkId uint32 ) {
215+ mux .sinks .Remove (sinkId )
206216}
207217
208- func (mux * CowMapMsgMux ) updateSinkMap (f func (map [uint32 ]MsgSink )) {
209- mux .Lock ()
210- defer mux .Unlock ()
211-
212- current := mux .getSinks ()
213- result := map [uint32 ]MsgSink {}
214- for k , v := range current {
215- result [k ] = v
216- }
217- f (result )
218- mux .sinks .Store (result )
219- }
220-
221- func (mux * CowMapMsgMux ) Close () {
218+ func (mux * MsgMuxImpl ) Close () {
222219 if mux .closed .CompareAndSwap (false , true ) {
223220 // we don't need to lock the mux because due to the atomic bool, only one go-routine will enter this.
224221 // If the sink HandleMuxClose methods do anything with the mux, like remove themselves, they will acquire
225222 // their own locks
226- sinks := mux .getSinks ()
223+ sinks := mux .sinks . Items ()
227224 for _ , val := range sinks {
228225 if err := val .HandleMuxClose (); err != nil {
229226 pfxlog .Logger ().
@@ -234,7 +231,3 @@ func (mux *CowMapMsgMux) Close() {
234231 }
235232 }
236233}
237-
238- func (mux * CowMapMsgMux ) getSinks () map [uint32 ]MsgSink {
239- return mux .sinks .Load ().(map [uint32 ]MsgSink )
240- }
0 commit comments