Skip to content

Commit 44bab45

Browse files
RomanManzrmanzRoman Manz
authored
Connect related changes - add optional hooks (#504)
* adding return value to connectReqHandler * adding optional connectRespHandler * calling Hijack in handleHttps if defined * fixing PR comments * Update https.go --------- Co-authored-by: Roman Manz <rmanz@amadeus.com> Co-authored-by: Roman Manz <roman.manz.fsm@e.email>
1 parent 0003d27 commit 44bab45

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

examples/cascadeproxy/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ func main() {
4141
// socks5://localhost:8082
4242
return url.Parse("http://localhost:8082")
4343
}
44-
connectReqHandler := func(req *http.Request) {
44+
connectReqHandler := func(req *http.Request) error {
4545
SetBasicAuth(username, password, req)
46+
return nil
4647
}
4748
middleProxy.ConnectDial = middleProxy.NewConnectDialToProxyWithHandler("http://localhost:8082", connectReqHandler)
4849

https.go

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
143143
return
144144
}
145145
ctx.Logf("Accepting CONNECT to %s", host)
146-
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n"))
146+
if todo.Hijack != nil {
147+
todo.Hijack(r, proxyClient, ctx)
148+
} else {
149+
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n"))
150+
}
147151

148152
targetTCP, targetOK := targetSiteCon.(halfClosable)
149153
proxyClientTCP, clientOK := proxyClient.(halfClosable)
@@ -194,7 +198,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
194198
case ConnectHijack:
195199
todo.Hijack(r, proxyClient, ctx)
196200
case ConnectHTTPMitm:
197-
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
201+
if todo.Hijack != nil {
202+
todo.Hijack(r, proxyClient, ctx)
203+
} else {
204+
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
205+
}
198206
ctx.Logf("Assuming CONNECT is plain HTTP tunneling, mitm proxying it")
199207

200208
var targetSiteCon net.Conn
@@ -265,7 +273,11 @@ func (proxy *ProxyHttpServer) handleHttps(w http.ResponseWriter, r *http.Request
265273
}
266274
}
267275
case ConnectMitm:
268-
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
276+
if todo.Hijack != nil {
277+
todo.Hijack(r, proxyClient, ctx)
278+
} else {
279+
_, _ = proxyClient.Write([]byte("HTTP/1.0 200 OK\r\n\r\n"))
280+
}
269281
ctx.Logf("Assuming CONNECT is TLS, mitm proxying it")
270282
// this goes in a separate goroutine, so that the net/http server won't think we're
271283
// still handling the request even after hijacking the connection. Those HTTP CONNECT
@@ -534,7 +546,15 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxy(httpsProxy string) func(netw
534546

535547
func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
536548
httpsProxy string,
537-
connectReqHandler func(req *http.Request),
549+
connectReqHandler func(req *http.Request) error,
550+
) func(network, addr string) (net.Conn, error) {
551+
return proxy.NewConnectDialToProxyWithMoreHandlers(httpsProxy, connectReqHandler, nil)
552+
}
553+
554+
func (proxy *ProxyHttpServer) NewConnectDialToProxyWithMoreHandlers(
555+
httpsProxy string,
556+
connectReqHandler func(req *http.Request) error,
557+
connectRespHandler func(req *http.Response) error,
538558
) func(network, addr string) (net.Conn, error) {
539559
u, err := url.Parse(httpsProxy)
540560
if err != nil {
@@ -552,7 +572,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
552572
Header: make(http.Header),
553573
}
554574
if connectReqHandler != nil {
555-
connectReqHandler(connectReq)
575+
if err := connectReqHandler(connectReq); err != nil {
576+
return nil, err
577+
}
556578
}
557579
c, err := proxy.dial(&ProxyCtx{Req: &http.Request{}}, network, u.Host)
558580
if err != nil {
@@ -569,7 +591,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
569591
return nil, err
570592
}
571593
defer resp.Body.Close()
572-
if resp.StatusCode != http.StatusOK {
594+
if connectRespHandler != nil {
595+
if err := connectRespHandler(resp); err != nil {
596+
c.Close()
597+
return nil, err
598+
}
599+
} else if resp.StatusCode != http.StatusOK {
573600
resp, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength))
574601
if err != nil {
575602
return nil, err
@@ -603,7 +630,9 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
603630
Header: make(http.Header),
604631
}
605632
if connectReqHandler != nil {
606-
connectReqHandler(connectReq)
633+
if err := connectReqHandler(connectReq); err != nil {
634+
return nil, err
635+
}
607636
}
608637
_ = connectReq.Write(c)
609638
// Read response.
@@ -616,7 +645,12 @@ func (proxy *ProxyHttpServer) NewConnectDialToProxyWithHandler(
616645
return nil, err
617646
}
618647
defer resp.Body.Close()
619-
if resp.StatusCode != http.StatusOK {
648+
if connectRespHandler != nil {
649+
if err := connectRespHandler(resp); err != nil {
650+
c.Close()
651+
return nil, err
652+
}
653+
} else if resp.StatusCode != http.StatusOK {
620654
body, err := io.ReadAll(io.LimitReader(resp.Body, _errorRespMaxLength))
621655
if err != nil {
622656
return nil, err

0 commit comments

Comments
 (0)