Skip to content

Commit 61f893e

Browse files
committed
Improve naive client interface
1 parent 548e259 commit 61f893e

File tree

3 files changed

+57
-49
lines changed

3 files changed

+57
-49
lines changed

naive_client.go

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,38 @@ import (
2020
N "github.com/sagernet/sing/common/network"
2121
)
2222

23-
type NaiveClientConfig struct {
24-
Context context.Context
25-
ServerAddress M.Socksaddr
26-
ServerName string
27-
Username string
28-
Password string
29-
Concurrency int
30-
ExtraHeaders map[string]string
31-
32-
TrustedRootCertificates string // PEM format
33-
CertificatePublicKeySHA256 [][]byte // SPKI SHA256 hashes
34-
35-
Dialer N.Dialer
36-
}
23+
var _ N.Dialer = (*NaiveClient)(nil)
3724

3825
type NaiveClient struct {
39-
ctx context.Context
40-
dialer N.Dialer
41-
serverAddress M.Socksaddr
42-
serverName string
43-
serverURL string
44-
authorization string
45-
extraHeaders map[string]string
46-
trustedRootCertificates string
47-
certificatePublicKeySHA256 [][]byte
48-
concurrency int
49-
counter atomic.Uint64
50-
engine Engine
51-
streamEngine StreamEngine
52-
activeConnections sync.WaitGroup
53-
proxyWaitGroup sync.WaitGroup
54-
proxyCancel context.CancelFunc
26+
ctx context.Context
27+
dialer N.Dialer
28+
serverAddress M.Socksaddr
29+
serverName string
30+
serverURL string
31+
authorization string
32+
extraHeaders map[string]string
33+
trustedRootCertificates string
34+
trustedCertificatePublicKeySHA256 [][]byte
35+
concurrency int
36+
counter atomic.Uint64
37+
engine Engine
38+
streamEngine StreamEngine
39+
activeConnections sync.WaitGroup
40+
proxyWaitGroup sync.WaitGroup
41+
proxyCancel context.CancelFunc
42+
}
43+
44+
type NaiveClientConfig struct {
45+
Context context.Context
46+
ServerAddress M.Socksaddr
47+
ServerName string
48+
Username string
49+
Password string
50+
InsecureConcurrency int
51+
ExtraHeaders map[string]string
52+
TrustedRootCertificates string // PEM format
53+
TrustedCertificatePublicKeySHA256 [][]byte // SPKI SHA256 hashes
54+
Dialer N.Dialer
5555
}
5656

5757
func NewNaiveClient(config NaiveClientConfig) (*NaiveClient, error) {
@@ -79,7 +79,7 @@ func NewNaiveClient(config NaiveClientConfig) (*NaiveClient, error) {
7979
[]byte(config.Username+":"+config.Password))
8080
}
8181

82-
concurrency := config.Concurrency
82+
concurrency := config.InsecureConcurrency
8383
if concurrency < 1 {
8484
concurrency = 1
8585
}
@@ -95,24 +95,24 @@ func NewNaiveClient(config NaiveClientConfig) (*NaiveClient, error) {
9595
}
9696

9797
return &NaiveClient{
98-
ctx: ctx,
99-
dialer: dialer,
100-
serverAddress: config.ServerAddress,
101-
serverName: serverName,
102-
serverURL: serverURL.String(),
103-
authorization: authorization,
104-
extraHeaders: config.ExtraHeaders,
105-
trustedRootCertificates: config.TrustedRootCertificates,
106-
certificatePublicKeySHA256: config.CertificatePublicKeySHA256,
107-
concurrency: concurrency,
98+
ctx: ctx,
99+
dialer: dialer,
100+
serverAddress: config.ServerAddress,
101+
serverName: serverName,
102+
serverURL: serverURL.String(),
103+
authorization: authorization,
104+
extraHeaders: config.ExtraHeaders,
105+
trustedRootCertificates: config.TrustedRootCertificates,
106+
trustedCertificatePublicKeySHA256: config.TrustedCertificatePublicKeySHA256,
107+
concurrency: concurrency,
108108
}, nil
109109
}
110110

