Skip to content

Commit 2dc9581

Browse files
committed
feat: true async streaming handling for lambda function urls
1 parent 847f289 commit 2dc9581

File tree

9 files changed

+328
-172
lines changed

9 files changed

+328
-172
lines changed

README.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ import (
218218
func main() {
219219
adapter := [...] // see above
220220
h := [...] // see above
221-
h = WrapWithRecover(h, func(ctx context.Context, event events.APIGatewayV2HTTPRequest, panicValue any) (events.APIGatewayV2HTTPResponse, error) {
221+
h = handler.WrapWithRecover(h, func(ctx context.Context, event events.APIGatewayV2HTTPRequest, panicValue any) (events.APIGatewayV2HTTPResponse, error) {
222222
return events.APIGatewayV2HTTPResponse{
223223
StatusCode: 500,
224224
Headers: make(map[string]string),
@@ -240,4 +240,18 @@ Have a look at the existing event handlers:
240240
Have a look at the existing adapters:
241241
- [net/http](./adapter/vanilla/vanilla.go)
242242
- [Echo](./adapter/echo/echo.go)
243-
- [Fiber](./adapter/fiber/fiber.go)
243+
- [Fiber](./adapter/fiber/fiber.go)
244+
245+
## Build Tags
246+
You can opt-in to enable partial build by using the build-tag `lambdahttpadapter.partial`.
247+
248+
Once this build-tag is present, the following build-tags are available:
249+
- `lambdahttpadapter.vanilla` (enables the vanilla adapter)
250+
- `lambdahttpadapter.echo` (enables the echo adapter)
251+
- `lambdahttpadapter.fiber` (enables the fiber adapter)
252+
- `lambdahttpadapter.apigwv1` (enables API Gateway V1 handler)
253+
- `lambdahttpadapter.apigwv2` (enables API Gateway V2 handler)
254+
- `lambdahttpadapter.functionurl` (enables Lambda Function URL handler)
255+
256+
Also note that Lambda Function URL in Streaming-Mode requires the following build-tag to be set:
257+
- `lambda.norpc`

adapter/echo/echo.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.echo)
2+
13
package echo
24

35
import (

adapter/fiber/fiber.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.fiber)
2+
13
package fiber
24

35
import (
@@ -63,6 +65,8 @@ func (a adapter) adapterFunc(ctx context.Context, r *http.Request, w http.Respon
6365

6466
var fctx fasthttp.RequestCtx
6567
fctx.Init(httpReq, remoteAddr, nil)
68+
defer fasthttp.ReleaseResponse(&fctx.Response)
69+
6670
fctx.SetUserValue(sourceEventUserValueKey, handler.GetSourceEvent(ctx))
6771

6872
a.app.Handler()(&fctx)
@@ -76,9 +80,10 @@ func (a adapter) adapterFunc(ctx context.Context, r *http.Request, w http.Respon
7680
})
7781

7882
w.WriteHeader(fctx.Response.StatusCode())
79-
_, _ = w.Write(fctx.Response.Body())
83+
// release handled in defer
84+
_, err = io.Copy(w, fctx.Response.BodyStream())
8085

81-
return nil
86+
return err
8287
}
8388

8489
func NewAdapter(delegate *fiber.App) handler.AdapterFunc {

adapter/vanilla/vanilla.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.vanilla)
2+
13
package vanilla
24

35
import (

handler/api.go

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,16 @@ import (
77

88
type httpContextKey string
99

10-
var sourceEventContextKey httpContextKey = "github.com/its-felix/aws-lambda-go-adapter/httpadapter::sourceEventContextKey"
10+
var sourceEventContextKey httpContextKey = "github.com/its-felix/aws-lambda-go-http-adapter/api::sourceEventContextKey"
1111

12-
type RequestConverterFunc[In any] func(ctx context.Context, event In) (*http.Request, error)
13-
type ResponseInitializerFunc[W http.ResponseWriter] func(ctx context.Context) W
14-
type ResponseFinalizerFunc[W http.ResponseWriter, Out any] func(ctx context.Context, w W) (Out, error)
1512
type AdapterFunc func(ctx context.Context, r *http.Request, w http.ResponseWriter) error
16-
13+
type HandlerFunc[In any, Out any] func(ctx context.Context, event In, adapter AdapterFunc) (Out, error)
1714
type RecoverFunc[In any, Out any] func(ctx context.Context, event In, panicValue any) (Out, error)
1815

19-
func NewHandler[In any, W http.ResponseWriter, Out any](reqConverter RequestConverterFunc[In], resInitializer ResponseInitializerFunc[W], resFinalizer ResponseFinalizerFunc[W, Out], adapter AdapterFunc) func(context.Context, In) (Out, error) {
16+
func NewHandler[In any, Out any](handlerFunc HandlerFunc[In, Out], adapter AdapterFunc) func(context.Context, In) (Out, error) {
2017
return func(ctx context.Context, event In) (Out, error) {
21-
ctx = WithSourceEvent(ctx, event)
22-
23-
r, err := reqConverter(ctx, event)
24-
if err != nil {
25-
var def Out
26-
return def, err
27-
}
28-
29-
w := resInitializer(ctx)
30-
if err = adapter(ctx, r, w); err != nil {
31-
var def Out
32-
return def, err
33-
}
34-
35-
return resFinalizer(ctx, w)
18+
ctx = context.WithValue(ctx, sourceEventContextKey, event)
19+
return handlerFunc(ctx, event, adapter)
3620
}
3721
}
3822

@@ -55,10 +39,6 @@ func WrapWithRecover[In any, Out any](handler func(context.Context, In) (Out, er
5539
}
5640
}
5741

58-
func WithSourceEvent(ctx context.Context, event any) context.Context {
59-
return context.WithValue(ctx, sourceEventContextKey, event)
60-
}
61-
6242
func GetSourceEvent(ctx context.Context) any {
6343
return ctx.Value(sourceEventContextKey)
6444
}

handler/apigwv1.go

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.apigwv1)
2+
13
package handler
24

35
import (
6+
"bytes"
47
"context"
58
"encoding/base64"
69
"github.com/aws/aws-lambda-go/events"
710
"net/http"
811
"net/url"
12+
"strconv"
913
"unicode/utf8"
1014
)
1115

12-
func apiGwV1RequestConverter(ctx context.Context, event events.APIGatewayProxyRequest) (*http.Request, error) {
16+
func convertApiGwV1Request(ctx context.Context, event events.APIGatewayProxyRequest) (*http.Request, error) {
1317
q := make(url.Values)
1418

1519
if len(event.MultiValueQueryStringParameters) > 0 {
@@ -54,41 +58,84 @@ func apiGwV1RequestConverter(ctx context.Context, event events.APIGatewayProxyRe
5458
return req, nil
5559
}
5660

57-
func apiGwV1ResponseInitializer(ctx context.Context) *ResponseWriterProxy {
58-
return NewResponseWriterProxy()
61+
type apiGwV1ResponseWriter struct {
62+
headersWritten bool
63+
contentTypeSet bool
64+
contentLengthSet bool
65+
headers http.Header
66+
body bytes.Buffer
67+
res events.APIGatewayProxyResponse
5968
}
6069

61-
func apiGwV1ResponseFinalizer(ctx context.Context, w *ResponseWriterProxy) (events.APIGatewayProxyResponse, error) {
62-
out := events.APIGatewayProxyResponse{
63-
StatusCode: w.Status,
64-
Headers: make(map[string]string),
65-
}
70+
func (w *apiGwV1ResponseWriter) Header() http.Header {
71+
return w.headers
72+
}
6673

67-
for k, values := range w.Headers {
68-
if len(values) == 0 {
69-
out.Headers[k] = ""
70-
} else if len(values) == 1 {
71-
out.Headers[k] = values[0]
72-
} else {
73-
if out.MultiValueHeaders == nil {
74-
out.MultiValueHeaders = make(map[string][]string)
75-
}
74+
func (w *apiGwV1ResponseWriter) Write(p []byte) (int, error) {
75+
w.WriteHeader(http.StatusOK)
76+
return w.body.Write(p)
77+
}
78+
79+
func (w *apiGwV1ResponseWriter) WriteHeader(statusCode int) {
80+
if !w.headersWritten {
81+
w.headersWritten = true
82+
w.res.StatusCode = statusCode
7683

77-
out.MultiValueHeaders[k] = values
84+
for k, values := range w.headers {
85+
if len(values) == 0 {
86+
w.res.Headers[k] = ""
87+
} else if len(values) == 1 {
88+
w.res.Headers[k] = values[0]
89+
} else {
90+
if w.res.MultiValueHeaders == nil {
91+
w.res.MultiValueHeaders = make(map[string][]string)
92+
}
93+
94+
w.res.MultiValueHeaders[k] = values
95+
}
7896
}
7997
}
98+
}
99+
100+
func handleApiGwV1(ctx context.Context, event events.APIGatewayProxyRequest, adapter AdapterFunc) (events.APIGatewayProxyResponse, error) {
101+
req, err := convertApiGwV1Request(ctx, event)
102+
if err != nil {
103+
var def events.APIGatewayProxyResponse
104+
return def, err
105+
}
106+
107+
w := apiGwV1ResponseWriter{
108+
headers: make(http.Header),
109+
res: events.APIGatewayProxyResponse{
110+
Headers: make(map[string]string),
111+
},
112+
}
113+
114+
if err = adapter(ctx, req, &w); err != nil {
115+
var def events.APIGatewayProxyResponse
116+
return def, err
117+
}
118+
119+
b := w.body.Bytes()
120+
121+
if !w.contentTypeSet {
122+
w.res.Headers["Content-Type"] = http.DetectContentType(b)
123+
}
124+
125+
if !w.contentLengthSet {
126+
w.res.Headers["Content-Length"] = strconv.Itoa(len(b))
127+
}
80128

81-
b := w.Body.Bytes()
82129
if utf8.Valid(b) {
83-
out.Body = string(b)
130+
w.res.Body = string(b)
84131
} else {
85-
out.IsBase64Encoded = true
86-
out.Body = base64.StdEncoding.EncodeToString(b)
132+
w.res.IsBase64Encoded = true
133+
w.res.Body = base64.StdEncoding.EncodeToString(b)
87134
}
88135

89-
return out, nil
136+
return w.res, nil
90137
}
91138

92139
func NewAPIGatewayV1Handler(adapter AdapterFunc) func(context.Context, events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
93-
return NewHandler(apiGwV1RequestConverter, apiGwV1ResponseInitializer, apiGwV1ResponseFinalizer, adapter)
140+
return NewHandler(handleApiGwV1, adapter)
94141
}

handler/apigwv2.go

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.apigwv2)
2+
13
package handler
24

35
import (
6+
"bytes"
47
"context"
58
"encoding/base64"
69
"github.com/aws/aws-lambda-go/events"
710
"net/http"
11+
"strconv"
812
"strings"
913
"unicode/utf8"
1014
)
1115

12-
func apiGwV2RequestConverter(ctx context.Context, event events.APIGatewayV2HTTPRequest) (*http.Request, error) {
16+
func convertApiGwV2Request(ctx context.Context, event events.APIGatewayV2HTTPRequest) (*http.Request, error) {
1317
url := buildFullRequestURL(event.RequestContext.DomainName, event.RawPath, event.RequestContext.HTTP.Path, buildQuery(event.RawQueryString, event.QueryStringParameters))
1418
req, err := http.NewRequestWithContext(ctx, event.RequestContext.HTTP.Method, url, getBody(event.Body, event.IsBase64Encoded))
1519
if err != nil {
@@ -40,46 +44,89 @@ func apiGwV2RequestConverter(ctx context.Context, event events.APIGatewayV2HTTPR
4044
return req, nil
4145
}
4246

43-
func apiGwV2ResponseInitializer(ctx context.Context) *ResponseWriterProxy {
44-
return NewResponseWriterProxy()
47+
type apiGwV2ResponseWriter struct {
48+
headersWritten bool
49+
contentTypeSet bool
50+
contentLengthSet bool
51+
headers http.Header
52+
body bytes.Buffer
53+
res events.APIGatewayV2HTTPResponse
4554
}
4655

47-
func apiGwV2ResponseFinalizer(ctx context.Context, w *ResponseWriterProxy) (events.APIGatewayV2HTTPResponse, error) {
48-
out := events.APIGatewayV2HTTPResponse{
49-
StatusCode: w.Status,
50-
Headers: make(map[string]string),
51-
Cookies: make([]string, 0),
52-
}
56+
func (w *apiGwV2ResponseWriter) Header() http.Header {
57+
return w.headers
58+
}
59+
60+
func (w *apiGwV2ResponseWriter) Write(p []byte) (int, error) {
61+
w.WriteHeader(http.StatusOK)
62+
return w.body.Write(p)
63+
}
64+
65+
func (w *apiGwV2ResponseWriter) WriteHeader(statusCode int) {
66+
if !w.headersWritten {
67+
w.headersWritten = true
68+
w.res.StatusCode = statusCode
5369

54-
for k, values := range w.Headers {
55-
if strings.EqualFold("set-cookie", k) {
56-
out.Cookies = values
57-
} else {
58-
if len(values) == 0 {
59-
out.Headers[k] = ""
60-
} else if len(values) == 1 {
61-
out.Headers[k] = values[0]
70+
for k, values := range w.headers {
71+
if strings.EqualFold("set-cookie", k) {
72+
w.res.Cookies = values
6273
} else {
63-
if out.MultiValueHeaders == nil {
64-
out.MultiValueHeaders = make(map[string][]string)
65-
}
74+
if len(values) == 0 {
75+
w.res.Headers[k] = ""
76+
} else if len(values) == 1 {
77+
w.res.Headers[k] = values[0]
78+
} else {
79+
if w.res.MultiValueHeaders == nil {
80+
w.res.MultiValueHeaders = make(map[string][]string)
81+
}
6682

67-
out.MultiValueHeaders[k] = values
83+
w.res.MultiValueHeaders[k] = values
84+
}
6885
}
6986
}
7087
}
88+
}
89+
90+
func handleApiGwV2(ctx context.Context, event events.APIGatewayV2HTTPRequest, adapter AdapterFunc) (events.APIGatewayV2HTTPResponse, error) {
91+
req, err := convertApiGwV2Request(ctx, event)
92+
if err != nil {
93+
var def events.APIGatewayV2HTTPResponse
94+
return def, err
95+
}
96+
97+
w := apiGwV2ResponseWriter{
98+
headers: make(http.Header),
99+
res: events.APIGatewayV2HTTPResponse{
100+
Headers: make(map[string]string),
101+
Cookies: make([]string, 0),
102+
},
103+
}
104+
105+
if err = adapter(ctx, req, &w); err != nil {
106+
var def events.APIGatewayV2HTTPResponse
107+
return def, err
108+
}
109+
110+
b := w.body.Bytes()
111+
112+
if !w.contentTypeSet {
113+
w.res.Headers["Content-Type"] = http.DetectContentType(b)
114+
}
115+
116+
if !w.contentLengthSet {
117+
w.res.Headers["Content-Length"] = strconv.Itoa(len(b))
118+
}
71119

72-
b := w.Body.Bytes()
73120
if utf8.Valid(b) {
74-
out.Body = string(b)
121+
w.res.Body = string(b)
75122
} else {
76-
out.IsBase64Encoded = true
77-
out.Body = base64.StdEncoding.EncodeToString(b)
123+
w.res.IsBase64Encoded = true
124+
w.res.Body = base64.StdEncoding.EncodeToString(b)
78125
}
79126

80-
return out, nil
127+
return w.res, nil
81128
}
82129

83130
func NewAPIGatewayV2Handler(adapter AdapterFunc) func(context.Context, events.APIGatewayV2HTTPRequest) (events.APIGatewayV2HTTPResponse, error) {
84-
return NewHandler(apiGwV2RequestConverter, apiGwV2ResponseInitializer, apiGwV2ResponseFinalizer, adapter)
131+
return NewHandler(handleApiGwV2, adapter)
85132
}

0 commit comments

Comments
 (0)