Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ require (
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect
gopkg.in/inconshreveable/log15.v2 v2.0.0-20200109203555-b30bc20e4fd1
)

replace github.com/imdario/mergo => dario.cat/mergo latest
34 changes: 15 additions & 19 deletions tunnel/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
package tunnel

import (
"bytes"
"errors"
"fmt"
"html"
"io"
"io/ioutil"
"net/http"

// imported per documentation - https://golang.org/pkg/net/http/pprof/
Expand All @@ -26,7 +24,7 @@ func httpError(log log15.Logger, w http.ResponseWriter, token, err string, code
http.Error(w, html.EscapeString(err), code)
}

//websocket error constants
// websocket error constants
const (
wsReadClose = iota
wsReadError = iota
Expand Down Expand Up @@ -77,14 +75,12 @@ func wsHandler(t *WSTunnelServer, w http.ResponseWriter, r *http.Request) {
go func() {
rs.remoteName, rs.remoteWhois = ipAddrLookup(t.Log, rs.remoteAddr)
}()
// Set safety limits
ws.SetReadLimit(100 * 1024 * 1024)
// Start timeout handling
wsSetPingHandler(t, ws, rs)
// Create synchronization channel
ch := make(chan int, 2)
// Spawn goroutine to read responses
go wsReader(rs, ws, t.WSTimeout, ch)
go wsReader(rs, ws, ch)
// Send requests
wsWriter(rs, ws, ch)
}
Expand Down Expand Up @@ -136,7 +132,7 @@ func wsWriter(rs *remoteServer, ws *websocket.Conn, ch chan int) {
continue
}
// write the request into the tunnel
ws.SetWriteDeadline(time.Now().Add(time.Minute))
ws.SetWriteDeadline(time.Time{}) // no timeout, there's the ping-pong for that
var w io.WriteCloser
w, err = ws.NextWriter(websocket.BinaryMessage)
// got an error, reply with a "hey, retry" to the request handler
Expand Down Expand Up @@ -170,9 +166,17 @@ func wsWriter(rs *remoteServer, ws *websocket.Conn, ch chan int) {
}

// Read responses from the tunnel and fulfill pending requests
func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch chan int) {
func wsReader(rs *remoteServer, ws *websocket.Conn, ch chan int) {
var err error
logToken := cutToken(rs.token)

// the mutex remains locked unless we are within Cond.Wait()
rs.readCond.L.Lock()
defer func() {
rs.readCond.L.Unlock()
rs.readCond.Signal()
}()

// continue reading until we get an error
for {
ws.SetReadDeadline(time.Time{}) // no timeout, there's the ping-pong for that
Expand All @@ -187,33 +191,25 @@ func wsReader(rs *remoteServer, ws *websocket.Conn, wsTimeout time.Duration, ch
err = fmt.Errorf("non-binary message received, type=%d", t)
break
}
// give the sender a fixed time to get us the data
ws.SetReadDeadline(time.Now().Add(wsTimeout))
// get request id
var id int16
_, err = fmt.Fscanf(io.LimitReader(r, 4), "%04x", &id)
if err != nil {
break
}
// read request itself, the size is limited by the SetReadLimit on the websocket
var buf []byte
buf, err = ioutil.ReadAll(r)
if err != nil {
break
}
rs.log.Info("WS RCV", "id", id, "ws", wsp(ws), "len", len(buf))
// try to match request
rs.requestSetMutex.Lock()
req := rs.requestSet[id]
rs.lastActivity = time.Now()
rs.requestSetMutex.Unlock()
// let's see...
if req != nil {
rb := responseBuffer{response: bytes.NewBuffer(buf)}
rb := responseBuffer{response: r}
// try to enqueue response
select {
case req.replyChan <- rb:
// great!
rs.log.Info("WS RCV enqueued response", "id", id, "ws", wsp(ws))
rs.readCond.Wait() // wait for response to be sent
default:
rs.log.Info("WS RCV can't enqueue response", "id", id, "ws", wsp(ws))
}
Expand Down
4 changes: 1 addition & 3 deletions tunnel/wstuncli.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,6 @@ func (wsc *WSConnection) handleRequests() {
wsc.Log.Warn("WS invalid message type", "type", typ)
break
}
// give the sender a minute to produce the request
wsc.ws.SetReadDeadline(time.Now().Add(time.Minute))
// read request id
var id int16
_, err = fmt.Fscanf(io.LimitReader(r, 4), "%04x", &id)
Expand Down Expand Up @@ -758,7 +756,7 @@ func (wsc *WSConnection) writeResponseMessage(id int16, resp *http.Response) {
wsWriterMutex.Lock()
defer wsWriterMutex.Unlock()
// Write response into the tunnel
wsc.ws.SetWriteDeadline(time.Now().Add(time.Minute))
wsc.ws.SetWriteDeadline(time.Time{}) // separate ping-pong routine does timeout
w, err := wsc.ws.NextWriter(websocket.BinaryMessage)
// got an error, reply with a "hey, retry" to the request handler
if err != nil {
Expand Down
17 changes: 14 additions & 3 deletions tunnel/wstunsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type token string

type responseBuffer struct {
err error
response *bytes.Buffer
response io.Reader
}

// A request for a remote server
Expand All @@ -77,6 +77,8 @@ type remoteServer struct {
requestSet map[int16]*remoteRequest // all requests in queue/flight indexed by ID
requestSetMutex sync.Mutex
log log15.Logger
readMutex sync.Mutex // ensure that no more than one goroutine calls the websocket read methods concurrently
readCond *sync.Cond // (NextReader, SetReadDeadline, SetPingHandler, ...)
}

//WSTunnelServer a wstunnel server construct
Expand Down Expand Up @@ -349,8 +351,14 @@ func getResponse(t *WSTunnelServer, req *remoteRequest, w http.ResponseWriter, r
}

// Ensure we retire the request when we pop out of this function
// and signal the tunnel reader to continue
defer func() {
rs.RetireRequest(req)
if !retry {
rs.readCond.L.Lock() // make sure the reader is in Wait()
rs.readCond.Signal()
rs.readCond.L.Unlock()
}
}()

// enqueue request
Expand All @@ -367,6 +375,7 @@ func getResponse(t *WSTunnelServer, req *remoteRequest, w http.ResponseWriter, r
}
req.log.Info("HTTP RCV", "verb", r.Method, "url", r.URL,
"addr", req.remoteAddr, "x-host", r.Header.Get("X-Host"), "try", try)

// wait for response
select {
case resp := <-req.replyChan:
Expand Down Expand Up @@ -426,6 +435,7 @@ func (t *WSTunnelServer) getRemoteServer(tok token, create bool) *remoteServer {
requestSet: make(map[int16]*remoteRequest),
log: log15.New("token", cutToken(tok)),
}
rs.readCond = sync.NewCond(&rs.readMutex)
t.serverRegistry[tok] = rs
return rs
}
Expand Down Expand Up @@ -498,8 +508,8 @@ var censoredHeaders = []string{
}

// Write an HTTP response from a byte buffer into a ResponseWriter
func writeResponse(w http.ResponseWriter, buf *bytes.Buffer) int {
resp, err := http.ReadResponse(bufio.NewReader(buf), nil)
func writeResponse(w http.ResponseWriter, r io.Reader) int {
resp, err := http.ReadResponse(bufio.NewReader(r), nil)
if err != nil {
log15.Info("WriteResponse: can't parse incoming response", "err", err)
w.WriteHeader(506)
Expand All @@ -512,6 +522,7 @@ func writeResponse(w http.ResponseWriter, buf *bytes.Buffer) int {
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
resp.Body.Close()
return resp.StatusCode
}

Expand Down