Skip to content

Commit c62b25f

Browse files
committed
feat(courierhttp/client): h2c supported
1 parent aeb0b7b commit c62b25f

File tree

4 files changed

+103
-105
lines changed

4 files changed

+103
-105
lines changed

example/cmd/example/main_test.go

Lines changed: 59 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,127 +3,93 @@ package main_test
33
import (
44
"bytes"
55
"context"
6-
"crypto/tls"
76
"fmt"
87
"io"
9-
"net"
108
"net/http"
9+
"net/http/httptest"
10+
"net/http/httputil"
1111
"testing"
12-
"time"
1312

1413
"github.com/go-courier/logr"
1514
"github.com/go-courier/logr/slog"
1615
"github.com/octohelm/courier/example/apis"
1716
"github.com/octohelm/courier/example/client/example"
1817
domainorg "github.com/octohelm/courier/example/pkg/domain/org"
1918
"github.com/octohelm/courier/internal/testingutil"
20-
"github.com/octohelm/courier/pkg/courierhttp/client"
2119
"github.com/octohelm/courier/pkg/courierhttp/handler/httprouter"
2220
testingx "github.com/octohelm/x/testing"
23-
"golang.org/x/net/http2"
21+
"github.com/octohelm/x/testing/bdd"
2422
)
2523

26-
var htLogger = client.HttpTransportFunc(func(req *http.Request, next client.RoundTrip) (*http.Response, error) {
27-
startedAt := time.Now()
28-
29-
ctx, logger := logr.Start(req.Context(), "Request")
30-
defer logger.End()
31-
32-
resp, err := next(req.WithContext(ctx))
33-
34-
defer func() {
35-
cost := time.Since(startedAt)
36-
37-
logger := logger.WithValues(
38-
"cost", fmt.Sprintf("%0.3fms", float64(cost/time.Millisecond)),
39-
"method", req.Method,
40-
"url", req.URL.String(),
41-
"metadata", req.Header,
42-
"http.content-length", req.ContentLength,
43-
)
44-
45-
if err == nil {
46-
logger.WithValues("response.proto", resp.Proto).Info("success")
47-
} else {
48-
logger.Warn(fmt.Errorf("http request failed: %w", err))
49-
}
50-
}()
51-
52-
return resp, err
53-
})
54-
5524
func TestAll(t *testing.T) {
56-
h, err := httprouter.New(apis.R, "example")
57-
testingx.Expect(t, err, testingx.BeNil[error]())
25+
h := bdd.Must(httprouter.New(apis.R, "example"))
5826

59-
srv := testingutil.Serve(t, h)
27+
hh := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
28+
raw, _ := httputil.DumpRequest(request, true)
29+
fmt.Println(string(raw))
6030

61-
c := &example.Client{
62-
Endpoint: srv.URL,
63-
HttpTransports: []client.HttpTransport{htLogger},
64-
}
65-
ctx := c.InjectContext(context.Background())
66-
ctx = logr.WithLogger(ctx, slog.Logger(slog.Default()))
31+
h.ServeHTTP(writer, request)
32+
})
6733

