Skip to content

Commit 75e1804

Browse files
committed
handler: option for custom stream forward func
1 parent ec1e351 commit 75e1804

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

handler/config.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package handler
22

33
import (
4+
"context"
5+
"io"
6+
47
"github.com/SenseUnit/dumbproxy/auth"
58
clog "github.com/SenseUnit/dumbproxy/log"
69
)
@@ -9,11 +12,14 @@ type Config struct {
912
// Dialer optionally specifies dialer to use for creating
1013
// connections originating from proxy.
1114
Dialer HandlerDialer
12-
// Auth specifies request validator used to verify users
13-
// and return their username
15+
// Auth optionally specifies request validator used to verify users
16+
// and return their username.
1417
Auth auth.Auth
15-
// Logger specifies optional custom logger
18+
// Logger specifies optional custom logger.
1619
Logger *clog.CondLogger
20+
// Forward optionally specifies custom connection pairing function
21+
// which does actual data forwarding.
22+
Forward func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error
1723
// UserIPHints specifies whether allow IP hints set by user or not
1824
UserIPHints bool
1925
}

handler/handler.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type ProxyHandler struct {
2525
auth auth.Auth
2626
logger *clog.CondLogger
2727
dialer HandlerDialer
28+
forward func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error
2829
httptransport http.RoundTripper
2930
outbound map[string]string
3031
outboundMux sync.RWMutex
@@ -48,10 +49,15 @@ func NewProxyHandler(config *Config) *ProxyHandler {
4849
if l == nil {
4950
l = clog.NewCondLogger(log.New(io.Discard, "", 0), 0)
5051
}
52+
f := config.Forward
53+
if f == nil {
54+
f = PairConnections
55+
}
5156
return &ProxyHandler{
5257
auth: a,
5358
logger: l,
5459
dialer: d,
60+
forward: f,
5561
httptransport: httptransport,
5662
outbound: make(map[string]string),
5763
userIPHints: config.UserIPHints,
@@ -90,12 +96,12 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request, u
9096
// Inform client connection is built
9197
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
9298

93-
PairConnections(req.Context(), username, localconn, conn)
99+
s.forward(req.Context(), username, localconn, conn)
94100
} else if req.ProtoMajor == 2 {
95101
wr.Header()["Date"] = nil
96102
wr.WriteHeader(http.StatusOK)
97103
flush(wr)
98-
PairConnections(req.Context(), username, wrapH2(req.Body, wr), conn)
104+
s.forward(req.Context(), username, wrapH2(req.Body, wr), conn)
99105
} else {
100106
s.logger.Error("Unsupported protocol version: %s", req.Proto)
101107
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
@@ -109,7 +115,7 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request,
109115
origBody := req.Body
110116
req.Body = forwardReqBody.Body()
111117
go func() {
112-
PairConnections(req.Context(), username, wrapH1ReqBody(origBody), forwardReqBody)
118+
s.forward(req.Context(), username, wrapH1ReqBody(origBody), forwardReqBody)
113119
}()
114120
if req.ProtoMajor == 2 {
115121
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
@@ -127,7 +133,7 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request,
127133
copyHeader(wr.Header(), resp.Header)
128134
wr.WriteHeader(resp.StatusCode)
129135
flush(wr)
130-
PairConnections(req.Context(), username, wrapH1RespWriter(wr), wrapH1ReqBody(resp.Body))
136+
s.forward(req.Context(), username, wrapH1RespWriter(wr), wrapH1ReqBody(resp.Body))
131137
}
132138

133139
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {

0 commit comments

Comments
 (0)