Skip to content

Commit 46fe7ab

Browse files
authored
Merge pull request #84 from gojek/feature/otel-add-client-id
feat: add client id attribute on otel metrics
2 parents 8b1755c + ff96567 commit 46fe7ab

File tree

13 files changed

+293
-81
lines changed

13 files changed

+293
-81
lines changed

client.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func NewClient(opts ...ClientOption) (*Client, error) {
9696
func (c *Client) IsConnected() bool {
9797
val := &atomic.Bool{}
9898

99-
return c.execute(func(cc mqtt.Client) error {
99+
return c.execute(context.Background(), func(cc mqtt.Client) error {
100100
if cc.IsConnectionOpen() {
101101
val.CompareAndSwap(false, true)
102102
}
@@ -136,7 +136,7 @@ func (c *Client) Run(ctx context.Context) error {
136136
}
137137

138138
func (c *Client) stop() error {
139-
err := c.execute(func(cc mqtt.Client) error {
139+
err := c.execute(context.Background(), func(cc mqtt.Client) error {
140140
cc.Disconnect(uint(c.options.gracefulShutdownPeriod / time.Millisecond))
141141

142142
return nil
@@ -215,7 +215,7 @@ func (c *Client) runResolver() error {
215215
}
216216

217217
func (c *Client) runConnect() error {
218-
err := c.execute(func(cc mqtt.Client) error {
218+
err := c.execute(context.Background(), func(cc mqtt.Client) error {
219219
t := cc.Connect()
220220
if !t.WaitTimeout(c.options.connectTimeout) {
221221
return ErrConnectTimeout

client_publish.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func publishHandler(c *Client) Publisher {
3737

3838
o := composeOptions(opts)
3939

40-
return c.execute(func(cc mqtt.Client) error {
40+
return c.execute(ctx, func(cc mqtt.Client) error {
4141
return c.handleToken(ctx, cc.Publish(topic, o.qos, o.retained, buf.Bytes()), ErrPublishTimeout)
4242
}, execOneRoundRobin)
4343
})

client_subscribe.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func subscriberFuncs(c *Client) Subscriber {
8181
eo = subscribeOnlyOnce(topic)
8282
}
8383

84-
return c.execute(func(cc mqtt.Client) error {
84+
return c.execute(ctx, func(cc mqtt.Client) error {
8585
return c.handleToken(ctx, cc.Subscribe(topic, o.qos, callbackWrapper(c, callback)), ErrSubscribeTimeout)
8686
}, eo)
8787
},
@@ -92,7 +92,7 @@ func subscriberFuncs(c *Client) Subscriber {
9292

9393
if len(sharedSubs) > 0 {
9494
execs = append(execs, func(ctx context.Context) error {
95-
return c.execute(func(cc mqtt.Client) error {
95+
return c.execute(ctx, func(cc mqtt.Client) error {
9696
return c.handleToken(ctx, cc.SubscribeMultiple(
9797
sharedSubs,
9898
callbackWrapper(c, callback),
@@ -103,7 +103,7 @@ func subscriberFuncs(c *Client) Subscriber {
103103

104104
if len(normalSubs) > 0 {
105105
execs = append(execs, func(ctx context.Context) error {
106-
return c.execute(func(cc mqtt.Client) error {
106+
return c.execute(ctx, func(cc mqtt.Client) error {
107107
return c.handleToken(ctx, cc.SubscribeMultiple(
108108
normalSubs,
109109
callbackWrapper(c, callback),
@@ -120,8 +120,9 @@ func subscriberFuncs(c *Client) Subscriber {
120120
}
121121

122122
func callbackWrapper(c *Client, callback MessageHandler) mqtt.MessageHandler {
123-
return func(_ mqtt.Client, m mqtt.Message) {
123+
return func(cc mqtt.Client, m mqtt.Message) {
124124
ctx := context.Background()
125+
ctx = withClientID(ctx, clientIDMapper(cc))
125126

126127
msg := NewMessageWithDecoder(
127128
c.options.newDecoder(ctx, bytes.NewReader(m.Payload())),

client_telemetry.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package courier
22

33
import (
4+
"context"
45
"net/url"
56
"sort"
67
"strconv"
@@ -96,6 +97,7 @@ func (c *Client) multiClientInfo() []MQTTClientInfo {
9697
bCh := make(chan MQTTClientInfo, len(cls))
9798

9899
_ = c.execute(
100+
context.Background(),
99101
func(cc mqtt.Client) error { return nil },
100102
execOptWithState(func(f func(mqtt.Client) error, is *internalState) error {
101103
is.mu.Lock()
@@ -133,7 +135,7 @@ func (c *Client) singleClientInfo() []MQTTClientInfo {
133135

134136
var bi MQTTClientInfo
135137

136-
_ = c.execute(func(cc mqtt.Client) error {
138+
_ = c.execute(context.Background(), func(cc mqtt.Client) error {
137139
bi = transformClientInfo(cc)
138140

139141
return nil

client_unsubscribe.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func (c *Client) UseUnsubscriberMiddleware(mwf ...UnsubscriberMiddlewareFunc) {
3838

3939
func unsubscriberHandler(c *Client) Unsubscriber {
4040
return UnsubscriberFunc(func(ctx context.Context, topics ...string) error {
41-
return c.execute(func(cc mqtt.Client) error {
41+
return c.execute(ctx, func(cc mqtt.Client) error {
4242
return c.handleToken(ctx, cc.Unsubscribe(topics...), ErrUnsubscribeTimeout)
4343
}, removeSubsFromState(topics...))
4444
})

context.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package courier
2+
3+
import "context"
4+
5+
type clientIDCallbackKey struct{}
6+
type clientIDSubscribedKey struct{}
7+
8+
// ClientIDCallback is a function that receives the client ID used for an operation.
9+
type ClientIDCallback func(clientID string)
10+
11+
// WithClientIDCallback returns a context with the callback that will receive
12+
// the client ID of the underlying MQTT client used for an operation.
13+
func WithClientIDCallback(ctx context.Context, cb ClientIDCallback) context.Context {
14+
return context.WithValue(ctx, clientIDCallbackKey{}, cb)
15+
}
16+
17+
func invokeClientIDCallback(ctx context.Context, clientID string) {
18+
if cb, ok := ctx.Value(clientIDCallbackKey{}).(ClientIDCallback); ok && cb != nil {
19+
cb(clientID)
20+
}
21+
}
22+
23+
func withClientID(ctx context.Context, clientID string) context.Context {
24+
return context.WithValue(ctx, clientIDSubscribedKey{}, clientID)
25+
}
26+
27+
// ClientIDFromContext returns the client ID from the context.
28+
// This is available in subscribe callbacks to identify which MQTT connection received the message.
29+
func ClientIDFromContext(ctx context.Context) string {
30+
if id, ok := ctx.Value(clientIDSubscribedKey{}).(string); ok {
31+
return id
32+
}
33+
34+
return ""
35+
}

context_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package courier
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestWithClientIDCallback(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
invokeID string
14+
wantID string
15+
callback bool
16+
}{
17+
{
18+
name: "Invokes Callback",
19+
invokeID: "test-client",
20+
wantID: "test-client",
21+
callback: true,
22+
},
23+
{
24+
name: "Returns Empty String on nil Callback fn",
25+
invokeID: "",
26+
wantID: "",
27+
callback: false,
28+
},
29+
}
30+
31+
for _, tt := range tests {
32+
t.Run(tt.name, func(t *testing.T) {
33+
var capturedID string
34+
var ctx context.Context
35+
if tt.callback {
36+
ctx = WithClientIDCallback(context.Background(), func(clientID string) {
37+
capturedID = clientID
38+
})
39+
} else {
40+
ctx = WithClientIDCallback(context.Background(), nil)
41+
}
42+
43+
invokeClientIDCallback(ctx, tt.invokeID)
44+
assert.Equal(t, tt.wantID, capturedID)
45+
})
46+
}
47+
}
48+
49+
func TestClientIDFromContext(t *testing.T) {
50+
tests := []struct {
51+
name string
52+
ctx context.Context
53+
wantID string
54+
}{
55+
{
56+
name: "Returns Client ID",
57+
ctx: withClientID(context.Background(), "test-subscribe"),
58+
wantID: "test-subscribe",
59+
},
60+
{
61+
name: "Returns Empty String on Missing Client ID",
62+
ctx: context.Background(),
63+
wantID: "",
64+
},
65+
{
66+
name: "Returns Empty String on Nil Client ID",
67+
ctx: context.WithValue(context.Background(), clientIDSubscribedKey{}, nil),
68+
wantID: "",
69+
},
70+
}
71+
72+
for _, tt := range tests {
73+
t.Run(tt.name, func(t *testing.T) {
74+
assert.Equal(t, tt.wantID, ClientIDFromContext(tt.ctx))
75+
})
76+
}
77+
}

docs/docs/sdk/SDK.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ Package courier contains the client that can be used to interact with the courie
1111
## Index
1212

1313
- [Variables](#variables)
14+
- [func ClientIDFromContext\(ctx context.Context\) string](#ClientIDFromContext)
1415
- [func ExponentialStartStrategy\(ctx context.Context, c interface\{ Start\(\) error \}, opts ...StartOption\)](#ExponentialStartStrategy)
1516
- [func Version\(\) string](#Version)
1617
- [func WaitForConnection\(c ConnectionInformer, waitFor time.Duration, tick time.Duration\) bool](#WaitForConnection)
18+
- [func WithClientIDCallback\(ctx context.Context, cb ClientIDCallback\) context.Context](#WithClientIDCallback)
1719
- [type Client](#Client)
1820
- [func NewClient\(opts ...ClientOption\) \(\*Client, error\)](#NewClient)
1921
- [func \(c \*Client\) AckTimeout\(\) time.Duration](#Client.AckTimeout)
@@ -34,6 +36,7 @@ Package courier contains the client that can be used to interact with the courie
3436
- [func \(c \*Client\) UseSubscriberMiddleware\(mwf ...SubscriberMiddlewareFunc\)](#Client.UseSubscriberMiddleware)
3537
- [func \(c \*Client\) UseUnsubscriberMiddleware\(mwf ...UnsubscriberMiddlewareFunc\)](#Client.UseUnsubscriberMiddleware)
3638
- [func \(c \*Client\) WriteTimeout\(\) time.Duration](#Client.WriteTimeout)
39+
- [type ClientIDCallback](#ClientIDCallback)
3740
- [type ClientInfoEmitter](#ClientInfoEmitter)
3841
- [type ClientInfoEmitterConfig](#ClientInfoEmitterConfig)
3942
- [type ClientMeta](#ClientMeta)
@@ -164,6 +167,15 @@ This is useful when working with shared subscriptions and multiple connections c
164167
var UseMultiConnectionMode = multiConnMode{}
165168
```
166169

170+
<a name="ClientIDFromContext"></a>
171+
## func [ClientIDFromContext](https://github.com/gojek/courier-go/blob/main/context.go#L29)
172+
173+
```go
174+
func ClientIDFromContext(ctx context.Context) string
175+
```
176+
177+
ClientIDFromContext returns the client ID from the context. This is available in subscribe callbacks to identify which MQTT connection received the message.
178+
167179
<a name="ExponentialStartStrategy"></a>
168180
## func [ExponentialStartStrategy](https://github.com/gojek/courier-go/blob/main/exp_starter.go#L32)
169181

@@ -191,6 +203,15 @@ func WaitForConnection(c ConnectionInformer, waitFor time.Duration, tick time.Du
191203

192204
WaitForConnection checks if the Client is connected, it calls ConnectionInformer.IsConnected after every tick and waitFor is the maximum duration it can block. Returns true only when ConnectionInformer.IsConnected returns true
193205

206+
<a name="WithClientIDCallback"></a>
207+
## func [WithClientIDCallback](https://github.com/gojek/courier-go/blob/main/context.go#L13)
208+
209+
```go
210+
func WithClientIDCallback(ctx context.Context, cb ClientIDCallback) context.Context
211+
```
212+
213+
WithClientIDCallback returns a context with the callback that will receive the client ID of the underlying MQTT client used for an operation.
214+
194215
<a name="Client"></a>
195216
## type [Client](https://github.com/gojek/courier-go/blob/main/client.go#L22-L46)
196217

@@ -423,6 +444,15 @@ func (c *Client) WriteTimeout() time.Duration
423444

424445
WriteTimeout returns the write timeout duration configured for the client
425446

447+
<a name="ClientIDCallback"></a>
448+
## type [ClientIDCallback](https://github.com/gojek/courier-go/blob/main/context.go#L9)
449+
450+
ClientIDCallback is a function that receives the client ID used for an operation.
451+
452+
```go
453+
type ClientIDCallback func(clientID string)
454+
```
455+
426456
<a name="ClientInfoEmitter"></a>
427457
## type [ClientInfoEmitter](https://github.com/gojek/courier-go/blob/main/metrics.go#L17-L19)
428458

@@ -880,7 +910,7 @@ type Logger interface {
880910
```
881911

882912
<a name="MQTTClientInfo"></a>
883-
## type [MQTTClientInfo](https://github.com/gojek/courier-go/blob/main/client_telemetry.go#L14-L25)
913+
## type [MQTTClientInfo](https://github.com/gojek/courier-go/blob/main/client_telemetry.go#L15-L26)
884914

885915
MQTTClientInfo contains information about the internal MQTT client
886916

exec.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package courier
22

33
import (
4+
"context"
45
"errors"
56
"math/rand"
67
"sync"
@@ -48,7 +49,7 @@ func (ac *atomicCounter) next() uint64 {
4849
return current
4950
}
5051

51-
func (c *Client) execute(f func(mqtt.Client) error, eo execOpt) error {
52+
func (c *Client) execute(ctx context.Context, f func(mqtt.Client) error, eo execOpt) error {
5253
c.clientMu.RLock()
5354
defer c.clientMu.RUnlock()
5455

@@ -57,13 +58,15 @@ func (c *Client) execute(f func(mqtt.Client) error, eo execOpt) error {
5758
}
5859

5960
if c.options.poolEnabled || c.options.multiConnectionMode {
60-
return c.execMultiConn(f, eo)
61+
return c.execMultiConn(ctx, f, eo)
6162
}
6263

64+
invokeClientIDCallback(ctx, clientIDMapper(c.mqttClient))
65+
6366
return f(c.mqttClient)
6467
}
6568

66-
func (c *Client) execMultiConn(f func(mqtt.Client) error, eo execOpt) error {
69+
func (c *Client) execMultiConn(ctx context.Context, f func(mqtt.Client) error, eo execOpt) error {
6770
var ccs []*internalState
6871

6972
if eo == execOneRoundRobin {
@@ -79,11 +82,17 @@ func (c *Client) execMultiConn(f func(mqtt.Client) error, eo execOpt) error {
7982
p := c.rndPool.Get().(*rand.Rand)
8083
defer c.rndPool.Put(p)
8184

82-
return f(ccs[p.Intn(len(ccs))].client)
85+
cc := ccs[p.Intn(len(ccs))].client
86+
invokeClientIDCallback(ctx, clientIDMapper(cc))
87+
88+
return f(cc)
8389
}
8490

8591
if eo == execOneRoundRobin {
86-
return f(ccs[int(c.rrCounter.next())%len(ccs)].client)
92+
cc := ccs[int(c.rrCounter.next())%len(ccs)].client
93+
invokeClientIDCallback(ctx, clientIDMapper(cc))
94+
95+
return f(cc)
8796
}
8897

8998
return slice.Reduce(slice.MapConcurrent(ccs, func(s *internalState) error {

exec_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package courier
22

33
import (
4+
"context"
45
"testing"
56

67
"github.com/stretchr/testify/assert"
@@ -39,7 +40,7 @@ type invalidExecOpt bool
3940
func (ieo invalidExecOpt) isExecOpt() {}
4041

4142
func TestClient_execMultiConn_invalidExecOption(t *testing.T) {
42-
assert.EqualError(t, new(Client).execMultiConn(func(mqtt.Client) error {
43+
assert.EqualError(t, new(Client).execMultiConn(context.Background(), func(mqtt.Client) error {
4344
return nil
4445
}, invalidExecOpt(true)), errInvalidExecOpt.Error())
4546
}

0 commit comments

Comments
 (0)