68-
t.Run("Do Some Request", func(t *testing.T) {
69-
org := &example.GetOrg{}
70-
org.OrgName = "test"
34+
for i, srv := range []*httptest.Server{
35+
testingutil.ServeWithH2C(t, hh),
36+
testingutil.Serve(t, hh),
37+
} {
38+
t.Run(fmt.Sprintf("serve http/%d", 2-i), func(t *testing.T) {
39+
c := &example.Client{
40+
Endpoint: srv.URL,
41+
SupportH2C: i == 0,
42+
}
7143

72-
resp, err := example.Do(ctx, org)
73-
testingx.Expect(t, err, testingx.BeNil[error]())
74-
testingx.Expect(t, resp.Name, testingx.Be(org.OrgName))
75-
testingx.Expect(t, resp.Type, testingx.Be(domainorg.TYPE__GOV))
76-
})
44+
ctx := c.InjectContext(context.Background())
45+
ctx = logr.WithLogger(ctx, slog.Logger(slog.Default()))
7746

78-
t.Run("Do Some Request with h2", func(t *testing.T) {
79-
org := &example.GetOrg{}
80-
org.OrgName = "test"
47+
t.Run("Do Some Request", func(t *testing.T) {
48+
org := &example.GetOrg{}
49+
org.OrgName = "test"
8150

82-
resp, err := example.Do(client.ContextWithRoundTripperCreator(ctx, func() http.RoundTripper {
83-
return &http2.Transport{
84-
AllowHTTP: true,
85-
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
86-
return net.Dial(network, addr)
87-
},
88-
}
89-
}), org)
90-
testingx.Expect(t, err, testingx.BeNil[error]())
91-
testingx.Expect(t, resp.Name, testingx.Be(org.OrgName))
92-
})
51+
resp, err := example.Do(ctx, org)
52+
testingx.Expect(t, err, testingx.BeNil[error]())
9353

94-
t.Run("Upload", func(t *testing.T) {
95-
v := &example.UploadBlob{}
96-
v.RequestBody = io.NopCloser(bytes.NewBufferString("1234567"))
54+
testingx.Expect(t, resp.Name, testingx.Be(org.OrgName))
55+
testingx.Expect(t, resp.Type, testingx.Be(domainorg.TYPE__GOV))
56+
})
9757

98-
_, err := example.Do(ctx, v)
99-
testingx.Expect(t, err, testingx.BeNil[error]())
100-
})
58+
t.Run("Upload", func(t *testing.T) {
59+
v := &example.UploadBlob{}
60+
v.RequestBody = io.NopCloser(bytes.NewBufferString("1234567"))
10161

102-
t.Run("UploadStoreBlob", func(t *testing.T) {
103-
v := &example.UploadStoreBlob{}
104-
v.Scope = "a/b/c"
105-
v.RequestBody = io.NopCloser(bytes.NewBufferString("1234567"))
62+
_, err := example.Do(ctx, v)
63+
testingx.Expect(t, err, testingx.BeNil[error]())
64+
})
10665

107-
_, err := example.Do(ctx, v)
108-
testingx.Expect(t, err, testingx.BeNil[error]())
109-
})
66+
t.Run("UploadStoreBlob", func(t *testing.T) {
67+
v := &example.UploadStoreBlob{}
68+
v.Scope = "a/b/c"
69+
v.RequestBody = io.NopCloser(bytes.NewBufferString("1234567"))
11070

111-
t.Run("GetStoreBlob", func(t *testing.T) {
112-
v := &example.GetStoreBlob{}
113-
v.Scope = "a/b/c"
114-
v.Digest = "xxx"
71+
_, err := example.Do(ctx, v)
72+
testingx.Expect(t, err, testingx.BeNil[error]())
73+
})
11574

116-
resp, err := example.Do(ctx, v)
117-
testingx.Expect(t, err, testingx.BeNil[error]())
118-
testingx.Expect(t, *resp, testingx.Be("a/b/c@xxx"))
119-
})
75+
t.Run("GetStoreBlob", func(t *testing.T) {
76+
v := &example.GetStoreBlob{}
77+
v.Scope = "a/b/c"
78+
v.Digest = "xxx"
12079

121-
t.Run("GetFile", func(t *testing.T) {
122-
v := &example.GetFile{}
123-
v.Path = "a/b/c"
80+
resp, err := example.Do(ctx, v)
81+
testingx.Expect(t, err, testingx.BeNil[error]())
82+
testingx.Expect(t, *resp, testingx.Be("a/b/c@xxx"))
83+
})
12484

125-
resp, err := example.Do(ctx, v)
126-
testingx.Expect(t, err, testingx.BeNil[error]())
127-
testingx.Expect(t, *resp, testingx.Be("a/b/c"))
128-
})
85+
t.Run("GetFile", func(t *testing.T) {
86+
v := &example.GetFile{}
87+
v.Path = "a/b/c"
88+
89+
resp, err := example.Do(ctx, v)
90+
testingx.Expect(t, err, testingx.BeNil[error]())
91+
testingx.Expect(t, *resp, testingx.Be("a/b/c"))
92+
})
93+
})
94+
}
12995
}

