@@ -3,11 +3,12 @@ package handler
33import (
44 "context"
55 "fmt"
6+ "io"
7+ "log"
68 "net"
79 "net/http"
810 "strings"
911 "sync"
10- "time"
1112
1213 "github.com/SenseUnit/dumbproxy/auth"
1314 "github.com/SenseUnit/dumbproxy/dialer"
@@ -21,36 +22,50 @@ type HandlerDialer interface {
2122}
2223
2324type ProxyHandler struct {
24- timeout time.Duration
2525 auth auth.Auth
2626 logger * clog.CondLogger
2727 dialer HandlerDialer
28+ forward func (ctx context.Context , username string , incoming , outgoing io.ReadWriteCloser ) error
2829 httptransport http.RoundTripper
2930 outbound map [string ]string
3031 outboundMux sync.RWMutex
3132 userIPHints bool
3233}
3334
34- func NewProxyHandler (timeout time.Duration , auth auth.Auth , dialer HandlerDialer ,
35- userIPHints bool , logger * clog.CondLogger ) * ProxyHandler {
35+ func NewProxyHandler (config * Config ) * ProxyHandler {
36+ d := config .Dialer
37+ if d == nil {
38+ d = dialer .NewBoundDialer (nil , "" )
39+ }
3640 httptransport := & http.Transport {
37- DialContext : dialer .DialContext ,
41+ DialContext : d .DialContext ,
3842 DisableKeepAlives : true ,
3943 }
44+ a := config .Auth
45+ if a == nil {
46+ a = auth.NoAuth {}
47+ }
48+ l := config .Logger
49+ if l == nil {
50+ l = clog .NewCondLogger (log .New (io .Discard , "" , 0 ), 0 )
51+ }
52+ f := config .Forward
53+ if f == nil {
54+ f = PairConnections
55+ }
4056 return & ProxyHandler {
41- timeout : timeout ,
42- auth : auth ,
43- logger : logger ,
44- dialer : dialer ,
57+ auth : a ,
58+ logger : l ,
59+ dialer : d ,
60+ forward : f ,
4561 httptransport : httptransport ,
4662 outbound : make (map [string ]string ),
47- userIPHints : userIPHints ,
63+ userIPHints : config . UserIPHints ,
4864 }
4965}
5066
51- func (s * ProxyHandler ) HandleTunnel (wr http.ResponseWriter , req * http.Request ) {
52- ctx , _ := context .WithTimeout (req .Context (), s .timeout )
53- conn , err := s .dialer .DialContext (ctx , "tcp" , req .RequestURI )
67+ func (s * ProxyHandler ) HandleTunnel (wr http.ResponseWriter , req * http.Request , username string ) {
68+ conn , err := s .dialer .DialContext (req .Context (), "tcp" , req .RequestURI )
5469 if err != nil {
5570 s .logger .Error ("Can't satisfy CONNECT request: %v" , err )
5671 http .Error (wr , "Can't satisfy CONNECT request" , http .StatusBadGateway )
@@ -81,21 +96,27 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
8196 // Inform client connection is built
8297 fmt .Fprintf (localconn , "HTTP/%d.%d 200 OK\r \n \r \n " , req .ProtoMajor , req .ProtoMinor )
8398
84- proxy (req .Context (), localconn , conn )
99+ s . forward (req .Context (), username , localconn , conn )
85100 } else if req .ProtoMajor == 2 {
86101 wr .Header ()["Date" ] = nil
87102 wr .WriteHeader (http .StatusOK )
88103 flush (wr )
89- proxyh2 (req .Context (), req .Body , wr , conn )
104+ s . forward (req .Context (), username , wrapH2 ( req .Body , wr ) , conn )
90105 } else {
91106 s .logger .Error ("Unsupported protocol version: %s" , req .Proto )
92107 http .Error (wr , "Unsupported protocol version." , http .StatusBadRequest )
93108 return
94109 }
95110}
96111
97- func (s * ProxyHandler ) HandleRequest (wr http.ResponseWriter , req * http.Request ) {
112+ func (s * ProxyHandler ) HandleRequest (wr http.ResponseWriter , req * http.Request , username string ) {
98113 req .RequestURI = ""
114+ forwardReqBody := newH1ReqBodyPipe ()
115+ origBody := req .Body
116+ req .Body = forwardReqBody .Body ()
117+ go func () {
118+ s .forward (req .Context (), username , wrapH1ReqBody (origBody ), forwardReqBody )
119+ }()
99120 if req .ProtoMajor == 2 {
100121 req .URL .Scheme = "http" // We can't access :scheme pseudo-header, so assume http
101122 req .URL .Host = req .Host
@@ -112,7 +133,7 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request)
112133 copyHeader (wr .Header (), resp .Header )
113134 wr .WriteHeader (resp .StatusCode )
114135 flush (wr )
115- copyBody ( wr , resp .Body )
136+ s . forward ( req . Context (), username , wrapH1RespWriter ( wr ), wrapH1ReqBody ( resp .Body ) )
116137}
117138
118139func (s * ProxyHandler ) isLoopback (req * http.Request ) (string , bool ) {
@@ -160,9 +181,9 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
160181 req = req .WithContext (newCtx )
161182 delHopHeaders (req .Header )
162183 if isConnect {
163- s .HandleTunnel (wr , req )
184+ s .HandleTunnel (wr , req , username )
164185 } else {
165- s .HandleRequest (wr , req )
186+ s .HandleRequest (wr , req , username )
166187 }
167188}
168189
0 commit comments