Skip to content

Commit b63b419

Browse files
authored
Merge pull request #735 from mjkim610/fix-upstream-proxy-bug
Fix upstream proxy bug
2 parents f2381b5 + 6b6c4cb commit b63b419

File tree

3 files changed

+299
-0
lines changed

3 files changed

+299
-0
lines changed

proxy.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,23 @@ func NewProxy(options *Options) (*Proxy, error) {
144144
}
145145
fastdialerOptions.BaseResolvers = []string{"127.0.0.1" + options.ListenDNSAddr}
146146
}
147+
148+
if len(options.UpstreamHTTPProxies) > 0 {
149+
proxyDialer, err := newHTTPProxyRoundRobinDialer(options.UpstreamHTTPProxies)
150+
if err != nil {
151+
return nil, err
152+
}
153+
fastdialerOptions.ProxyDialer = &proxyDialer
154+
}
155+
156+
if len(options.UpstreamSOCKS5Proxies) > 0 {
157+
dialer, err := newSOCKS5ProxyRoundRobinDialer(options.UpstreamSOCKS5Proxies)
158+
if err != nil {
159+
return nil, err
160+
}
161+
fastdialerOptions.ProxyDialer = &dialer
162+
}
163+
147164
dialer, err := fastdialer.NewDialer(fastdialerOptions)
148165
if err != nil {
149166
return nil, err

upstream.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
package proxify
2+
3+
import (
4+
"bufio"
5+
"encoding/base64"
6+
"fmt"
7+
"net"
8+
"net/http"
9+
"net/url"
10+
11+
rbtransport "github.com/projectdiscovery/roundrobin/transport"
12+
"golang.org/x/net/proxy"
13+
)
14+
15+
type httpProxyDialer struct {
16+
proxyURL *url.URL
17+
forward proxy.Dialer
18+
}
19+
20+
// Dial connects to the address using the HTTP proxy.
21+
func (d *httpProxyDialer) Dial(_, addr string) (net.Conn, error) {
22+
conn, err := d.forward.Dial("tcp", d.proxyURL.Host)
23+
if err != nil {
24+
return nil, err
25+
}
26+
27+
connectReq := &http.Request{
28+
Method: "CONNECT",
29+
URL: &url.URL{Opaque: addr},
30+
Host: addr,
31+
Header: make(http.Header),
32+
}
33+
if d.proxyURL.User != nil {
34+
encodedUserinfo := base64.StdEncoding.EncodeToString([]byte(d.proxyURL.User.String()))
35+
connectReq.Header.Set("Proxy-Authorization", "Basic "+encodedUserinfo)
36+
}
37+
38+
if err := connectReq.Write(conn); err != nil {
39+
_ = conn.Close()
40+
return nil, err
41+
}
42+
43+
br := bufio.NewReader(conn)
44+
resp, err := http.ReadResponse(br, connectReq)
45+
if err != nil {
46+
_ = conn.Close()
47+
return nil, err
48+
}
49+
if resp.StatusCode != http.StatusOK {
50+
_ = conn.Close()
51+
return nil, fmt.Errorf("unexpected response from proxy: %s", resp.Status)
52+
}
53+
54+
return conn, nil
55+
}
56+
57+
type httpProxyRoundRobinDialer struct {
58+
proxyDialers map[string]httpProxyDialer
59+
transport *rbtransport.RoundTransport
60+
}
61+
62+
// Dial connects to the address on the named network via one of the HTTP proxies using round-robin scheduling.
63+
func (d *httpProxyRoundRobinDialer) Dial(network, addr string) (net.Conn, error) {
64+
nextProxyURL := d.transport.Next()
65+
dialer, ok := d.proxyDialers[nextProxyURL]
66+
if !ok {
67+
return nil, fmt.Errorf("no matching proxy dialer found")
68+
}
69+
return dialer.Dial(network, addr)
70+
}
71+
72+
func newHTTPProxyRoundRobinDialer(upstreamProxies []string) (proxy.Dialer, error) {
73+
if len(upstreamProxies) == 0 {
74+
return nil, fmt.Errorf("proxy URLs cannot be empty")
75+
}
76+
77+
proxyURLs := make([]*url.URL, 0, len(upstreamProxies))
78+
dialers := make(map[string]httpProxyDialer)
79+
for _, proxyAddr := range upstreamProxies {
80+
proxyURL, err := url.Parse(proxyAddr)
81+
if err != nil {
82+
return nil, err
83+
}
84+
proxyURLs = append(proxyURLs, proxyURL)
85+
dialer := httpProxyDialer{proxyURL: proxyURL, forward: proxy.Direct}
86+
dialers[proxyURL.String()] = dialer
87+
}
88+
89+
robin, err := rbtransport.NewWithOptions(1, toStringSlice(proxyURLs)...)
90+
if err != nil {
91+
return nil, err
92+
}
93+
94+
return &httpProxyRoundRobinDialer{proxyDialers: dialers, transport: robin}, nil
95+
}
96+
97+
type socks5ProxyRoundRobinDialer struct {
98+
proxyDialers map[string]proxy.Dialer
99+
robin *rbtransport.RoundTransport
100+
}
101+
102+
// Dial connects to the address on the named network via one of the SOCKS5 proxies using round-robin scheduling.
103+
func (d *socks5ProxyRoundRobinDialer) Dial(network, addr string) (net.Conn, error) {
104+
nextProxyURL := d.robin.Next()
105+
dialer, ok := d.proxyDialers[nextProxyURL]
106+
if !ok {
107+
return nil, fmt.Errorf("no matching proxy dialer found")
108+
}
109+
return dialer.Dial(network, addr)
110+
}
111+
112+
func newSOCKS5ProxyRoundRobinDialer(upstreamProxies []string) (proxy.Dialer, error) {
113+
if len(upstreamProxies) == 0 {
114+
return nil, fmt.Errorf("proxy URLs cannot be empty")
115+
}
116+
117+
proxyURLs := make([]*url.URL, 0, len(upstreamProxies))
118+
dialers := make(map[string]proxy.Dialer)
119+
for _, proxyAddr := range upstreamProxies {
120+
proxyURL, err := url.Parse(proxyAddr)
121+
if err != nil {
122+
return nil, err
123+
}
124+
proxyURLs = append(proxyURLs, proxyURL)
125+
var auth *proxy.Auth
126+
if proxyURL.User != nil {
127+
password, _ := proxyURL.User.Password()
128+
auth = &proxy.Auth{
129+
User: proxyURL.User.Username(),
130+
Password: password,
131+
}
132+
}
133+
dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
134+
if err != nil {
135+
return nil, err
136+
}
137+
dialers[proxyAddr] = dialer
138+
}
139+
140+
robin, err := rbtransport.NewWithOptions(1, toStringSlice(proxyURLs)...)
141+
if err != nil {
142+
return nil, err
143+
}
144+
145+
return &socks5ProxyRoundRobinDialer{proxyDialers: dialers, robin: robin}, nil
146+
}
147+
148+
func toStringSlice(urls []*url.URL) []string {
149+
s := make([]string, len(urls))
150+
for i, u := range urls {
151+
s[i] = u.String()
152+
}
153+
return s
154+
}

upstream_test.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package proxify
2+
3+
import (
4+
"net/url"
5+
"reflect"
6+
"testing"
7+
)
8+
9+
func TestNewHTTPProxyRoundRobinDialer(t *testing.T) {
10+
tests := []struct {
11+
name string
12+
upstreamProxies []string
13+
shouldThrowErr bool
14+
}{
15+
{
16+
name: "empty",
17+
upstreamProxies: []string{},
18+
shouldThrowErr: true,
19+
},
20+
{
21+
name: "one",
22+
upstreamProxies: []string{"http://localhost:7777"},
23+
shouldThrowErr: false,
24+
},
25+
{
26+
name: "multiple",
27+
upstreamProxies: []string{"http://localhost:7777", "http://localhost:9999"},
28+
shouldThrowErr: false,
29+
},
30+
{
31+
name: "invalid",
32+
upstreamProxies: []string{"http://:invalid"},
33+
shouldThrowErr: true,
34+
},
35+
}
36+
for _, test := range tests {
37+
t.Run(test.name, func(t *testing.T) {
38+
actual, actualErr := newHTTPProxyRoundRobinDialer(test.upstreamProxies)
39+
if (actualErr != nil) != test.shouldThrowErr {
40+
t.Errorf("newHTTPProxyRoundRobinDialer() actualErr = %v, shouldThrowErr = %v", actualErr, test.shouldThrowErr)
41+
return
42+
}
43+
if !test.shouldThrowErr && actual == nil {
44+
t.Errorf("newHTTPProxyRoundRobinDialer() actual = %v, expected non-nil", actual)
45+
}
46+
})
47+
}
48+
}
49+
50+
func TestNewSOCKS5ProxyRoundRobinDialer(t *testing.T) {
51+
tests := []struct {
52+
name string
53+
upstreamProxies []string
54+
shouldThrowErr bool
55+
}{
56+
{
57+
name: "empty",
58+
upstreamProxies: []string{},
59+
shouldThrowErr: true,
60+
},
61+
{
62+
name: "one",
63+
upstreamProxies: []string{"socks5://localhost:10070"},
64+
shouldThrowErr: false,
65+
},
66+
{
67+
name: "multiple",
68+
upstreamProxies: []string{"socks5://localhost:10070", "socks5://localhost:10090"},
69+
shouldThrowErr: false,
70+
},
71+
{
72+
name: "invalid",
73+
upstreamProxies: []string{"socks5://:invalid"},
74+
shouldThrowErr: true,
75+
},
76+
{
77+
name: "auth",
78+
upstreamProxies: []string{"socks5://user:pass@localhost:10070"},
79+
shouldThrowErr: false,
80+
},
81+
}
82+
for _, test := range tests {
83+
t.Run(test.name, func(t *testing.T) {
84+
actual, actualErr := newSOCKS5ProxyRoundRobinDialer(test.upstreamProxies)
85+
if (actualErr != nil) != test.shouldThrowErr {
86+
t.Errorf("newSOCKS5ProxyRoundRobinDialer() actualErr = %v, shouldThrowErr = %v", actualErr, test.shouldThrowErr)
87+
return
88+
}
89+
if !test.shouldThrowErr && actual == nil {
90+
t.Errorf("newSOCKS5ProxyRoundRobinDialer() actual = %v, expected non-nil", actual)
91+
}
92+
})
93+
}
94+
}
95+
96+
func TestToStringSlice(t *testing.T) {
97+
tests := []struct {
98+
name string
99+
urls []*url.URL
100+
expected []string
101+
}{
102+
{
103+
name: "single",
104+
urls: []*url.URL{{Scheme: "http", Host: "localhost:8080"}},
105+
expected: []string{"http://localhost:8080"},
106+
},
107+
{
108+
name: "multiple",
109+
urls: []*url.URL{
110+
{Scheme: "http", Host: "localhost:8080"},
111+
{Scheme: "socks5", User: url.UserPassword("user", "pass"), Host: "localhost:8081"},
112+
},
113+
expected: []string{"http://localhost:8080", "socks5://user:pass@localhost:8081"},
114+
},
115+
{
116+
name: "empty",
117+
urls: []*url.URL{},
118+
expected: []string{},
119+
},
120+
}
121+
for _, test := range tests {
122+
t.Run(test.name, func(t *testing.T) {
123+
if actual := toStringSlice(test.urls); !reflect.DeepEqual(actual, test.expected) {
124+
t.Errorf("toStringSlice() actual = %v, expected = %v", actual, test.expected)
125+
}
126+
})
127+
}
128+
}

0 commit comments

Comments
 (0)