Skip to content

Commit b58f6c3

Browse files
authored
Merge pull request #13 from functionalfoundry/jannis/auth-token-and-user
Add support for auth tokens and users associated with connections
2 parents 246b6af + 73fa495 commit b58f6c3

File tree

3 files changed

+106
-47
lines changed

3 files changed

+106
-47
lines changed

connections.go

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package graphqlws
33
import (
44
"encoding/json"
55
"errors"
6+
"fmt"
67
"time"
78

89
"github.com/google/uuid"
@@ -30,6 +31,12 @@ const (
3031
writeTimeout = 10 * time.Second
3132
)
3233

34+
// InitMessagePayload defines the parameters of a connection
35+
// init message.
36+
type InitMessagePayload struct {
37+
AuthToken string `json:"authToken"`
38+
}
39+
3340
// StartMessagePayload defines the parameters of an operation that
3441
// a client requests to be started.
3542
type StartMessagePayload struct {
@@ -59,6 +66,10 @@ func (msg OperationMessage) String() string {
5966
return "<invalid>"
6067
}
6168

69+
// UserFromAuthTokenFunc is a function that resolves an auth token
70+
// into a user (or returns an error if that isn't possible).
71+
type UserFromAuthTokenFunc func(token string) (interface{}, error)
72+
6273
// ConnectionEventHandlers define the event handlers for a connection.
6374
// Event handlers allow other system components to react to events such
6475
// as the connection closing or an operation being started or stopped.
@@ -81,12 +92,22 @@ type ConnectionEventHandlers struct {
8192
StopOperation func(Connection, string)
8293
}
8394

95+
// ConnectionConfig defines the configuration parameters of a
96+
// GraphQL WebSocket connection.
97+
type ConnectionConfig struct {
98+
UserFromAuthToken UserFromAuthTokenFunc
99+
EventHandlers ConnectionEventHandlers
100+
}
101+
84102
// Connection is an interface to represent GraphQL WebSocket connections.
85103
// Each connection is associated with an ID that is unique to the server.
86104
type Connection interface {
87105
// ID returns the unique ID of the connection.
88106
ID() string
89107

108+
// User returns the user associated with the connection (or nil).
109+
User() interface{}
110+
90111
// SendData sends results of executing an operation (typically a
91112
// subscription) to the client.
92113
SendData(string, *DataMessagePayload)
@@ -100,11 +121,12 @@ type Connection interface {
100121
*/
101122

102123
type connection struct {
103-
id string
104-
ws *websocket.Conn
105-
eventHandlers *ConnectionEventHandlers
106-
logger *log.Entry
107-
outgoing chan OperationMessage
124+
id string
125+
ws *websocket.Conn
126+
config ConnectionConfig
127+
logger *log.Entry
128+
outgoing chan OperationMessage
129+
user interface{}
108130
}
109131

110132
func operationMessageForType(messageType string) OperationMessage {
@@ -116,11 +138,11 @@ func operationMessageForType(messageType string) OperationMessage {
116138
// NewConnection establishes a GraphQL WebSocket connection. It implements
117139
// the GraphQL WebSocket protocol by managing its internal state and handling
118140
// the client-server communication.
119-
func NewConnection(ws *websocket.Conn, eventHandlers *ConnectionEventHandlers) Connection {
141+
func NewConnection(ws *websocket.Conn, config ConnectionConfig) Connection {
120142
conn := new(connection)
121143
conn.id = uuid.New().String()
122144
conn.ws = ws
123-
conn.eventHandlers = eventHandlers
145+
conn.config = config
124146
conn.logger = NewLogger("connection/" + conn.id)
125147

126148
conn.outgoing = make(chan OperationMessage)
@@ -137,6 +159,10 @@ func (conn *connection) ID() string {
137159
return conn.id
138160
}
139161

162+
func (conn *connection) User() interface{} {
163+
return conn.user
164+
}
165+
140166
func (conn *connection) SendData(opID string, data *DataMessagePayload) {
141167
msg := operationMessageForType(gqlData)
142168
msg.ID = opID
@@ -162,8 +188,8 @@ func (conn *connection) close() {
162188
close(conn.outgoing)
163189

164190
// Notify event handlers
165-
if conn.eventHandlers != nil {
166-
conn.eventHandlers.Close(conn)
191+
if conn.config.EventHandlers.Close != nil {
192+
conn.config.EventHandlers.Close(conn)
167193
}
168194

169195
conn.logger.Info("Closed connection")
@@ -238,16 +264,31 @@ func (conn *connection) readLoop() {
238264

239265
// When the GraphQL WS connection is initiated, send an ACK back
240266
case gqlConnectionInit:
241-
conn.outgoing <- operationMessageForType(gqlConnectionAck)
267+
data := InitMessagePayload{}
268+
if err := json.Unmarshal(rawPayload, &data); err != nil {
269+
conn.SendError(errors.New("Invalid GQL_CONNECTION_INIT payload"))
270+
} else {
271+
if conn.config.UserFromAuthToken != nil {
272+
user, err := conn.config.UserFromAuthToken(data.AuthToken)
273+
if err != nil {
274+
conn.SendError(fmt.Errorf("Failed to authenticate user: %v", err))
275+
} else {
276+
conn.user = user
277+
conn.outgoing <- operationMessageForType(gqlConnectionAck)
278+
}
279+
} else {
280+
conn.outgoing <- operationMessageForType(gqlConnectionAck)
281+
}
282+
}
242283

243284
// Let event handlers deal with starting operations
244285
case gqlStart:
245-
if conn.eventHandlers != nil {
286+
if conn.config.EventHandlers.StartOperation != nil {
246287
data := StartMessagePayload{}
247288
if err := json.Unmarshal(rawPayload, &data); err != nil {
248289
conn.SendError(errors.New("Invalid GQL_START payload"))
249290
} else {
250-
errs := conn.eventHandlers.StartOperation(conn, msg.ID, &data)
291+
errs := conn.config.EventHandlers.StartOperation(conn, msg.ID, &data)
251292
if errs != nil {
252293
conn.sendOperationErrors(msg.ID, errs)
253294
}
@@ -256,8 +297,8 @@ func (conn *connection) readLoop() {
256297

257298
// Let event handlers deal with stopping operations
258299
case gqlStop:
259-
if conn.eventHandlers != nil {
260-
conn.eventHandlers.StopOperation(conn, msg.ID)
300+
if conn.config.EventHandlers.StopOperation != nil {
301+
conn.config.EventHandlers.StopOperation(conn, msg.ID)
261302
}
262303

263304
// When the GraphQL WS connection is terminated by the client,

examples/simple-server/main.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ func main() {
3131

3232
// Create subscription manager and GraphQL WS handler
3333
subscriptionManager := graphqlws.NewSubscriptionManager(&schema)
34-
websocketHandler := graphqlws.NewHandler(subscriptionManager)
34+
websocketHandler := graphqlws.NewHandler(graphqlws.HandlerConfig{
35+
SubscriptionManager: subscriptionManager,
36+
UserFromAuthToken: func(token string) (interface{}, error) {
37+
return "Default user", nil
38+
},
39+
})
3540

3641
// Serve the GraphQL WS endpoint
3742
http.Handle("/subscriptions", websocketHandler)

handler.go

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,16 @@ import (
77
log "github.com/sirupsen/logrus"
88
)
99

10+
// HandlerConfig stores the configuration of a GraphQL WebSocket handler.
11+
type HandlerConfig struct {
12+
SubscriptionManager SubscriptionManager
13+
UserFromAuthToken UserFromAuthTokenFunc
14+
}
15+
1016
// NewHandler creates a WebSocket handler for GraphQL WebSocket connections.
1117
// This handler takes a SubscriptionManager and adds/removes subscriptions
1218
// as they are started/stopped by the client.
13-
func NewHandler(subscriptionManager SubscriptionManager) http.Handler {
19+
func NewHandler(config HandlerConfig) http.Handler {
1420
// Create a WebSocket upgrader that requires clients to implement
1521
// the "graphql-ws" protocol
1622
var upgrader = websocket.Upgrader{
@@ -19,6 +25,7 @@ func NewHandler(subscriptionManager SubscriptionManager) http.Handler {
1925
}
2026

2127
logger := NewLogger("handler")
28+
subscriptionManager := config.SubscriptionManager
2229

2330
// Create a map (used like a set) to manage client connections
2431
var connections = make(map[Connection]bool)
@@ -42,40 +49,46 @@ func NewHandler(subscriptionManager SubscriptionManager) http.Handler {
4249
}
4350

4451
// Establish a GraphQL WebSocket connection
45-
conn := NewConnection(ws, &ConnectionEventHandlers{
46-
Close: func(conn Connection) {
47-
logger.WithFields(log.Fields{
48-
"conn": conn.ID(),
49-
}).Debug("Closing connection")
52+
conn := NewConnection(ws, ConnectionConfig{
53+
UserFromAuthToken: config.UserFromAuthToken,
54+
EventHandlers: ConnectionEventHandlers{
55+
Close: func(conn Connection) {
56+
logger.WithFields(log.Fields{
57+
"conn": conn.ID(),
58+
"user": conn.User(),
59+
}).Debug("Closing connection")
5060

51-
subscriptionManager.RemoveSubscriptions(conn)
61+
subscriptionManager.RemoveSubscriptions(conn)
5262

53-
delete(connections, conn)
54-
},
55-
StartOperation: func(
56-
conn Connection,
57-
opID string,
58-
data *StartMessagePayload,
59-
) []error {
60-
logger.WithFields(log.Fields{
61-
"conn": conn.ID(),
62-
"op": opID,
63-
}).Debug("Start operation")
63+
delete(connections, conn)
64+
},
65+
StartOperation: func(
66+
conn Connection,
67+
opID string,
68+
data *StartMessagePayload,
69+
) []error {
70+
logger.WithFields(log.Fields{
71+
"conn": conn.ID(),
72+
"op": opID,
73+
"user": conn.User(),
74+
}).Debug("Start operation")
6475

65-
return subscriptionManager.AddSubscription(conn, &Subscription{
66-
ID: opID,
67-
Query: data.Query,
68-
Variables: data.Variables,
69-
OperationName: data.OperationName,
70-
SendData: func(subscription *Subscription, data *DataMessagePayload) {
71-
conn.SendData(opID, data)
72-
},
73-
})
74-
},
75-
StopOperation: func(conn Connection, opID string) {
76-
subscriptionManager.RemoveSubscription(conn, &Subscription{
77-
ID: opID,
78-
})
76+
return subscriptionManager.AddSubscription(conn, &Subscription{
77+
ID: opID,
78+
Query: data.Query,
79+
Variables: data.Variables,
80+
OperationName: data.OperationName,
81+
Connection: conn,
82+
SendData: func(subscription *Subscription, data *DataMessagePayload) {
83+
conn.SendData(opID, data)
84+
},
85+
})
86+
},
87+
StopOperation: func(conn Connection, opID string) {
88+
subscriptionManager.RemoveSubscription(conn, &Subscription{
89+
ID: opID,
90+
})
91+
},
7992
},
8093
})
8194
connections[conn] = true

0 commit comments

Comments
 (0)