Skip to content

Commit 5e7c199

Browse files
tanordheimlorenzodonini
authored andcommitted
allow overriding charge point id resolution
1 parent 7bee131 commit 5e7c199

File tree

3 files changed

+125
-18
lines changed

3 files changed

+125
-18
lines changed

ws/mocks/mock_Server.go

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ws/server.go

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ type Server interface {
9797
//
9898
// Duplicates will be removed automatically.
9999
AddSupportedSubprotocol(subProto string)
100+
// SetChargePointIdResolver sets the callback function to use for resolving the charge point ID of a charger connecting to
101+
// the websocket server. By default, this will just be the path in the URL used by the client.
102+
SetChargePointIdResolver(resolver func(r *http.Request) (string, error))
100103
// SetBasicAuthHandler enables HTTP Basic Authentication and requires clients to pass credentials.
101104
// The handler function is called whenever a new client attempts to connect, to check for credentials correctness.
102105
// The handler must return true if the credentials were correct, false otherwise.
@@ -127,21 +130,22 @@ type Server interface {
127130
//
128131
// Use the NewServer function to create a new server.
129132
type server struct {
130-
connections map[string]*webSocket
131-
httpServer *http.Server
132-
messageHandler func(ws Channel, data []byte) error
133-
checkClientHandler CheckClientHandler
134-
newClientHandler func(ws Channel)
135-
disconnectedHandler func(ws Channel)
136-
basicAuthHandler func(username string, password string) bool
137-
tlsCertificatePath string
138-
tlsCertificateKey string
139-
timeoutConfig ServerTimeoutConfig
140-
upgrader websocket.Upgrader
141-
errC chan error
142-
connMutex sync.RWMutex
143-
addr *net.TCPAddr
144-
httpHandler *mux.Router
133+
connections map[string]*webSocket
134+
httpServer *http.Server
135+
messageHandler func(ws Channel, data []byte) error
136+
chargePointIdResolver func(*http.Request) (string, error)
137+
checkClientHandler CheckClientHandler
138+
newClientHandler func(ws Channel)
139+
disconnectedHandler func(ws Channel)
140+
basicAuthHandler func(username string, password string) bool
141+
tlsCertificatePath string
142+
tlsCertificateKey string
143+
timeoutConfig ServerTimeoutConfig
144+
upgrader websocket.Upgrader
145+
errC chan error
146+
connMutex sync.RWMutex
147+
addr *net.TCPAddr
148+
httpHandler *mux.Router
145149
}
146150

147151
// ServerOpt is a function that can be used to set options on a server during creation.
@@ -183,6 +187,10 @@ func NewServer(opts ...ServerOpt) Server {
183187
timeoutConfig: NewServerTimeoutConfig(),
184188
upgrader: websocket.Upgrader{Subprotocols: []string{}},
185189
httpHandler: router,
190+
chargePointIdResolver: func(r *http.Request) (string, error) {
191+
url := r.URL
192+
return path.Base(url.Path), nil
193+
},
186194
}
187195
for _, o := range opts {
188196
o(s)
@@ -220,6 +228,10 @@ func (s *server) AddSupportedSubprotocol(subProto string) {
220228
s.upgrader.Subprotocols = append(s.upgrader.Subprotocols, subProto)
221229
}
222230

231+
func (s *server) SetChargePointIdResolver(resolver func(r *http.Request) (string, error)) {
232+
s.chargePointIdResolver = resolver
233+
}
234+
223235
func (s *server) SetBasicAuthHandler(handler func(username string, password string) bool) {
224236
s.basicAuthHandler = handler
225237
}
@@ -343,8 +355,12 @@ func (s *server) Write(webSocketId string, data []byte) error {
343355

344356
func (s *server) wsHandler(w http.ResponseWriter, r *http.Request) {
345357
responseHeader := http.Header{}
346-
url := r.URL
347-
id := path.Base(url.Path)
358+
id, err := s.chargePointIdResolver(r)
359+
if err != nil {
360+
s.error(fmt.Errorf("failed to resolve charge point id"))
361+
http.Error(w, "NotFound", http.StatusNotFound)
362+
return
363+
}
348364
log.Debugf("handling new connection for %s from %s", id, r.RemoteAddr)
349365
// Negotiate sub-protocol
350366
clientSubProtocols := websocket.Subprotocols(r)

ws/websocket_test.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,64 @@ func (s *WebSocketSuite) TestWebsocketEcho() {
297297
}
298298
}
299299

300+
func (s *WebSocketSuite) TestWebsocketChargePointIdResolver() {
301+
connected := make(chan string)
302+
s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) {
303+
s.Fail("no message should be received from client!")
304+
return nil, nil
305+
})
306+
s.server.SetChargePointIdResolver(func(*http.Request) (string, error) {
307+
return "my-custom-id", nil
308+
})
309+
s.server.SetNewClientHandler(func(ws Channel) {
310+
connected <- ws.ID()
311+
})
312+
go s.server.Start(serverPort, serverPath)
313+
time.Sleep(500 * time.Millisecond)
314+
315+
// Test message
316+
s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) {
317+
s.Fail("no message should be received from server!")
318+
return nil, nil
319+
})
320+
321+
host := fmt.Sprintf("localhost:%v", serverPort)
322+
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
323+
// Attempt to connect and expect the custom resolved charge point id
324+
err := s.client.Start(u.String())
325+
s.NoError(err)
326+
result := <-connected
327+
s.Equal("my-custom-id", result)
328+
}
329+
330+
func (s *WebSocketSuite) TestWebsocketChargePointIdResolverFailure() {
331+
s.server = newWebsocketServer(s.T(), func(data []byte) ([]byte, error) {
332+
s.Fail("no message should be received from client!")
333+
return nil, nil
334+
})
335+
s.server.SetChargePointIdResolver(func(*http.Request) (string, error) {
336+
return "", fmt.Errorf("test error")
337+
})
338+
go s.server.Start(serverPort, serverPath)
339+
time.Sleep(500 * time.Millisecond)
340+
341+
// Test message
342+
s.client = newWebsocketClient(s.T(), func(data []byte) ([]byte, error) {
343+
s.Fail("no message should be received from server!")
344+
return nil, nil
345+
})
346+
347+
host := fmt.Sprintf("localhost:%v", serverPort)
348+
u := url.URL{Scheme: "ws", Host: host, Path: testPath}
349+
// Attempt to connect and expect the custom resolved charge point id
350+
err := s.client.Start(u.String())
351+
s.Error(err)
352+
httpErr, ok := err.(HttpConnectionError)
353+
s.True(ok)
354+
s.Equal(http.StatusNotFound, httpErr.HttpCode)
355+
s.Equal("websocket: bad handshake", httpErr.Message)
356+
}
357+
300358
func (s *WebSocketSuite) TestWebsocketBootRetries() {
301359
verifyConnection := func(client *client, connected bool) {
302360
maxAttempts := 20
@@ -1276,7 +1334,7 @@ func (s *WebSocketSuite) TestClientErrors() {
12761334
conn := s.server.connections[path.Base(testPath)]
12771335
s.NotNil(conn)
12781336
err = conn.WriteManual(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, ""))
1279-
//err = conn.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, ""))
1337+
// err = conn.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseUnsupportedData, ""))
12801338
s.NoError(err)
12811339
r = <-triggerC
12821340
s.NotNil(r)

0 commit comments

Comments
 (0)