Skip to content

Commit 35bfe03

Browse files
feat: avoid error when subscription ID is not present (#387)
This PR removes the error triggered when a websocket message does not contain a subscription ID. Check this issue https://github.com/issues?q=sort%3Aupdated-desc+is%3Aissue+is%3Aopen+author%3A%40me+archived%3Afalse&issue=Khan%7Cgenqlient%7C383 I have: - [x] Written a clear PR title and description (above) - [x] Signed the [Khan Academy CLA](https://www.khanacademy.org/r/cla) - [x] Added tests covering my changes, if applicable - [x] Included a link to the issue fixed, if applicable - [ ] Included documentation, for new features - [x] Added an entry to the changelog --------- Co-authored-by: Ben Kraft <[email protected]>
1 parent b1fe42c commit 35bfe03

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

docs/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ This release fixes a bug introduced in v0.8.0 breaking path resolution on Window
4141
- fixed documentation link in `introduction.md`
4242
- upgraded version of alexflint/go-arg from 1.4.2 to 1.5.1
4343
- fixed a typo in the struct + fragment error message
44+
- avoid error when a subscription message is received without a subscription ID
4445
- avoid closing subscription channels more than once, which could cause a panic in some cases
4546

4647
## v0.8.0

graphql/websocket.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ func (w *webSocketClient) forwardWebSocketData(message []byte) error {
136136
if err != nil {
137137
return err
138138
}
139+
if wsMsg.ID == "" { // e.g. keep-alive messages
140+
return nil
141+
}
139142
sub, ok := w.subscriptions.Read(wsMsg.ID)
140143
if !ok {
141144
return fmt.Errorf("received message for unknown subscription ID '%s'", wsMsg.ID)

graphql/websocket_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package graphql
2+
3+
import (
4+
"encoding/json"
5+
"sync"
6+
"testing"
7+
)
8+
9+
const testSubscriptionID = "test-subscription-id"
10+
11+
func forgeTestWebSocketClient(hasBeenUnsubscribed bool) *webSocketClient {
12+
return &webSocketClient{
13+
subscriptions: subscriptionMap{
14+
RWMutex: sync.RWMutex{},
15+
map_: map[string]subscription{
16+
testSubscriptionID: {
17+
hasBeenUnsubscribed: hasBeenUnsubscribed,
18+
interfaceChan: make(chan any),
19+
forwardDataFunc: func(interfaceChan any, jsonRawMsg json.RawMessage) error {
20+
return nil
21+
},
22+
},
23+
},
24+
},
25+
}
26+
}
27+
28+
func Test_webSocketClient_forwardWebSocketData(t *testing.T) {
29+
type args struct {
30+
message []byte
31+
}
32+
tests := []struct {
33+
wc *webSocketClient
34+
name string
35+
args args
36+
wantErr bool
37+
}{
38+
{
39+
name: "empty message",
40+
args: args{message: []byte{}},
41+
wc: forgeTestWebSocketClient(false),
42+
wantErr: true,
43+
},
44+
{
45+
name: "nil message",
46+
args: args{message: nil},
47+
wc: forgeTestWebSocketClient(false),
48+
wantErr: true,
49+
},
50+
{
51+
name: "unknown subscription id",
52+
args: args{message: []byte(`{"type":"next","id":"unknown-id","payload":{}}`)},
53+
wc: forgeTestWebSocketClient(false),
54+
wantErr: true,
55+
},
56+
{
57+
name: "void subscription ID",
58+
args: args{message: []byte(`{"type":"next","id":"","payload":{}}`)},
59+
wc: forgeTestWebSocketClient(false),
60+
wantErr: false,
61+
},
62+
{
63+
name: "unsubscribed subscription",
64+
args: args{message: []byte(`{"type":"next","id":"test-subscription-id","payload":{}}`)},
65+
wc: forgeTestWebSocketClient(true),
66+
wantErr: false,
67+
},
68+
{
69+
name: "complete message closes channel",
70+
args: args{message: []byte(`{"type":"complete","id":"test-subscription-id","payload":{}}`)},
71+
wc: forgeTestWebSocketClient(false),
72+
wantErr: false,
73+
},
74+
{
75+
name: "valid next message",
76+
args: args{message: []byte(`{"type":"next","id":"test-subscription-id","payload":{"foo":"bar"}}`)},
77+
wc: forgeTestWebSocketClient(false),
78+
wantErr: false,
79+
},
80+
}
81+
for i := range tests {
82+
tt := &tests[i]
83+
t.Run(tt.name, func(t *testing.T) {
84+
t.Logf("Running test: %s", tt.name)
85+
86+
if err := tt.wc.forwardWebSocketData(tt.args.message); (err != nil) != tt.wantErr {
87+
t.Errorf("%s: webSocketClient.forwardWebSocketData() error = %v, wantErr %v", tt.name, err, tt.wantErr)
88+
return
89+
}
90+
})
91+
}
92+
}

0 commit comments

Comments
 (0)