diff --git a/go.mod b/go.mod index be0925f..b0d4b0f 100644 --- a/go.mod +++ b/go.mod @@ -13,3 +13,5 @@ require ( golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect gopkg.in/inconshreveable/log15.v2 v2.0.0-20200109203555-b30bc20e4fd1 ) + +replace github.com/imdario/mergo => dario.cat/mergo latest diff --git a/tunnel/ws.go b/tunnel/ws.go index eba8dbe..c113aeb 100644 --- a/tunnel/ws.go +++ b/tunnel/ws.go @@ -3,12 +3,10 @@ package tunnel import ( - "bytes" "errors" "fmt" "html" "io" - "io/ioutil" "net/http" // imported per documentation - https://golang.org/pkg/net/http/pprof/ @@ -26,7 +24,7 @@ func httpError(log log15.Logger, w http.ResponseWriter, token, err string, code http.Error(w, html.EscapeString(err), code) } -//websocket error constants +// websocket error constants const ( wsReadClose = iota wsReadError = iota @@ -77,14 +75,12 @@ func wsHandler(t *WSTunnelServer, w http.ResponseWriter, r *http.Request) { go func() { rs.remoteName, rs.remoteWhois = ipAddrLookup(t.Log, rs.remoteAddr) }() - // Set safety limits - ws.SetReadLimit(100 * 1024 * 1024) // Start timeout handling wsSetPingHandler(t, ws, rs) // Create synchronization channel ch := make(chan int, 2) // Spawn goroutine to read responses - go wsReader(rs, ws, t.WSTimeout, ch) + go wsReader(rs, ws, ch) // Send requests wsWriter(rs, ws, ch) } @@ -136,7 +132,7 @@ func wsWriter(rs *remoteServer, ws *websocket.Conn, ch chan int) { continue } // write the request into the tunnel - ws.SetWriteDeadline(time.Now().Add(time.Minute)) + ws.SetWriteDeadline(time.Time{}) // no timeout, there's the ping-pong for that var w io.WriteCloser w, err = ws.NextWriter(websocket.BinaryMessage) // got an error, reply with a "hey, retry" to the request handler @@ -170,9 +166,17 @@ func wsWriter(rs *remoteServer, ws *websocket.Conn, ch chan int) { } // Read responses from the tunnel and fulfill pending requests -func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch chan int) { +func wsReader(rs *remoteServer, ws *websocket.Conn, ch chan int) { var err error logToken := cutToken(rs.token) + + // the mutex remains locked unless we are within Cond.Wait() + rs.readCond.L.Lock() + defer func() { + rs.readCond.L.Unlock() + rs.readCond.Signal() + }() + // continue reading until we get an error for { ws.SetReadDeadline(time.Time{}) // no timeout, there's the ping-pong for that @@ -187,21 +191,12 @@ func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch err = fmt.Errorf("non-binary message received, type=%d", t) break } - // give the sender a fixed time to get us the data - ws.SetReadDeadline(time.Now().Add(wsTimeout)) // get request id var id int16 _, err = fmt.Fscanf(io.LimitReader(r, 4), "%04x", &id) if err != nil { break } - // read request itself, the size is limited by the SetReadLimit on the websocket - var buf []byte - buf, err = ioutil.ReadAll(r) - if err != nil { - break - } - rs.log.Info("WS RCV", "id", id, "ws", wsp(ws), "len", len(buf)) // try to match request rs.requestSetMutex.Lock() req := rs.requestSet[id] @@ -209,11 +204,12 @@ func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch rs.requestSetMutex.Unlock() // let's see... if req != nil { - rb := responseBuffer{response: bytes.NewBuffer(buf)} + rb := responseBuffer{response: r} // try to enqueue response select { case req.replyChan <- rb: - // great! + rs.log.Info("WS RCV enqueued response", "id", id, "ws", wsp(ws)) + rs.readCond.Wait() // wait for response to be sent default: rs.log.Info("WS RCV can't enqueue response", "id", id, "ws", wsp(ws)) } diff --git a/tunnel/wstuncli.go b/tunnel/wstuncli.go index 1ffd00d..1d5e5fa 100644 --- a/tunnel/wstuncli.go +++ b/tunnel/wstuncli.go @@ -364,8 +364,6 @@ func (wsc *WSConnection) handleRequests() { wsc.Log.Warn("WS invalid message type", "type", typ) break } - // give the sender a minute to produce the request - wsc.ws.SetReadDeadline(time.Now().Add(time.Minute)) // read request id var id int16 _, err = fmt.Fscanf(io.LimitReader(r, 4), "%04x", &id) @@ -758,7 +756,7 @@ func (wsc *WSConnection) writeResponseMessage(id int16, resp *http.Response) { wsWriterMutex.Lock() defer wsWriterMutex.Unlock() // Write response into the tunnel - wsc.ws.SetWriteDeadline(time.Now().Add(time.Minute)) + wsc.ws.SetWriteDeadline(time.Time{}) // separate ping-pong routine does timeout w, err := wsc.ws.NextWriter(websocket.BinaryMessage) // got an error, reply with a "hey, retry" to the request handler if err != nil { diff --git a/tunnel/wstunsrv.go b/tunnel/wstunsrv.go index 034ff56..fcabf19 100644 --- a/tunnel/wstunsrv.go +++ b/tunnel/wstunsrv.go @@ -51,7 +51,7 @@ type token string type responseBuffer struct { err error - response *bytes.Buffer + response io.Reader } // A request for a remote server @@ -77,6 +77,8 @@ type remoteServer struct { requestSet map[int16]*remoteRequest // all requests in queue/flight indexed by ID requestSetMutex sync.Mutex log log15.Logger + readMutex sync.Mutex // ensure that no more than one goroutine calls the websocket read methods concurrently + readCond *sync.Cond // (NextReader, SetReadDeadline, SetPingHandler, ...) } //WSTunnelServer a wstunnel server construct @@ -349,8 +351,14 @@ func getResponse(t *WSTunnelServer, req *remoteRequest, w http.ResponseWriter, r } // Ensure we retire the request when we pop out of this function + // and signal the tunnel reader to continue defer func() { rs.RetireRequest(req) + if !retry { + rs.readCond.L.Lock() // make sure the reader is in Wait() + rs.readCond.Signal() + rs.readCond.L.Unlock() + } }() // enqueue request @@ -367,6 +375,7 @@ func getResponse(t *WSTunnelServer, req *remoteRequest, w http.ResponseWriter, r } req.log.Info("HTTP RCV", "verb", r.Method, "url", r.URL, "addr", req.remoteAddr, "x-host", r.Header.Get("X-Host"), "try", try) + // wait for response select { case resp := <-req.replyChan: @@ -426,6 +435,7 @@ func (t *WSTunnelServer) getRemoteServer(tok token, create bool) *remoteServer { requestSet: make(map[int16]*remoteRequest), log: log15.New("token", cutToken(tok)), } + rs.readCond = sync.NewCond(&rs.readMutex) t.serverRegistry[tok] = rs return rs } @@ -498,8 +508,8 @@ var censoredHeaders = []string{ } // Write an HTTP response from a byte buffer into a ResponseWriter -func writeResponse(w http.ResponseWriter, buf *bytes.Buffer) int { - resp, err := http.ReadResponse(bufio.NewReader(buf), nil) +func writeResponse(w http.ResponseWriter, r io.Reader) int { + resp, err := http.ReadResponse(bufio.NewReader(r), nil) if err != nil { log15.Info("WriteResponse: can't parse incoming response", "err", err) w.WriteHeader(506) @@ -512,6 +522,7 @@ func writeResponse(w http.ResponseWriter, buf *bytes.Buffer) int { copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) io.Copy(w, resp.Body) + resp.Body.Close() return resp.StatusCode }