Skip to content

Commit 59c08ca

Browse files
authored
Merge pull request #89 from SenseUnit/handler_refactoring
Handler refactoring
2 parents 9e3e749 + 75e1804 commit 59c08ca

File tree

6 files changed

+255
-97
lines changed

6 files changed

+255
-97
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,6 @@ Usage of /home/user/go/bin/dumbproxy:
219219
bcrypt password cost (for -passwd mode) (default 4)
220220
-proxy value
221221
upstream proxy URL. Can be repeated multiple times to chain proxies. Examples: socks5h://127.0.0.1:9050; https://user:password@example.com:443
222-
-timeout duration
223-
timeout for network operations (default 10s)
224222
-user-ip-hints
225223
allow IP hints to be specified by user in X-Src-IP-Hints header
226224
-verbosity int

handler/adapter.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package handler
2+
3+
import (
4+
"io"
5+
"net/http"
6+
)
7+
8+
type wrappedH2 struct {
9+
r io.ReadCloser
10+
w io.Writer
11+
}
12+
13+
func wrapH2(r io.ReadCloser, w io.Writer) wrappedH2 {
14+
return wrappedH2{
15+
r: r,
16+
w: w,
17+
}
18+
}
19+
20+
func (w wrappedH2) Read(p []byte) (n int, err error) {
21+
return w.r.Read(p)
22+
}
23+
24+
func (w wrappedH2) Write(p []byte) (n int, err error) {
25+
n, err = w.w.Write(p)
26+
if err != nil {
27+
return
28+
}
29+
if f, ok := w.w.(http.Flusher); ok {
30+
f.Flush()
31+
}
32+
return
33+
}
34+
35+
func (w wrappedH2) Close() error {
36+
// can't really close response writer, but at least we can disrupt copy
37+
// closing Reader
38+
return w.r.Close()
39+
}
40+
41+
var _ io.ReadWriteCloser = wrappedH2{}
42+
43+
type wrappedH1ReqBody struct {
44+
r io.ReadCloser
45+
}
46+
47+
func wrapH1ReqBody(r io.ReadCloser) wrappedH1ReqBody {
48+
return wrappedH1ReqBody{
49+
r: r,
50+
}
51+
}
52+
53+
func (w wrappedH1ReqBody) Read(p []byte) (n int, err error) {
54+
return w.r.Read(p)
55+
}
56+
57+
func (w wrappedH1ReqBody) Write(p []byte) (n int, err error) {
58+
return len(p), nil
59+
}
60+
61+
func (w wrappedH1ReqBody) Close() error {
62+
return w.r.Close()
63+
}
64+
65+
func (w wrappedH1ReqBody) CloseWrite() error {
66+
return nil
67+
}
68+
69+
var _ io.ReadWriteCloser = wrappedH1ReqBody{}
70+
var _ interface{ CloseWrite() error } = wrappedH1ReqBody{}
71+
72+
type h1ReqBodyPipe struct {
73+
r *io.PipeReader
74+
w *io.PipeWriter
75+
}
76+
77+
func newH1ReqBodyPipe() h1ReqBodyPipe {
78+
r, w := io.Pipe()
79+
return h1ReqBodyPipe{
80+
r: r,
81+
w: w,
82+
}
83+
}
84+
85+
func (w h1ReqBodyPipe) Read(p []byte) (n int, err error) {
86+
return 0, io.EOF
87+
}
88+
89+
func (w h1ReqBodyPipe) Write(p []byte) (n int, err error) {
90+
return w.w.Write(p)
91+
}
92+
93+
func (w h1ReqBodyPipe) Close() error {
94+
return w.CloseWrite()
95+
}
96+
97+
func (w h1ReqBodyPipe) CloseWrite() error {
98+
return w.w.Close()
99+
}
100+
101+
func (w h1ReqBodyPipe) Body() io.ReadCloser {
102+
return w.r
103+
}
104+
105+
var _ io.ReadWriteCloser = h1ReqBodyPipe{}
106+
var _ interface{ CloseWrite() error } = h1ReqBodyPipe{}
107+
108+
type wrappedH1RespWriter struct {
109+
w io.Writer
110+
}
111+
112+
func wrapH1RespWriter(w io.Writer) wrappedH1RespWriter {
113+
return wrappedH1RespWriter{
114+
w: w,
115+
}
116+
}
117+
118+
func (w wrappedH1RespWriter) Read(p []byte) (n int, err error) {
119+
return 0, io.EOF
120+
}
121+
122+
func (w wrappedH1RespWriter) Write(p []byte) (n int, err error) {
123+
n, err = w.w.Write(p)
124+
if f, ok := w.w.(http.Flusher); ok {
125+
f.Flush()
126+
}
127+
return
128+
}
129+
130+
func (w wrappedH1RespWriter) Close() error {
131+
// can't really close response writer, just make copier return
132+
// and finish request
133+
return nil
134+
}
135+
136+
var _ io.ReadWriteCloser = wrappedH1RespWriter{}

