Skip to content

Commit 0bb9b43

Browse files
authored
Merge pull request #6 from lxzan/dev
Connection Pool Parameter Optimization
2 parents bb20a42 + d209223 commit 0bb9b43

File tree

3 files changed

+85
-32
lines changed

3 files changed

+85
-32
lines changed

client_test.go

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
"net/url"
78
"strconv"
89
"sync/atomic"
910
"testing"
@@ -14,8 +15,6 @@ import (
1415
"github.com/stretchr/testify/assert"
1516
)
1617

17-
const testURL = "https://api.github.com"
18-
1918
var _port = int64(10086)
2019

2120
func nextAddr() string {
@@ -159,6 +158,21 @@ func TestRequest_SetQuery(t *testing.T) {
159158
resp := Get("http://%s", addr).SetQuery(nil).Send(nil)
160159
assert.Error(t, resp.Err())
161160
})
161+
162+
t.Run("", func(t *testing.T) {
163+
req := Get("http://%s", addr).SetQuery(url.Values{
164+
"name": []string{"xxx"},
165+
})
166+
assert.Equal(t, req.url, "http://"+addr+"?name=xxx")
167+
})
168+
169+
t.Run("", func(t *testing.T) {
170+
type Req struct {
171+
Name string `form:"name"`
172+
}
173+
req := Get("http://%s", addr).SetQuery(Req{Name: "xxx"})
174+
assert.Equal(t, req.url, "http://"+addr+"?name=xxx")
175+
})
162176
}
163177