internal/testingutil/serve.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@ import (
1010
)
1111

1212
func Serve(t testing.TB, handler http.Handler) *httptest.Server {
13-
srv := httptest.NewServer(h2c.NewHandler(handler, &http2.Server{}))
14-
t.Cleanup(func() {
15-
srv.Close()
16-
})
13+
srv := httptest.NewUnstartedServer(handler)
14+
srv.Start()
15+
t.Cleanup(srv.Close)
16+
return srv
17+
}
18+
19+
func ServeWithH2C(t testing.TB, handler http.Handler) *httptest.Server {
20+
srv := httptest.NewUnstartedServer(h2c.NewHandler(handler, &http2.Server{}))
21+
srv.Start()
22+
t.Cleanup(srv.Close)
1723
return srv
1824
}

pkg/courierhttp/client/client.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ func (h *httpTransportFunc) RoundTrip(request *http.Request) (*http.Response, er
5050
}
5151

5252
type Client struct {
53-
Endpoint string `flag:""`
53+
Endpoint string `flag:""`
54+
SupportH2C bool `flag:",omitzero"`
55+
5456
NewError func() error
5557
HttpTransports []HttpTransport
5658
}
@@ -70,14 +72,19 @@ func (c *Client) Do(ctx context.Context, req any, metas ...courier.Metadata) cou
7072

7173
httpClient := HttpClientFromContext(ctx)
7274
if httpClient == nil {
73-
httpClient = GetReasonableClientContext(ctx, c.HttpTransports...)
74-
} else {
75-
if httpClient.Transport == nil {
76-
httpClient.Transport = http.DefaultTransport
77-
}
78-
httpClient.Transport = WithHttpTransports(c.HttpTransports...)(httpClient.Transport)
75+
httpClient = GetReasonableClientContext(ctx)
76+
}
77+
78+
if httpClient.Transport == nil {
79+
httpClient.Transport = reasonableRoundTripper
7980
}
8081

82+
if c.SupportH2C {
83+
httpClient.Transport = UpgradeToSupportH2c(httpClient.Transport)
84+
}
85+
86+
httpClient.Transport = WithHttpTransports(c.HttpTransports...)(httpClient.Transport)
87+
8188
resp, err := httpClient.Do(httpReq)
8289
if err != nil {
8390
if errors.Is(err, context.Canceled) {

pkg/courierhttp/client/http_default.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"net"
77
"net/http"
88
"time"
9+
10+
"golang.org/x/net/http2"
911
)
1012

1113
var reasonableRoundTripper = &http.Transport{
@@ -29,7 +31,6 @@ var reasonableRoundTripper = &http.Transport{
2931
}
3032

3133
var defaultTlsConfig = &tls.Config{}
32-
3334
var defaultHosts = Hosts{}
3435

3536
func AddHostAlias(hostAliases ...HostAlias) {
@@ -89,3 +90,21 @@ func GetShortConnClientContext(ctx context.Context, httpTransports ...HttpTransp
8990

9091
return &http.Client{Transport: WithHttpTransports(httpTransports...)(t)}
9192
}
93+
94+
func UpgradeToSupportH2c(t http.RoundTripper) http.RoundTripper {
95+
if t1, ok := t.(*http.Transport); ok {
96+
if !t1.DisableKeepAlives {
97+
if t2, err := http2.ConfigureTransports(t1); err == nil {
98+
t2.AllowHTTP = true
99+
t2.DialTLSContext = func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
100+
return t1.DialContext(ctx, network, addr)
101+
}
102+
t2.ConnPool = nil
103+
104+
return t2
105+
}
106+
}
107+
}
108+
109+
return t
110+
}

0 commit comments

Comments
 (0)