Skip to content

Commit ab7fe30

Browse files
authored
Merge pull request #2 from lestrrat-go/http
Create HTTP handler
2 parents 7c5778d + 73502ee commit ab7fe30

File tree

16 files changed

+1791
-384
lines changed

16 files changed

+1791
-384
lines changed

.github/workflows/codeql.yml

Lines changed: 0 additions & 76 deletions
This file was deleted.

component/resolver.go

Lines changed: 136 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
"net/url"
78
"strings"
89
)
910

1011
type modeKey struct{}
11-
type requestKey struct{}
12-
type responseKey struct{}
12+
type requestInfoKey struct{}
13+
type responseInfoKey struct{}
1314

1415
type Mode int
1516

@@ -18,6 +19,24 @@ const (
1819
ModeResponse
1920
)
2021

22+
// RequestInfo contains the discrete components needed for request signature resolution
23+
type RequestInfo struct {
24+
Headers http.Header
25+
Method string
26+
Scheme string
27+
Authority string
28+
Path string
29+
RawQuery string
30+
TargetURI string
31+
}
32+
33+
// ResponseInfo contains the discrete components needed for response signature resolution
34+
type ResponseInfo struct {
35+
Headers http.Header
36+
StatusCode int
37+
Request *RequestInfo // For response components that need request info
38+
}
39+
2140
// WithMode adds a mode to the context for later retrieval. IF unspecified,
2241
// the default mode is to resolve components for HTTP requests.
2342
func WithMode(ctx context.Context, mode Mode) context.Context {
@@ -32,23 +51,95 @@ func ModeFromContext(ctx context.Context) Mode {
3251
return mode
3352
}
3453

35-
// WithRequest adds an HTTP request to the context for later retrieval.
36-
func WithRequest(ctx context.Context, req *http.Request) context.Context {
37-
return context.WithValue(ctx, requestKey{}, req)
54+
55+
// WithRequestInfo adds request information to the context using discrete values
56+
func WithRequestInfo(ctx context.Context, headers http.Header, method, scheme, authority, path, rawQuery, targetURI string) context.Context {
57+
info := &RequestInfo{
58+
Headers: headers,
59+
Method: method,
60+
Scheme: scheme,
61+
Authority: authority,
62+
Path: path,
63+
RawQuery: rawQuery,
64+
TargetURI: targetURI,
65+
}
66+
return context.WithValue(ctx, requestInfoKey{}, info)
67+
}
68+
69+
func RequestInfoFromContext(ctx context.Context) (*RequestInfo, bool) {
70+
info, ok := ctx.Value(requestInfoKey{}).(*RequestInfo)
71+
return info, ok
72+
}
73+
74+
// WithResponseInfo adds response information to the context using discrete values
75+
func WithResponseInfo(ctx context.Context, headers http.Header, statusCode int, requestInfo *RequestInfo) context.Context {
76+
info := &ResponseInfo{
77+
Headers: headers,
78+
StatusCode: statusCode,
79+
Request: requestInfo,
80+
}
81+
return context.WithValue(ctx, responseInfoKey{}, info)
82+
}
83+
84+
func ResponseInfoFromContext(ctx context.Context) (*ResponseInfo, bool) {
85+
info, ok := ctx.Value(responseInfoKey{}).(*ResponseInfo)
86+
return info, ok
3887
}
3988

40-
func RequestFromContext(ctx context.Context) (*http.Request, bool) {
41-
req, ok := ctx.Value(requestKey{}).(*http.Request)
42-
return req, ok
89+
// Helper function to create RequestInfo from http.Request
90+
func RequestInfoFromHTTP(req *http.Request) *RequestInfo {
91+
if req == nil || req.URL == nil {
92+
return nil
93+
}
94+
95+
return &RequestInfo{
96+
Headers: req.Header,
97+
Method: req.Method,
98+
Scheme: req.URL.Scheme,
99+
Authority: req.URL.Host,
100+
Path: req.URL.Path,
101+
RawQuery: req.URL.RawQuery,
102+
TargetURI: req.URL.String(),
103+
}
43104
}
44105

45-
func WithResponse(ctx context.Context, resp *http.Response) context.Context {
46-
return context.WithValue(ctx, responseKey{}, resp)
106+
// WithRequestInfoFromHTTP is a convenience function that extracts request info from an http.Request
107+
// and adds it to the context.
108+
func WithRequestInfoFromHTTP(ctx context.Context, req *http.Request) context.Context {
109+
reqInfo := RequestInfoFromHTTP(req)
110+
if reqInfo == nil {
111+
return ctx
112+
}
113+
return WithRequestInfo(ctx, reqInfo.Headers, reqInfo.Method, reqInfo.Scheme,
114+
reqInfo.Authority, reqInfo.Path, reqInfo.RawQuery, reqInfo.TargetURI)
47115
}
48116

49-
func ResponseFromContext(ctx context.Context) (*http.Response, bool) {
50-
resp, ok := ctx.Value(responseKey{}).(*http.Response)
51-
return resp, ok
117+
// Helper function to create ResponseInfo from http.Response
118+
func ResponseInfoFromHTTP(resp *http.Response) *ResponseInfo {
119+
if resp == nil {
120+
return nil
121+
}
122+
123+
var requestInfo *RequestInfo
124+
if resp.Request != nil {
125+
requestInfo = RequestInfoFromHTTP(resp.Request)
126+
}
127+
128+
return &ResponseInfo{
129+
Headers: resp.Header,
130+
StatusCode: resp.StatusCode,
131+
Request: requestInfo,
132+
}
133+
}
134+
135+
// WithResponseInfoFromHTTP is a convenience function that extracts response info from an http.Response
136+
// and adds it to the context.
137+
func WithResponseInfoFromHTTP(ctx context.Context, resp *http.Response) context.Context {
138+
respInfo := ResponseInfoFromHTTP(resp)
139+
if respInfo == nil {
140+
return ctx
141+
}
142+
return WithResponseInfo(ctx, respInfo.Headers, respInfo.StatusCode, respInfo.Request)
52143
}
53144

54145
// Resolve resolves the component identifier to its value. Since the resolution
@@ -67,66 +158,49 @@ func Resolve(ctx context.Context, comp Identifier) (string, error) {
67158
}
68159

69160
func resolveRequest(ctx context.Context, comp Identifier) (string, error) {
70-
req, ok := RequestFromContext(ctx)
161+
reqInfo, ok := RequestInfoFromContext(ctx)
71162
if !ok {
72-
return "", fmt.Errorf("no request available in context")
163+
return "", fmt.Errorf("no request information available in context")
73164
}
74165

75166
compName := comp.name
76167
if strings.HasPrefix(compName, "@") {
77-
return resolveRequestDerivedComponent(ctx, comp)
168+
return resolveRequestDerivedComponentFromInfo(ctx, comp, reqInfo)
78169
}
79-
80-
return resolveHeader(ctx, comp, req.Header)
170+
return resolveHeader(ctx, comp, reqInfo.Headers)
81171
}
82172

83-
func resolveRequestDerivedComponent(ctx context.Context, comp Identifier) (string, error) {
84-
req, ok := RequestFromContext(ctx)
85-
if !ok {
86-
return "", fmt.Errorf("no request available in context")
87-
}
88-
173+
func resolveRequestDerivedComponentFromInfo(ctx context.Context, comp Identifier, reqInfo *RequestInfo) (string, error) {
89174
switch comp.name {
90175
case "@method":
91-
return req.Method, nil
176+
return reqInfo.Method, nil
92177
case "@scheme":
93-
if req.URL == nil {
94-
return "", fmt.Errorf("request URL is nil")
95-
}
96-
return req.URL.Scheme, nil
178+
return reqInfo.Scheme, nil
97179
case "@authority":
98-
if req.URL == nil {
99-
return "", fmt.Errorf("request URL is nil")
100-
}
101-
return req.URL.Host, nil
180+
return reqInfo.Authority, nil
102181
case "@path":
103-
if req.URL == nil {
104-
return "", fmt.Errorf("request URL is nil")
105-
}
106-
return req.URL.Path, nil
182+
return reqInfo.Path, nil
107183
case "@query":
108-
if req.URL == nil {
109-
return "", fmt.Errorf("request URL is nil")
110-
}
111-
if req.URL.RawQuery == "" {
184+
if reqInfo.RawQuery == "" {
112185
return "", fmt.Errorf("query component not found")
113186
}
114-
return "?" + req.URL.RawQuery, nil
187+
return "?" + reqInfo.RawQuery, nil
115188
case "@target-uri":
116-
if req.URL == nil {
117-
return "", fmt.Errorf("request URL is nil")
118-
}
119-
return req.URL.String(), nil
189+
return reqInfo.TargetURI, nil
120190
case "@query-param":
121-
if req.URL == nil {
122-
return "", fmt.Errorf("request URL is nil")
123-
}
124191
// Get the "name" parameter
125192
var paramName string
126193
if err := comp.GetParameter("name", &paramName); err != nil {
127194
return "", fmt.Errorf("@query-param requires 'name' parameter: %w", err)
128195
}
129-
values := req.URL.Query()[paramName]
196+
197+
// Parse query string to extract parameter
198+
queryValues, err := url.ParseQuery(reqInfo.RawQuery)
199+
if err != nil {
200+
return "", fmt.Errorf("failed to parse query: %w", err)
201+
}
202+
203+
values := queryValues[paramName]
130204
if len(values) == 0 {
131205
return "", fmt.Errorf("query parameter %q not found", paramName)
132206
}
@@ -136,26 +210,21 @@ func resolveRequestDerivedComponent(ctx context.Context, comp Identifier) (strin
136210
}
137211
}
138212

213+
139214
func resolveResponse(ctx context.Context, comp Identifier) (string, error) {
140-
resp, ok := ResponseFromContext(ctx)
141-
if !ok || resp == nil {
142-
return "", fmt.Errorf("no response available in context")
215+
respInfo, ok := ResponseInfoFromContext(ctx)
216+
if !ok {
217+
return "", fmt.Errorf("no response information available in context")
143218
}
144219

145220
compName := comp.name
146221
if strings.HasPrefix(compName, "@") {
147-
return resolveResponseDerivedComponent(ctx, comp)
222+
return resolveResponseDerivedComponentFromInfo(ctx, comp, respInfo)
148223
}
149-
150-
return resolveHeader(ctx, comp, resp.Header)
224+
return resolveHeader(ctx, comp, respInfo.Headers)
151225
}
152226

153-
func resolveResponseDerivedComponent(ctx context.Context, comp Identifier) (string, error) {
154-
resp, ok := ResponseFromContext(ctx)
155-
if !ok || resp == nil {
156-
return "", fmt.Errorf("no response available in context")
157-
}
158-
227+
func resolveResponseDerivedComponentFromInfo(ctx context.Context, comp Identifier, respInfo *ResponseInfo) (string, error) {
159228
switch comp.name {
160229
case "@method", "@scheme", "@authority", "@path", "@query":
161230
// Make sure that the ;req parameter is set
@@ -167,17 +236,18 @@ func resolveResponseDerivedComponent(ctx context.Context, comp Identifier) (stri
167236
return "", fmt.Errorf("'req' parameter must be true for %q component", comp.name)
168237
}
169238

170-
if _, ok := RequestFromContext(ctx); !ok {
171-
ctx = context.WithValue(ctx, requestKey{}, resp.Request)
239+
if respInfo.Request == nil {
240+
return "", fmt.Errorf("no request information available for %q component", comp.name)
172241
}
173-
return resolveRequestDerivedComponent(ctx, comp)
242+
return resolveRequestDerivedComponentFromInfo(ctx, comp, respInfo.Request)
174243
case "@status":
175-
return fmt.Sprintf("%d", resp.StatusCode), nil
244+
return fmt.Sprintf("%d", respInfo.StatusCode), nil
176245
default:
177246
return "", fmt.Errorf("unknown derived component: %s", comp.name)
178247
}
179248
}
180249

250+
181251
func resolveHeader(_ context.Context, comp Identifier, hdr http.Header) (string, error) {
182252
// Get header values (case-insensitive)
183253
values := hdr.Values(comp.name)

example_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ func ExampleSign() {
4949
inputValue := input.NewValueBuilder().AddDefinition(def).MustBuild()
5050

5151
// Sign the request
52-
err = htmsig.Sign(context.Background(), req, inputValue, privateKey)
52+
ctx := component.WithRequestInfoFromHTTP(context.Background(), req)
53+
err = htmsig.SignRequest(ctx, req.Header, inputValue, privateKey)
5354
if err != nil {
5455
panic(err)
5556
}
@@ -95,7 +96,8 @@ func ExampleVerify() {
9596
}
9697

9798
inputValue := input.NewValueBuilder().AddDefinition(def).MustBuild()
98-
err = htmsig.Sign(context.Background(), req, inputValue, privateKey)
99+
ctx := component.WithRequestInfoFromHTTP(context.Background(), req)
100+
err = htmsig.SignRequest(ctx, req.Header, inputValue, privateKey)
99101
if err != nil {
100102
panic(err)
101103
}
@@ -108,7 +110,8 @@ func ExampleVerify() {
108110
}
109111

110112
// Verify the request signature
111-
err = htmsig.Verify(context.Background(), req, keyResolver)
113+
ctx = component.WithRequestInfoFromHTTP(context.Background(), req)
114+
err = htmsig.VerifyRequest(ctx, req.Header, keyResolver)
112115
if err != nil {
113116
fmt.Printf("Verification failed: %v\n", err)
114117
return

0 commit comments

Comments
 (0)