@@ -143,7 +143,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
143143 return
144144 }
145145 ctx .Logf ("Accepting CONNECT to %s" , host )
146- _ , _ = proxyClient .Write ([]byte ("HTTP/1.0 200 Connection established\r \n \r \n " ))
146+ if todo .Hijack != nil {
147+ todo .Hijack (r , proxyClient , ctx )
148+ } else {
149+ _ , _ = proxyClient .Write ([]byte ("HTTP/1.0 200 Connection established\r \n \r \n " ))
150+ }
147151
148152 targetTCP , targetOK := targetSiteCon .(halfClosable )
149153 proxyClientTCP , clientOK := proxyClient .(halfClosable )
@@ -194,7 +198,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
194198 case ConnectHijack :
195199 todo .Hijack (r , proxyClient , ctx )
196200 case ConnectHTTPMitm :
197- _ , _ = proxyClient .Write ([]byte ("HTTP/1.0 200 OK\r \n \r \n " ))
201+ if todo .Hijack != nil {
202+ todo .Hijack (r , proxyClient , ctx )
203+ } else {
204+ _ , _ = proxyClient .Write ([]byte ("HTTP/1.0 200 OK\r \n \r \n " ))
205+ }
198206 ctx .Logf ("Assuming CONNECT is plain HTTP tunneling, mitm proxying it" )
199207
200208 var targetSiteCon net.Conn
@@ -265,7 +273,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
265273 }
266274 }
267275 case ConnectMitm :
268- _ , _ = proxyClient .Write ([]byte ("HTTP/1.0 200 OK\r \n \r \n " ))
276+ if todo .Hijack != nil {
277+ todo .Hijack (r , proxyClient , ctx )
278+ } else {
279+ _ , _ = proxyClient .Write ([]byte ("HTTP/1.0 200 OK\r \n \r \n " ))
280+ }
269281 ctx .Logf ("Assuming CONNECT is TLS, mitm proxying it" )
270282 // this goes in a separate goroutine, so that the net/http server won't think we're
271283 // still handling the request even after hijacking the connection. Those HTTP CONNECT
@@ -534,7 +546,15 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxy(httpsProxy string) func(netw
534546
535547func (proxy * ProxyHttpServer ) NewConnectDialToProxyWithHandler (
536548 httpsProxy string ,
537- connectReqHandler func (req * http.Request ),
549+ connectReqHandler func (req * http.Request ) error ,
550+ ) func (network , addr string ) (net.Conn , error ) {
551+ return proxy .NewConnectDialToProxyWithMoreHandlers (httpsProxy , connectReqHandler , nil )
552+ }
553+
554+ func (proxy * ProxyHttpServer ) NewConnectDialToProxyWithMoreHandlers (
555+ httpsProxy string ,
556+ connectReqHandler func (req * http.Request ) error ,
557+ connectRespHandler func (req * http.Response ) error ,
538558) func (network , addr string ) (net.Conn , error ) {
539559 u , err := url .Parse (httpsProxy )
540560 if err != nil {
@@ -552,7 +572,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
552572 Header : make (http.Header ),
553573 }
554574 if connectReqHandler != nil {
555- connectReqHandler (connectReq )
575+ if err := connectReqHandler (connectReq ); err != nil {
576+ return nil , err
577+ }
556578 }
557579 c , err := proxy .dial (& ProxyCtx {Req : & http.Request {}}, network , u .Host )
558580 if err != nil {
@@ -569,7 +591,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
569591 return nil , err
570592 }
571593 defer resp .Body .Close ()
572- if resp .StatusCode != http .StatusOK {
594+ if connectRespHandler != nil {
595+ if err := connectRespHandler (resp ); err != nil {
596+ c .Close ()
597+ return nil , err
598+ }
599+ } else if resp .StatusCode != http .StatusOK {
573600 resp , err := io .ReadAll (io .LimitReader (resp .Body , _errorRespMaxLength ))
574601 if err != nil {
575602 return nil , err
@@ -603,7 +630,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
603630 Header : make (http.Header ),
604631 }
605632 if connectReqHandler != nil {
606- connectReqHandler (connectReq )
633+ if err := connectReqHandler (connectReq ); err != nil {
634+ return nil , err
635+ }
607636 }
608637 _ = connectReq .Write (c )
609638 // Read response.
@@ -616,7 +645,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
616645 return nil , err
617646 }
618647 defer resp .Body .Close ()
619- if resp .StatusCode != http .StatusOK {
648+ if connectRespHandler != nil {
649+ if err := connectRespHandler (resp ); err != nil {
650+ c .Close ()
651+ return nil , err
652+ }
653+ } else if resp .StatusCode != http .StatusOK {
620654 body , err := io .ReadAll (io .LimitReader (resp .Body , _errorRespMaxLength ))
621655 if err != nil {
622656 return nil , err
0 commit comments