@@ -39,33 +39,52 @@ func (connection *connection) handleFrame(ctx context.Context) error {
39
39
return nil
40
40
}
41
41
42
+ type userConnections = map [* connection ]bool
43
+
44
+ func newUserConnections () userConnections {
45
+ return make (map [* connection ]bool )
46
+ }
47
+
42
48
type connections struct {
43
- data map [* connection ] bool
49
+ dict map [db. UserID ] userConnections
44
50
mutex sync.RWMutex
45
51
}
46
52
47
53
func newConnections () connections {
48
- return connections {data : make (map [* connection ] bool )}
54
+ return connections {dict : make (map [db. UserID ] userConnections )}
49
55
}
50
56
51
- func (connections * connections ) store (key * connection , value bool ) {
57
+ func (connections * connections ) store (userID db. UserID , connection * connection ) {
52
58
connections .mutex .Lock ()
53
- connections .data [key ] = value
54
- connections .mutex .Unlock ()
59
+ defer connections .mutex .Unlock ()
60
+
61
+ if userConnections , ok := connections .dict [userID ]; ok {
62
+ userConnections [connection ] = true
63
+ return
64
+ }
65
+
66
+ userConnections := newUserConnections ()
67
+ userConnections [connection ] = true
68
+ connections .dict [userID ] = userConnections
55
69
}
56
70
57
- func (connections * connections ) delete (key * connection ) {
71
+ func (connections * connections ) delete (userID db. UserID , connection * connection ) {
58
72
connections .mutex .Lock ()
59
- delete (connections .data , key )
60
- connections .mutex .Unlock ()
73
+ defer connections .mutex .Unlock ()
74
+
75
+ if userConnections , ok := connections .dict [userID ]; ok {
76
+ delete (userConnections , connection )
77
+ }
61
78
}
62
79
63
80
func (connections * connections ) close () {
64
81
connections .mutex .RLock ()
65
82
defer connections .mutex .RUnlock ()
66
83
67
- for connection := range connections .data {
68
- connection .Close (websocket .StatusNormalClosure , "server shutting down" )
84
+ for _ , userConnections := range connections .dict {
85
+ for connection := range userConnections {
86
+ connection .Close (websocket .StatusNormalClosure , "server shutting down" )
87
+ }
69
88
}
70
89
}
71
90
@@ -80,7 +99,7 @@ func newWsHandler() *wsHandler {
80
99
}
81
100
}
82
101
83
- func (handler * wsHandler ) handle (writer http.ResponseWriter , request * http.Request ) error {
102
+ func (handler * wsHandler ) handle (userID db. UserID , writer http.ResponseWriter , request * http.Request ) error {
84
103
defer handler .waitGroup .Done ()
85
104
86
105
options := websocket.AcceptOptions {InsecureSkipVerify : true }
@@ -92,8 +111,8 @@ func (handler *wsHandler) handle(writer http.ResponseWriter, request *http.Reque
92
111
connection := connection {conn , writer , request }
93
112
defer connection .CloseNow ()
94
113
95
- handler .connections .store (& connection , true )
96
- defer handler .connections .delete (& connection )
114
+ handler .connections .store (userID , & connection )
115
+ defer handler .connections .delete (userID , & connection )
97
116
98
117
log .Printf ("opened a connection with %v" , request .RemoteAddr )
99
118
@@ -113,7 +132,7 @@ func (handler *wsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Re
113
132
return
114
133
}
115
134
116
- _ , err = db .AuthenticateBySessionKey (sessionKey .Value )
135
+ userID , err : = db .AuthenticateBySessionKey (sessionKey .Value )
117
136
if err != nil {
118
137
switch err {
119
138
case db .Unathorized :
@@ -126,7 +145,7 @@ func (handler *wsHandler) ServeHTTP(writer http.ResponseWriter, request *http.Re
126
145
127
146
handler .waitGroup .Add (1 )
128
147
129
- err = handler .handle (writer , request )
148
+ err = handler .handle (userID , writer , request )
130
149
if err != nil {
131
150
log .Println (err )
132
151
return
0 commit comments