Skip to content

Commit 3fa7097

Browse files
lkedzioraglatosinski
authored andcommitted
[#72940] linux-client: Extract websocket methods into separate package
This extracts the RDFM-specific WebSocket connection into a separate package. This is in preparation for re-using existing RDFM WS connection code for shell output WS. The existing WebSocket connection management struct was renamed to DeviceManagementConnection to avoid confusion between the two. Signed-off-by: Łukasz Kędziora <lkedziora@antmicro.com>
1 parent 4599df4 commit 3fa7097

File tree

5 files changed

+81
-94
lines changed

5 files changed

+81
-94
lines changed

devices/linux-client/daemon/device.go

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@ import (
2626
"github.com/antmicro/rdfm/actions"
2727
"github.com/antmicro/rdfm/app"
2828
"github.com/antmicro/rdfm/conf"
29+
"github.com/antmicro/rdfm/serverws"
2930
"github.com/antmicro/rdfm/telemetry"
30-
"github.com/gorilla/websocket"
3131

3232
netUtils "github.com/antmicro/rdfm/daemon/net_utils"
33-
requests "github.com/antmicro/rdfm/daemon/requests"
3433
)
3534

3635
const RSA_DEVICE_KEY_SIZE = 4096
@@ -47,11 +46,11 @@ type Device struct {
4746
tokenMutex sync.Mutex
4847
httpTransport *http.Transport
4948
logManager *telemetry.LogManager
50-
conn *DeviceConnection
49+
conn *serverws.DeviceManagementConnection
5150
actionRunner *actions.ActionRunner
5251
}
5352

54-
func (d *Device) handleRequest(msg []byte) (requests.Request, error) {
53+
func (d *Device) handleRequest(msg []byte) (serverws.Request, error) {
5554
var msgMap map[string]interface{}
5655

5756
err := json.Unmarshal(msg, &msgMap)
@@ -63,18 +62,18 @@ func (d *Device) handleRequest(msg []byte) (requests.Request, error) {
6362

6463
log.Infof("Handling '%s' request...", requestName)
6564

66-
request, err := requests.Parse(string(msg[:]))
65+
request, err := serverws.Parse(string(msg[:]))
6766
if err != nil {
6867
return nil, err
6968
}
7069

7170
switch r := request.(type) {
72-
case requests.Alert:
71+
case serverws.Alert:
7372
for key, val := range r.Alert {
7473
log.Printf("Server sent %s: %s", key, val)
7574
}
76-
case requests.ActionExec:
77-
response := requests.ActionExecControl{
75+
case serverws.ActionExec:
76+
response := serverws.ActionExecControl{
7877
Method: "action_exec_control",
7978
ExecutionId: r.ExecutionId,
8079
Status: "ok",
@@ -89,11 +88,11 @@ func (d *Device) handleRequest(msg []byte) (requests.Request, error) {
8988
}
9089

9190
return response, nil
92-
case requests.ActionListQuery:
91+
case serverws.ActionListQuery:
9392
actions := d.actionRunner.List()
94-
var reqActions []requests.Action
93+
var reqActions []serverws.Action
9594
for _, action := range actions {
96-
reqAction := requests.Action{
95+
reqAction := serverws.Action{
9796
ActionId: action.Id,
9897
ActionName: action.Name,
9998
Description: action.Description,
@@ -103,22 +102,22 @@ func (d *Device) handleRequest(msg []byte) (requests.Request, error) {
103102

104103
reqActions = append(reqActions, reqAction)
105104
}
106-
response := requests.ActionListUpdate{
105+
response := serverws.ActionListUpdate{
107106
Method: "action_list_update",
108107
Actions: reqActions,
109108
}
110109
return response, nil
111-
//case requests.DeviceAttachToManager:
110+
//case serverws.DeviceAttachToManager:
112111
// TODO: Handle shell_attach
113112
default:
114113
log.Warnf("Request '%s' is unsupported", requestName)
115-
response := requests.CantHandleRequest()
114+
response := serverws.CantHandleRequest()
116115
return response, nil
117116
}
118117
return nil, nil
119118
}
120119

121-
func (d *Device) marshalSendRetry(req requests.Request, cancelCtx context.Context) error {
120+
func (d *Device) marshalSendRetry(req serverws.Request, cancelCtx context.Context) error {
122121
msg, err := json.Marshal(req)
123122
if err != nil {
124123
return err
@@ -175,14 +174,6 @@ func (d *Device) prepareHttpTransport(tlsConf *tls.Config) *http.Transport {
175174
}
176175
}
177176

178-
func (d *Device) prepareWsDialer(tlsConf *tls.Config) *websocket.Dialer {
179-
if tlsConf != nil {
180-
return &websocket.Dialer{TLSClientConfig: tlsConf}
181-
} else {
182-
return websocket.DefaultDialer
183-
}
184-
}
185-
186177
func (d *Device) setupConnection() error {
187178
// Get MAC address
188179
mac, err := netUtils.GetMacAddr()
@@ -210,8 +201,7 @@ func (d *Device) setupConnection() error {
210201
}
211202

212203
d.httpTransport = d.prepareHttpTransport(tlsConf)
213-
wsDialer := d.prepareWsDialer(tlsConf)
214-
d.conn = NewDeviceConnection(serverUrl, *wsDialer, 1024)
204+
d.conn = serverws.NewDeviceConnection(serverUrl, tlsConf, 1024)
215205
// TODO: Hardcode capabilities right now. This should be read from the
216206
// config file instead, but we don't have any configuration options that
217207
// determine whether action/shell should be enabled or not.
@@ -230,7 +220,7 @@ func (d *Device) setupActionRunner() error {
230220
}
231221

232222
func (d *Device) actionResultCallback(result actions.ActionResult, cancelCtx context.Context) bool {
233-
res := requests.ActionExecResult{
223+
res := serverws.ActionExecResult{
234224
Method: "action_exec_result",
235225
ExecutionId: result.ExecId,
236226
StatusCode: result.StatusCode,

devices/linux-client/daemon/net_utils/net_utils.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"encoding/json"
55
"errors"
66
"net"
7-
"net/url"
87
"regexp"
98
"time"
109

@@ -13,19 +12,6 @@ import (
1312
log "github.com/sirupsen/logrus"
1413
)
1514

16-
func HostWithOrWithoutPort(addr string, withPort bool) (string, error) {
17-
host, err := url.Parse(addr)
18-
if err != nil {
19-
return "", err
20-
}
21-
if !withPort {
22-
re := regexp.MustCompile(`:\d+`)
23-
result := re.Split(host.Host, -1)
24-
return result[0], nil
25-
}
26-
return host.Host, nil
27-
}
28-
2915
func ShouldEncryptProxy(addr string) (bool, error) {
3016
pattern := `http`
3117
re := regexp.MustCompile(pattern)

devices/linux-client/daemon/device_connection.go renamed to devices/linux-client/serverws/mgmt_connection.go

Lines changed: 22 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
package daemon
1+
package serverws
22

33
import (
44
"context"
5+
"crypto/tls"
56
"encoding/json"
6-
"net/http"
7-
"net/url"
87
"sync"
98

10-
netUtils "github.com/antmicro/rdfm/daemon/net_utils"
119
"github.com/gorilla/websocket"
1210
log "github.com/sirupsen/logrus"
1311
)
@@ -19,7 +17,7 @@ const (
1917
stateConnected
2018
)
2119

22-
type DeviceConnection struct {
20+
type DeviceManagementConnection struct {
2321
ws *websocket.Conn
2422
rx chan []byte
2523
txMut sync.Mutex
@@ -39,7 +37,7 @@ func (e *ConnClosedError) Error() string {
3937
return e.message
4038
}
4139

42-
func (d *DeviceConnection) startRecvLoop(cancelCtx context.Context) error {
40+
func (d *DeviceManagementConnection) startRecvLoop(cancelCtx context.Context) error {
4341
for {
4442
d.wsMut.RLock()
4543
if d.ws == nil {
@@ -58,7 +56,7 @@ func (d *DeviceConnection) startRecvLoop(cancelCtx context.Context) error {
5856
}
5957
}
6058

61-
func (d *DeviceConnection) Close() error {
59+
func (d *DeviceManagementConnection) Close() error {
6260
d.stateCnd.L.Lock()
6361
defer d.stateCnd.L.Unlock()
6462
err := d.tryClose()
@@ -68,7 +66,7 @@ func (d *DeviceConnection) Close() error {
6866
return err
6967
}
7068

71-
func (d *DeviceConnection) tryClose() error {
69+
func (d *DeviceManagementConnection) tryClose() error {
7270
d.wsMut.RLock()
7371
defer d.wsMut.RUnlock()
7472
if d.ws != nil {
@@ -77,18 +75,18 @@ func (d *DeviceConnection) tryClose() error {
7775
return nil
7876
}
7977

80-
func (d *DeviceConnection) setState(state deviceConnectionState) {
78+
func (d *DeviceManagementConnection) setState(state deviceConnectionState) {
8179
d.stateCnd.L.Lock()
8280
d.state = state
8381
d.stateCnd.Broadcast()
8482
d.stateCnd.L.Unlock()
8583
}
8684

87-
func (d *DeviceConnection) announceCapabilities() error {
85+
func (d *DeviceManagementConnection) announceCapabilities() error {
8886
res := map[string]interface{}{
8987
"method": "capability_report",
9088
"capabilities": d.capabilities,
91-
}
89+
}
9290

9391
msg, err := json.Marshal(res)
9492
if err != nil {
@@ -99,15 +97,19 @@ func (d *DeviceConnection) announceCapabilities() error {
9997
return nil
10098
}
10199

102-
func (d *DeviceConnection) CreateConnection(deviceToken string, cancelCtx context.Context) error {
100+
func (d *DeviceManagementConnection) CreateConnection(deviceToken string, cancelCtx context.Context) error {
103101
var wg sync.WaitGroup
104102

105103
defer func() {
106104
d.tryClose()
107105
d.setState(stateDisconnected)
108106
}()
109107

110-
err := d.prepareWs(deviceToken)
108+
endpoint, err := formatDeviceWsUrl(d.serverUrl)
109+
if err != nil {
110+
return err
111+
}
112+
d.ws, err = ConnectToRdfmWs(d.dialer, endpoint, deviceToken)
111113
if err != nil {
112114
return err
113115
}
@@ -156,18 +158,18 @@ func (d *DeviceConnection) CreateConnection(deviceToken string, cancelCtx contex
156158
return err
157159
}
158160

159-
func NewDeviceConnection(serverUrl string, dialer websocket.Dialer, buffer_size int) *DeviceConnection {
160-
dc := new(DeviceConnection)
161+
func NewDeviceConnection(serverUrl string, tlsConf *tls.Config, buffer_size int) *DeviceManagementConnection {
162+
dc := new(DeviceManagementConnection)
161163
dc.rx = make(chan []byte, buffer_size)
162164
dc.serverUrl = serverUrl
163-
dc.dialer = dialer
165+
dc.dialer = *prepareWsDialer(tlsConf)
164166
dc.state = stateDisconnected
165167
dc.stateCnd = sync.NewCond(&sync.Mutex{})
166168
dc.capabilities = make(map[string]bool)
167169
return dc
168170
}
169171

170-
func (d *DeviceConnection) Recv(cancelCtx context.Context) []byte {
172+
func (d *DeviceManagementConnection) Recv(cancelCtx context.Context) []byte {
171173
select {
172174
case msg, ok := <-d.rx:
173175
if ok {
@@ -178,7 +180,7 @@ func (d *DeviceConnection) Recv(cancelCtx context.Context) []byte {
178180
return nil
179181
}
180182

181-
func (d *DeviceConnection) Send(msg []byte) error {
183+
func (d *DeviceManagementConnection) Send(msg []byte) error {
182184
d.txMut.Lock()
183185
d.wsMut.RLock()
184186
defer d.txMut.Unlock()
@@ -195,47 +197,14 @@ func (d *DeviceConnection) Send(msg []byte) error {
195197
return nil
196198
}
197199

198-
func (d *DeviceConnection) EnsureReady() {
200+
func (d *DeviceManagementConnection) EnsureReady() {
199201
d.stateCnd.L.Lock()
200202
defer d.stateCnd.L.Unlock()
201203
for d.state != stateConnected {
202204
d.stateCnd.Wait()
203205
}
204206
}
205207

206-
func (d *DeviceConnection) prepareWs(deviceToken string) error {
207-
d.wsMut.Lock()
208-
defer d.wsMut.Unlock()
209-
210-
// Get the endpoint URL
211-
addr, err := netUtils.HostWithOrWithoutPort(d.serverUrl, true)
212-
if err != nil {
213-
return err
214-
}
215-
216-
scheme := "ws"
217-
if d.dialer.TLSClientConfig != nil {
218-
scheme = "wss"
219-
}
220-
221-
u := url.URL{
222-
Scheme: scheme,
223-
Host: addr,
224-
Path: "/api/v1/devices/ws",
225-
}
226-
227-
// Connect to the endpoint
228-
log.Infoln("Connecting to", u.String())
229-
230-
authHeader := http.Header{
231-
"Authorization": []string{"Bearer token=" + deviceToken},
232-
}
233-
234-
ws, _, err := d.dialer.Dial(u.String(), authHeader)
235-
d.ws = ws
236-
return err
237-
}
238-
239-
func (d *DeviceConnection) SetCapability(cap string, value bool) {
208+
func (d *DeviceManagementConnection) SetCapability(cap string, value bool) {
240209
d.capabilities[cap] = value
241210
}

devices/linux-client/daemon/requests/requests.go renamed to devices/linux-client/serverws/requests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package requests
1+
package serverws
22

33
import (
44
"encoding/json"
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package serverws
2+
3+
import (
4+
"crypto/tls"
5+
"net/http"
6+
"net/url"
7+
8+
"github.com/gorilla/websocket"
9+
log "github.com/sirupsen/logrus"
10+
)
11+
12+
func formatDeviceWsUrl(serverUrl string) (string, error) {
13+
return url.JoinPath(serverUrl, "/api/v1/devices/ws")
14+
}
15+
16+
func prepareWsDialer(tlsConf *tls.Config) *websocket.Dialer {
17+
if tlsConf != nil {
18+
return &websocket.Dialer{TLSClientConfig: tlsConf}
19+
} else {
20+
return websocket.DefaultDialer
21+
}
22+
}
23+
24+
func ConnectToRdfmWs(dialer websocket.Dialer, endpoint string, deviceToken string) (*websocket.Conn, error) {
25+
scheme := "ws"
26+
if dialer.TLSClientConfig != nil {
27+
scheme = "wss"
28+
}
29+
url, err := url.Parse(endpoint)
30+
if err != nil {
31+
return nil, err
32+
}
33+
url.Scheme = scheme
34+
wsEndpoint := url.String()
35+
36+
log.Infoln("Connecting to", wsEndpoint)
37+
authHeader := http.Header{
38+
"Authorization": []string{"Bearer token=" + deviceToken},
39+
}
40+
ws, _, err := dialer.Dial(wsEndpoint, authHeader)
41+
return ws, err
42+
}

0 commit comments

Comments
 (0)