Skip to content

Commit e3a46fb

Browse files
committed
Use cmap for message mux. Fixes #761
1 parent a87abae commit e3a46fb

File tree

3 files changed

+56
-63
lines changed

3 files changed

+56
-63
lines changed

ziti/edge/msg_mux.go

Lines changed: 52 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import (
2222
"github.com/openziti/channel/v4"
2323
"github.com/openziti/sdk-golang/inspect"
2424
"github.com/openziti/sdk-golang/xgress"
25+
cmap "github.com/orcaman/concurrent-map/v2"
2526
"github.com/pkg/errors"
2627
"github.com/sirupsen/logrus"
2728
"math"
2829
"strings"
29-
"sync"
3030
"sync/atomic"
3131
"time"
3232
)
@@ -47,28 +47,28 @@ type MsgMux interface {
4747
GetNextId() uint32
4848
}
4949

50-
func NewCowMapMsgMux() MsgMux {
51-
result := &CowMapMsgMux{
50+
func NewMapMsgMux() MsgMux {
51+
result := &MsgMuxImpl{
5252
maxId: (math.MaxUint32 / 2) - 1,
53+
sinks: cmap.NewWithCustomShardingFunction[uint32, MsgSink](func(key uint32) uint32 {
54+
return key
55+
}),
5356
}
54-
result.sinks.Store(map[uint32]MsgSink{})
5557
return result
5658
}
5759

58-
type CowMapMsgMux struct {
59-
sync.Mutex
60+
type MsgMuxImpl struct {
6061
closed atomic.Bool
61-
sinks atomic.Value
62+
sinks cmap.ConcurrentMap[uint32, MsgSink]
6263
nextId uint32
6364
minId uint32
6465
maxId uint32
6566
}
6667

67-
func (mux *CowMapMsgMux) GetNextId() uint32 {
68+
func (mux *MsgMuxImpl) GetNextId() uint32 {
6869
nextId := atomic.AddUint32(&mux.nextId, 1)
69-
sinks := mux.getSinks()
7070
for {
71-
if _, found := sinks[nextId]; found {
71+
if _, found := mux.sinks.Get(nextId); found {
7272
// if it's in use, try next one
7373
nextId = atomic.AddUint32(&mux.nextId, 1)
7474
} else if nextId < mux.minId || nextId >= mux.maxId {
@@ -82,11 +82,11 @@ func (mux *CowMapMsgMux) GetNextId() uint32 {
8282
}
8383
}
8484

85-
func (mux *CowMapMsgMux) ContentType() int32 {
85+
func (mux *MsgMuxImpl) ContentType() int32 {
8686
return ContentTypeData
8787
}
8888

89-
func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel) {
89+
func (mux *MsgMuxImpl) HandleReceive(msg *channel.Message, ch channel.Channel) {
9090
connId, found := msg.GetUint32Header(ConnIdHeader)
9191
if !found {
9292
if msg.ContentType == ContentTypeInspectRequest {
@@ -97,22 +97,45 @@ func (mux *CowMapMsgMux) HandleReceive(msg *channel.Message, ch channel.Channel)
9797
return
9898
}
9999

100-
sinks := mux.getSinks()
101-
if sink, found := sinks[connId]; found {
100+
if sink, found := mux.sinks.Get(connId); found {
102101
sink.Accept(msg)
103102
} else if msg.ContentType == ContentTypeConnInspectRequest {
104-
pfxlog.Logger().WithField("connId", connId).Trace("no conn found for connection inspect")
103+
pfxlog.Logger().WithField("connId", int(connId)).Trace("no conn found for connection inspect")
105104
resp := NewConnInspectResponse(connId, ConnTypeInvalid, fmt.Sprintf("invalid conn id [%v]", connId))
106105
if err := resp.ReplyTo(msg).Send(ch); err != nil {
107106
logrus.WithFields(GetLoggerFields(msg)).WithError(err).
108107
Error("failed to send inspect response")
109108
}
109+
} else if msg.ContentType == ContentTypeXgPayload {
110+
mux.handlePayloadWithNoSink(msg, ch)
111+
} else if msg.ContentType == ContentTypeStateClosed {
112+
// ignore, as conn is already closed
110113
} else {
111-
pfxlog.Logger().Debugf("unable to dispatch msg received for unknown edge conn id: %v", connId)
114+
pfxlog.Logger().WithField("connId", connId).WithField("contentType", msg.ContentType).
115+
Debug("unable to dispatch msg received for unknown edge conn id")
112116
}
113117
}
114118

115-
func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel) {
119+
func (mux *MsgMuxImpl) handlePayloadWithNoSink(msg *channel.Message, ch channel.Channel) {
120+
connId, _ := msg.GetUint32Header(ConnIdHeader)
121+
payload, err := xgress.UnmarshallPayload(msg)
122+
if err == nil {
123+
if payload.IsCircuitEndFlagSet() && len(payload.Data) == 0 {
124+
ack := xgress.NewAcknowledgement(payload.CircuitId, payload.GetOriginator().Invert())
125+
ackMsg := ack.Marshall()
126+
ackMsg.PutUint32Header(ConnIdHeader, connId)
127+
_, _ = ch.TrySend(msg)
128+
} else {
129+
pfxlog.Logger().WithField("connId", int(connId)).WithField("circuitId", payload.CircuitId).
130+
Debug("unable to dispatch xg payload received for unknown edge conn id")
131+
}
132+
} else {
133+
pfxlog.Logger().WithError(err).WithField("connId", int(connId)).
134+
Debug("unable to dispatch xg payload received for unknown edge conn id")
135+
}
136+
}
137+
138+
func (mux *MsgMuxImpl) HandleInspect(msg *channel.Message, ch channel.Channel) {
116139
resp := &inspect.SdkInspectResponse{
117140
Success: true,
118141
Values: make(map[string]any),
@@ -132,7 +155,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel)
132155
Circuits: make(map[string]*xgress.CircuitDetail),
133156
}
134157

135-
for _, sink := range mux.getSinks() {
158+
for _, sink := range mux.sinks.Items() {
136159
if circuitInfoSrc, ok := sink.(interface {
137160
GetCircuitDetail() *xgress.CircuitDetail
138161
}); ok {
@@ -149,7 +172,7 @@ func (mux *CowMapMsgMux) HandleInspect(msg *channel.Message, ch channel.Channel)
149172
mux.returnInspectResponse(msg, ch, resp)
150173
}
151174

152-
func (mux *CowMapMsgMux) returnInspectResponse(msg *channel.Message, ch channel.Channel, resp *inspect.SdkInspectResponse) {
175+
func (mux *MsgMuxImpl) returnInspectResponse(msg *channel.Message, ch channel.Channel, resp *inspect.SdkInspectResponse) {
153176
var sender channel.Sender = ch
154177
if mc, ok := ch.(channel.MultiChannel); ok {
155178
if sdkChan, ok := mc.GetUnderlayHandler().(SdkChannel); ok {
@@ -169,61 +192,35 @@ func (mux *CowMapMsgMux) returnInspectResponse(msg *channel.Message, ch channel.
169192
}
170193
}
171194

172-
func (mux *CowMapMsgMux) HandleClose(channel.Channel) {
195+
func (mux *MsgMuxImpl) HandleClose(channel.Channel) {
173196
mux.Close()
174197
}
175198

176-
func (mux *CowMapMsgMux) AddMsgSink(sink MsgSink) error {
199+
func (mux *MsgMuxImpl) AddMsgSink(sink MsgSink) error {
177200
if mux.closed.Load() {
178201
return errors.Errorf("mux is closed, can't add sink with id [%v]", sink.Id())
179202
}
180203

181-
var err error
182-
mux.updateSinkMap(func(m map[uint32]MsgSink) {
183-
if _, found := m[sink.Id()]; found {
184-
err = errors.Errorf("sink id %v already in use", sink.Id())
185-
} else {
186-
m[sink.Id()] = sink
187-
}
188-
})
189-
190-
// check again, just in case it was closed while we were adding
191-
if mux.closed.Load() {
192-
return errors.Errorf("mux is closed, can't add sink with id [%v]", sink.Id())
204+
if !mux.sinks.SetIfAbsent(sink.Id(), sink) {
205+
return errors.Errorf("sink id %v already in use", sink.Id())
193206
}
194-
195-
return err
207+
return nil
196208
}
197209

198-
func (mux *CowMapMsgMux) RemoveMsgSink(sink MsgSink) {
210+
func (mux *MsgMuxImpl) RemoveMsgSink(sink MsgSink) {
199211
mux.RemoveMsgSinkById(sink.Id())
200212
}
201213

202-
func (mux *CowMapMsgMux) RemoveMsgSinkById(sinkId uint32) {
203-
mux.updateSinkMap(func(m map[uint32]MsgSink) {
204-
delete(m, sinkId)
205-
})
214+
func (mux *MsgMuxImpl) RemoveMsgSinkById(sinkId uint32) {
215+
mux.sinks.Remove(sinkId)
206216
}
207217

208-
func (mux *CowMapMsgMux) updateSinkMap(f func(map[uint32]MsgSink)) {
209-
mux.Lock()
210-
defer mux.Unlock()
211-
212-
current := mux.getSinks()
213-
result := map[uint32]MsgSink{}
214-
for k, v := range current {
215-
result[k] = v
216-
}
217-
f(result)
218-
mux.sinks.Store(result)
219-
}
220-
221-
func (mux *CowMapMsgMux) Close() {
218+
func (mux *MsgMuxImpl) Close() {
222219
if mux.closed.CompareAndSwap(false, true) {
223220
// we don't need to lock the mux because due to the atomic bool, only one go-routine will enter this.
224221
// If the sink HandleMuxClose methods do anything with the mux, like remove themselves, they will acquire
225222
// their own locks
226-
sinks := mux.getSinks()
223+
sinks := mux.sinks.Items()
227224
for _, val := range sinks {
228225
if err := val.HandleMuxClose(); err != nil {
229226
pfxlog.Logger().
@@ -234,7 +231,3 @@ func (mux *CowMapMsgMux) Close() {
234231
}
235232
}
236233
}
237-
238-
func (mux *CowMapMsgMux) getSinks() map[uint32]MsgSink {
239-
return mux.sinks.Load().(map[uint32]MsgSink)
240-
}

ziti/edge/network/conn_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func BenchmarkConnWrite(b *testing.B) {
3232
closeNotify := make(chan struct{})
3333
defer close(closeNotify)
3434

35-
mux := edge.NewCowMapMsgMux()
35+
mux := edge.NewMapMsgMux()
3636
testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{})
3737
conn := &edgeConn{
3838
MsgChannel: *edge.NewEdgeMsgChannel(testChannel, 1),
@@ -58,7 +58,7 @@ func BenchmarkConnRead(b *testing.B) {
5858
closeNotify := make(chan struct{})
5959
defer close(closeNotify)
6060

61-
mux := edge.NewCowMapMsgMux()
61+
mux := edge.NewMapMsgMux()
6262
testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{})
6363

6464
readQ := NewNoopSequencer[*channel.Message](closeNotify, 4)
@@ -135,7 +135,7 @@ func TestReadMultipart(t *testing.T) {
135135
closeNotify := make(chan struct{})
136136
defer close(closeNotify)
137137

138-
mux := edge.NewCowMapMsgMux()
138+
mux := edge.NewMapMsgMux()
139139
testChannel := edge.NewSingleSdkChannel(&NoopTestChannel{})
140140

141141
readQ := NewNoopSequencer[*channel.Message](closeNotify, 4)

ziti/edge/network/factory.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func NewEdgeConnFactory(routerName, key string, owner RouterConnOwner) edge.Rout
6565
connFactory := &routerConn{
6666
key: key,
6767
routerName: routerName,
68-
msgMux: edge.NewCowMapMsgMux(),
68+
msgMux: edge.NewMapMsgMux(),
6969
owner: owner,
7070
}
7171

0 commit comments

Comments
 (0)