Skip to content

Commit 7703025

Browse files
committed
Fix race condition in http req/res collector
1 parent 851fcf5 commit 7703025

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

pkg/mux/fragmentedConn.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ func (fc *fragmentedConnection) Close() error {
9696
case <-fc.done:
9797
default:
9898
close(fc.done)
99-
fc.onClose()
99+
if fc.onClose != nil {
100+
fc.onClose()
101+
}
100102
}
101103

102104
return nil

pkg/mux/multiplexer.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,17 @@ func (m *Multiplexer) collector(localAddr net.Addr) http.HandlerFunc {
175175
lck sync.Mutex
176176
)
177177

178+
cleanupConnection := func(id string, conn *fragmentedConnection) {
179+
lck.Lock()
180+
defer lck.Unlock()
181+
if _, exists := connections[id]; exists {
182+
delete(connections, id)
183+
if conn != nil {
184+
conn.Close()
185+
}
186+
}
187+
}
188+
178189
return func(w http.ResponseWriter, req *http.Request) {
179190
if req.Method != http.MethodHead && req.Method != http.MethodGet && req.Method != http.MethodPost {
180191
http.Error(w, "Bad Request", http.StatusBadRequest)
@@ -191,6 +202,7 @@ func (m *Multiplexer) collector(localAddr net.Addr) http.HandlerFunc {
191202
defer lck.Unlock()
192203

193204
if req.Method == http.MethodHead {
205+
// Check to make sure the public key is within the authorised keys file and if so create an ID for the client connection
194206

195207
if len(connections) > 2000 {
196208
log.Println("server has too many polling connections (", len(connections), " limit is 2k")
@@ -216,7 +228,7 @@ func (m *Multiplexer) collector(localAddr net.Addr) http.HandlerFunc {
216228
}
217229

218230
c, id, err = NewFragmentCollector(localAddr, realConn.RemoteAddr(), func() {
219-
delete(connections, id)
231+
cleanupConnection(id, nil)
220232
})
221233
if err != nil {
222234
log.Println("error generating new fragment collector: ", err)
@@ -269,7 +281,9 @@ func (m *Multiplexer) collector(localAddr net.Addr) http.HandlerFunc {
269281
if err == io.EOF {
270282
return
271283
}
272-
c.Close()
284+
// If there is an error rather than just us reaching the end of the current request, clear
285+
// the connection as it would otherwise cause an ssh protocol desync and closure anyway
286+
cleanupConnection(id, c)
273287
}
274288

275289
// Add data
@@ -279,7 +293,7 @@ func (m *Multiplexer) collector(localAddr net.Addr) http.HandlerFunc {
279293
if err == io.EOF {
280294
return
281295
}
282-
c.Close()
296+
cleanupConnection(id, c)
283297
}
284298
}
285299

0 commit comments

Comments
 (0)