164178
func TestRequest_Send(t *testing.T) {
@@ -215,24 +229,40 @@ func TestMiddleware(t *testing.T) {
215229
time.Sleep(100 * time.Millisecond)
216230

217231
t.Run("before", func(t *testing.T) {
218-
opt := WithBefore(func(ctx context.Context, request *http.Request) (context.Context, error) {
232+
before := func(ctx context.Context, request *http.Request) (context.Context, error) {
219233
return ctx, errors.New("status error")
220-
})
221-
cli, _ := NewClient(opt)
222-
resp := cli.Post("http://%s/404", addr).Send(nil)
223-
assert.Error(t, resp.Err())
234+
}
235+
236+
{
237+
cli, _ := NewClient(WithBefore(before))
238+
resp := cli.Post("http://%s/404", addr).Send(nil)
239+
assert.Error(t, resp.Err())
240+
}
241+
242+
{
243+
resp := Post("http://%s/404", addr).SetBefore(before).Send(nil)
244+
assert.Error(t, resp.Err())
245+
}
224246
})
225247

226248
t.Run("after", func(t *testing.T) {
227-
opt := WithAfter(func(ctx context.Context, response *http.Response) (context.Context, error) {
249+
after := func(ctx context.Context, response *http.Response) (context.Context, error) {
228250
if response.StatusCode != http.StatusOK {
229251
return ctx, errors.New("status error")
230252
}
231253
return ctx, nil
232-
})
233-
cli, _ := NewClient(opt)
234-
resp := cli.Post("http://%s/404", addr).Send(nil)
235-
assert.Error(t, resp.Err())
254+
}
255+
256+
{
257+
cli, _ := NewClient(WithAfter(after))
258+
resp := cli.Post("http://%s/404", addr).Send(nil)
259+
assert.Error(t, resp.Err())
260+
}
261+
262+
{
263+
resp := Post("http://%s/404", addr).SetAfter(after).Send(nil)
264+
assert.Error(t, resp.Err())
265+
}
236266
})
237267

238268
t.Run("latency", func(t *testing.T) {

config.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ import (
77
"time"
88
)
99

10-
const defaultTimeout = 30 * time.Second
10+
const (
11+
defaultTimeout = 30 * time.Second
12+
defaultMaxIdleConnsPerHost = 128
13+
defaultMaxConnsPerHost = 128
14+
)
1115

1216
type (
1317
BeforeFunc func(ctx context.Context, request *http.Request) (context.Context, error)
@@ -19,6 +23,10 @@ var (
1923

2024
defaultClient, _ = NewClient(WithHTTPClient(&http.Client{
2125
Timeout: defaultTimeout,
26+
Transport: &http.Transport{
27+
MaxIdleConnsPerHost: defaultMaxIdleConnsPerHost,
28+
MaxConnsPerHost: defaultMaxConnsPerHost,
29+
},
2230
}))
2331

2432
defaultBeforeFunc BeforeFunc = func(ctx context.Context, request *http.Request) (context.Context, error) {
@@ -83,7 +91,11 @@ func withInitialize() Option {
8391

8492
if c.HTTPClient == nil {
8593
c.HTTPClient = &http.Client{
86-
Timeout: 30 * time.Second,
94+
Timeout: defaultTimeout,
95+
Transport: &http.Transport{
96+
MaxIdleConnsPerHost: defaultMaxIdleConnsPerHost,
97+
MaxConnsPerHost: defaultMaxConnsPerHost,
98+
},
8799
}
88100
}
89101
}

request.go

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ package hasaki
22

33
import (
44
"context"
5-
"fmt"
6-
"io"
75
"net/http"
86
neturl "net/url"
97

@@ -50,6 +48,20 @@ func Delete(url string, args ...any) *Request {
5048
return defaultClient.Delete(url, args...)
5149
}
5250

51+
// SetBefore 设置请求前中间件
52+
// Setting up pre-request middleware
53+
func (c *Request) SetBefore(f BeforeFunc) *Request {
54+
c.before = f
55+
return c
56+
}
57+
58+
// SetAfter 设置请求后中间件
59+
// Setting up post-request middleware
60+
func (c *Request) SetAfter(f AfterFunc) *Request {
61+
c.after = f
62+
return c
63+
}
64+
5365
// SetEncoder 设置编码器
5466
// Set request body encoder
5567
func (c *Request) SetEncoder(encoder Encoder) *Request {
@@ -87,22 +99,22 @@ func (c *Request) SetQuery(query any) *Request {
8799
return c
88100
}
89101

90-
str := fmt.Sprintf("%s://%s%s", URL.Scheme, URL.Host, URL.Path)
91102
switch v := query.(type) {
92103
case string:
93104
if len(v) > 0 {
94-
str += "?" + v
105+
URL.RawQuery = v
95106
}
107+
case neturl.Values:
108+
URL.RawQuery = v.Encode()
96109
default:
97-
str += "?"
98110
values, err := form.NewEncoder().Encode(query)
99111
if err != nil {
100112
c.err = errors.WithStack(err)
101113
return c
102114
}
103-
str += values.Encode()
115+
URL.RawQuery = values.Encode()
104116
}
105-
c.url = str
117+
c.url = URL.String()
106118
return c
107119
}
108120

@@ -115,13 +127,10 @@ func (c *Request) Send(body any) *Response {
115127
return response
116128
}
117129

118-
reader, ok := body.(io.Reader)
119-
if !ok {
120-
reader, c.err = c.encoder.Encode(body)
121-
if c.err != nil {
122-
response.err = c.err
123-
return response
124-
}
130+
reader, err := c.encoder.Encode(body)
131+
if err != nil {
132+
response.err = err
133+
return response
125134
}
126135

127136
req, err1 := http.NewRequestWithContext(c.ctx, c.method, c.url, reader)
@@ -130,16 +139,18 @@ func (c *Request) Send(body any) *Response {
130139
return response
131140
}
132141

142+
if c.method == http.MethodGet && body == nil {
143+
c.headers.Del("Content-Type")
144+
}
145+
req.Header = c.headers
146+
133147
// 执行请求前中间件
134148
response.ctx, response.err = c.before(c.ctx, req)
135149
if response.err != nil {
136150
return response
137151
}
138152

139-
if c.method == http.MethodGet && body == nil {
140-
c.headers.Del("Content-Type")
141-
}
142-
req.Header = c.headers
153+
// 发起请求
143154
resp, err2 := c.client.Do(req)
144155
if err2 != nil {
145156
response.err = errors.WithStack(err2)

0 commit comments

Comments
 (0)