111111
func (c *NaiveClient) Start() error {
112112
engine := NewEngine()
113113

114-
if len(c.certificatePublicKeySHA256) > 0 {
115-
if !engine.SetCertVerifierWithPublicKeySHA256(c.certificatePublicKeySHA256) {
114+
if len(c.trustedCertificatePublicKeySHA256) > 0 {
115+
if !engine.SetCertVerifierWithPublicKeySHA256(c.trustedCertificatePublicKeySHA256) {
116116
return E.New("failed to set certificate public key SHA256 verifier")
117117
}
118118
} else if c.trustedRootCertificates != "" {
@@ -220,7 +220,10 @@ func (c *NaiveClient) DialEarly(destination M.Socksaddr) (NaiveConn, error) {
220220
}, nil
221221
}
222222

223-
func (c *NaiveClient) DialContext(ctx context.Context, destination M.Socksaddr) (NaiveConn, error) {
223+
func (c *NaiveClient) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
224+
if N.NetworkName(network) != N.NetworkTCP {
225+
return nil, os.ErrInvalid
226+
}
224227
conn, err := c.DialEarly(destination)
225228
if err != nil {
226229
return nil, err
@@ -233,6 +236,10 @@ func (c *NaiveClient) DialContext(ctx context.Context, destination M.Socksaddr)
233236
return conn, nil
234237
}
235238

239+
func (c *NaiveClient) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
240+
return nil, os.ErrInvalid
241+
}
242+
236243
func (c *NaiveClient) Close() error {
237244
if c.proxyCancel != nil {
238245
c.proxyCancel()

test/integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ func TestNaiveRapidOpenClose(t *testing.T) {
224224
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
225225
defer cancel()
226226

227-
conn, err := client.DialEarly(M.ParseSocksaddrHostPort("127.0.0.1", 17004))
227+
conn, err := client.DialContext(ctx, N.NetworkTCP, M.ParseSocksaddrHostPort("127.0.0.1", 17004))
228228
if err != nil {
229229
t.Logf("iteration %d: dial failed (acceptable): %v", i, err)
230230
return

test/main_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
cronet "github.com/sagernet/cronet-go"
3131
"github.com/sagernet/sing/common/bufio"
3232
M "github.com/sagernet/sing/common/metadata"
33+
N "github.com/sagernet/sing/common/network"
3334

3435
"github.com/stretchr/testify/require"
3536
"go.uber.org/goleak"
@@ -80,7 +81,7 @@ func (e *testEnv) newNaiveClient(t *testing.T, config cronet.NaiveClientConfig)
8081
if config.Password == "" {
8182
config.Password = "test"
8283
}
83-
if config.TrustedRootCertificates == "" && len(config.CertificatePublicKeySHA256) == 0 {
84+
if config.TrustedRootCertificates == "" && len(config.TrustedCertificatePublicKeySHA256) == 0 {
8485
config.TrustedRootCertificates = string(e.caPEM)
8586
}
8687
client, err := cronet.NewNaiveClient(config)
@@ -168,7 +169,7 @@ func TestNaiveConcurrency(t *testing.T) {
168169
startIperf3ServerOnPort(t, uint16(iperf3Port+i))
169170
}
170171
env := setupTestEnv(t)
171-
client := env.newNaiveClient(t, cronet.NaiveClientConfig{Concurrency: 3})
172+
client := env.newNaiveClient(t, cronet.NaiveClientConfig{InsecureConcurrency: 3})
172173

173174
var waitGroup sync.WaitGroup
174175
for i := 0; i < concurrencyCount; i++ {
@@ -216,7 +217,7 @@ func TestNaivePublicKeySHA256(t *testing.T) {
216217
pinHash := sha256.Sum256(spkiBytes)
217218

218219
client := env.newNaiveClient(t, cronet.NaiveClientConfig{
219-
CertificatePublicKeySHA256: [][]byte{pinHash[:]},
220+
TrustedCertificatePublicKeySHA256: [][]byte{pinHash[:]},
220221
})
221222
startEchoServer(t, 15001)
222223

@@ -248,7 +249,7 @@ func TestNaiveCloseWhileReading(t *testing.T) {
248249
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
249250
defer cancel()
250251

251-
conn, err := client.DialEarly(M.ParseSocksaddrHostPort("127.0.0.1", 16000))
252+
conn, err := client.DialContext(ctx, N.NetworkTCP, M.ParseSocksaddrHostPort("127.0.0.1", 16000))
252253
if err != nil {
253254
t.Logf("iteration %d: dial failed: %v", i, err)
254255
return

0 commit comments

Comments
 (0)