handler/config.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package handler
2+
3+
import (
4+
"context"
5+
"io"
6+
7+
"github.com/SenseUnit/dumbproxy/auth"
8+
clog "github.com/SenseUnit/dumbproxy/log"
9+
)
10+
11+
type Config struct {
12+
// Dialer optionally specifies dialer to use for creating
13+
// connections originating from proxy.
14+
Dialer HandlerDialer
15+
// Auth optionally specifies request validator used to verify users
16+
// and return their username.
17+
Auth auth.Auth
18+
// Logger specifies optional custom logger.
19+
Logger *clog.CondLogger
20+
// Forward optionally specifies custom connection pairing function
21+
// which does actual data forwarding.
22+
Forward func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error
23+
// UserIPHints specifies whether allow IP hints set by user or not
24+
UserIPHints bool
25+
}

handler/handler.go

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ package handler
33
import (
44
"context"
55
"fmt"
6+
"io"
7+
"log"
68
"net"
79
"net/http"
810
"strings"
911
"sync"
10-
"time"
1112

1213
"github.com/SenseUnit/dumbproxy/auth"
1314
"github.com/SenseUnit/dumbproxy/dialer"
@@ -21,36 +22,50 @@ type HandlerDialer interface {
2122
}
2223

2324
type ProxyHandler struct {
24-
timeout time.Duration
2525
auth auth.Auth
2626
logger *clog.CondLogger
2727
dialer HandlerDialer
28+
forward func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser) error
2829
httptransport http.RoundTripper
2930
outbound map[string]string
3031
outboundMux sync.RWMutex
3132
userIPHints bool
3233
}
3334

34-
func NewProxyHandler(timeout time.Duration, auth auth.Auth, dialer HandlerDialer,
35-
userIPHints bool, logger *clog.CondLogger) *ProxyHandler {
35+
func NewProxyHandler(config *Config) *ProxyHandler {
36+
d := config.Dialer
37+
if d == nil {
38+
d = dialer.NewBoundDialer(nil, "")
39+
}
3640
httptransport := &http.Transport{
37-
DialContext: dialer.DialContext,
41+
DialContext: d.DialContext,
3842
DisableKeepAlives: true,
3943
}
44+
a := config.Auth
45+
if a == nil {
46+
a = auth.NoAuth{}
47+
}
48+
l := config.Logger
49+
if l == nil {
50+
l = clog.NewCondLogger(log.New(io.Discard, "", 0), 0)
51+
}
52+
f := config.Forward
53+
if f == nil {
54+
f = PairConnections
55+
}
4056
return &ProxyHandler{
41-
timeout: timeout,
42-
auth: auth,
43-
logger: logger,
44-
dialer: dialer,
57+
auth: a,
58+
logger: l,
59+
dialer: d,
60+
forward: f,
4561
httptransport: httptransport,
4662
outbound: make(map[string]string),
47-
userIPHints: userIPHints,
63+
userIPHints: config.UserIPHints,
4864
}
4965
}
5066

51-
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
52-
ctx, _ := context.WithTimeout(req.Context(), s.timeout)
53-
conn, err := s.dialer.DialContext(ctx, "tcp", req.RequestURI)
67+
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request, username string) {
68+
conn, err := s.dialer.DialContext(req.Context(), "tcp", req.RequestURI)
5469
if err != nil {
5570
s.logger.Error("Can't satisfy CONNECT request: %v", err)
5671
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
@@ -81,21 +96,27 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
8196
// Inform client connection is built
8297
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
8398

84-
proxy(req.Context(), localconn, conn)
99+
s.forward(req.Context(), username, localconn, conn)
85100
} else if req.ProtoMajor == 2 {
86101
wr.Header()["Date"] = nil
87102
wr.WriteHeader(http.StatusOK)
88103
flush(wr)
89-
proxyh2(req.Context(), req.Body, wr, conn)
104+
s.forward(req.Context(), username, wrapH2(req.Body, wr), conn)
90105
} else {
91106
s.logger.Error("Unsupported protocol version: %s", req.Proto)
92107
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
93108
return
94109
}
95110
}
96111

97-
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) {
112+
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request, username string) {
98113
req.RequestURI = ""
114+
forwardReqBody := newH1ReqBodyPipe()
115+
origBody := req.Body
116+
req.Body = forwardReqBody.Body()
117+
go func() {
118+
s.forward(req.Context(), username, wrapH1ReqBody(origBody), forwardReqBody)
119+
}()
99120
if req.ProtoMajor == 2 {
100121
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
101122
req.URL.Host = req.Host
@@ -112,7 +133,7 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request)
112133
copyHeader(wr.Header(), resp.Header)
113134
wr.WriteHeader(resp.StatusCode)
114135
flush(wr)
115-
copyBody(wr, resp.Body)
136+
s.forward(req.Context(), username, wrapH1RespWriter(wr), wrapH1ReqBody(resp.Body))
116137
}
117138

118139
func (s *ProxyHandler) isLoopback(req *http.Request) (string, bool) {
@@ -160,9 +181,9 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
160181
req = req.WithContext(newCtx)
161182
delHopHeaders(req.Header)
162183
if isConnect {
163-
s.HandleTunnel(wr, req)
184+
s.HandleTunnel(wr, req, username)
164185
} else {
165-
s.HandleRequest(wr, req)
186+
s.HandleRequest(wr, req, username)
166187
}
167188
}
168189

0 commit comments

Comments
 (0)