diff --git a/events/apigw.go b/events/apigw.go index 05bf02a4..7aa48932 100644 --- a/events/apigw.go +++ b/events/apigw.go @@ -2,6 +2,13 @@ package events +import ( + "bytes" + "encoding/json" + "errors" + "io" +) + // APIGatewayProxyRequest contains data coming from the API Gateway proxy type APIGatewayProxyRequest struct { Resource string `json:"resource"` // The resource path defined in API Gateway @@ -27,6 +34,64 @@ type APIGatewayProxyResponse struct { IsBase64Encoded bool `json:"isBase64Encoded,omitempty"` } +// APIGatewayProxyStreamingResponse configures the response to be returned by API Gateway for the request. +// - integration type must be AWS_PROXY +// - integration uri must be arn::apigateway::lambda:path/2021-11-15/functions//response-streaming-invocations +// - integration response transfer mode must be STREAM +// +// If not using the above streaming integration, use APIGatewayProxyResponse instead +type APIGatewayProxyStreamingResponse struct { + prelude *bytes.Buffer + + StatusCode int + Headers map[string]string + MultiValueHeaders map[string][]string + Body io.Reader + Cookies []string +} + +func (r *APIGatewayProxyStreamingResponse) Read(p []byte) (n int, err error) { + if r.prelude == nil { + b, err := json.Marshal(struct { + StatusCode int `json:"statusCode,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + MultiValueHeaders map[string][]string `json:"multiValueHeaders,omitempty"` + Cookies []string `json:"cookies,omitempty"` + }{ + StatusCode: r.StatusCode, + Headers: r.Headers, + MultiValueHeaders: r.MultiValueHeaders, + Cookies: r.Cookies, + }) + if err != nil { + return 0, err + } + r.prelude = bytes.NewBuffer(append(b, 0, 0, 0, 0, 0, 0, 0, 0)) + } + if r.prelude.Len() > 0 { + return r.prelude.Read(p) + } + if r.Body == nil { + return 0, io.EOF + } + return r.Body.Read(p) +} + +func (r *APIGatewayProxyStreamingResponse) Close() error { + if closer, ok := r.Body.(io.ReadCloser); ok { + return closer.Close() + } + return nil +} + +func (r *APIGatewayProxyStreamingResponse) MarshalJSON() ([]byte, error) { + return nil, errors.New("not json") +} + +func (r *APIGatewayProxyStreamingResponse) ContentType() string { + return "application/vnd.awslambda.http-integration-response" +} + // APIGatewayProxyRequestContext contains the information to identify the AWS account and resources invoking the // Lambda function. It also includes Cognito identity information for the caller. type APIGatewayProxyRequestContext struct { diff --git a/events/apigw_test.go b/events/apigw_test.go index 00611c90..651718ab 100644 --- a/events/apigw_test.go +++ b/events/apigw_test.go @@ -4,11 +4,15 @@ package events import ( "encoding/json" + "errors" "io/ioutil" //nolint: staticcheck + "net/http" + "strings" "testing" "github.com/aws/aws-lambda-go/events/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestApiGatewayRequestMarshaling(t *testing.T) { @@ -83,6 +87,80 @@ func TestApiGatewayResponseMalformedJson(t *testing.T) { test.TestMalformedJson(t, APIGatewayProxyResponse{}) } +func TestAPIGatewayProxyStreamingResponseMarshaling(t *testing.T) { + for _, test := range []struct { + name string + response *APIGatewayProxyStreamingResponse + expectedHead string + expectedBody string + }{ + { + "empty", + &APIGatewayProxyStreamingResponse{}, + `{}`, + "", + }, + { + "just the status code", + &APIGatewayProxyStreamingResponse{ + StatusCode: http.StatusTeapot, + }, + `{"statusCode":418}`, + "", + }, + { + "status and headers and cookies and body", + &APIGatewayProxyStreamingResponse{ + StatusCode: http.StatusTeapot, + Headers: map[string]string{"hello": "world"}, + MultiValueHeaders: map[string][]string{"hi": {"1", "2"}}, + Cookies: []string{"cookies", "are", "yummy"}, + Body: strings.NewReader(`Hello Hello`), + }, + `{"statusCode":418, "headers":{"hello":"world"}, "multiValueHeaders":{"hi":["1","2"]}, "cookies":["cookies","are","yummy"]}`, + `Hello Hello`, + }, + } { + t.Run(test.name, func(t *testing.T) { + response, err := ioutil.ReadAll(test.response) + require.NoError(t, err) + sep := "\x00\x00\x00\x00\x00\x00\x00\x00" + responseParts := strings.Split(string(response), sep) + require.Len(t, responseParts, 2) + head := string(responseParts[0]) + body := string(responseParts[1]) + assert.JSONEq(t, test.expectedHead, head) + assert.Equal(t, test.expectedBody, body) + assert.NoError(t, test.response.Close()) + }) + } +} + +func TestAPIGatewayProxyStreamingResponsePropogatesInnerClose(t *testing.T) { + for _, test := range []struct { + name string + closer *readCloser + err error + }{ + { + "closer no err", + &readCloser{}, + nil, + }, + { + "closer with err", + &readCloser{err: errors.New("yolo")}, + errors.New("yolo"), + }, + } { + t.Run(test.name, func(t *testing.T) { + response := &APIGatewayProxyStreamingResponse{Body: test.closer} + assert.Equal(t, test.err, response.Close()) + assert.True(t, test.closer.closed) + }) + } +} + func TestApiGatewayCustomAuthorizerRequestMarshaling(t *testing.T) { // read json from file diff --git a/events/example_apigw_test.go b/events/example_apigw_test.go new file mode 100644 index 00000000..fb266b02 --- /dev/null +++ b/events/example_apigw_test.go @@ -0,0 +1,20 @@ +package events_test + +import ( + "strings" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-lambda-go/lambda" +) + +func ExampleAPIGatewayProxyStreamingResponse() { + lambda.Start(func() (*events.APIGatewayProxyStreamingResponse, error) { + return &events.APIGatewayProxyStreamingResponse{ + StatusCode: 200, + Headers: map[string]string{ + "Content-Type": "text/html", + }, + Body: strings.NewReader("Hello World!"), + }, nil + }) +} diff --git a/events/example_lambda_function_urls_test.go b/events/example_lambda_function_urls_test.go new file mode 100644 index 00000000..7643bfc7 --- /dev/null +++ b/events/example_lambda_function_urls_test.go @@ -0,0 +1,20 @@ +package events_test + +import ( + "strings" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-lambda-go/lambda" +) + +func ExampleLambdaFunctionURLStreamingResponse() { + lambda.Start(func() (*events.LambdaFunctionURLStreamingResponse, error) { + return &events.LambdaFunctionURLStreamingResponse{ + StatusCode: 200, + Headers: map[string]string{ + "Content-Type": "text/html", + }, + Body: strings.NewReader("Hello World!"), + }, nil + }) +} diff --git a/events/lambda_function_urls.go b/events/lambda_function_urls.go index 52a48e83..1d5ac6ff 100644 --- a/events/lambda_function_urls.go +++ b/events/lambda_function_urls.go @@ -71,18 +71,6 @@ type LambdaFunctionURLResponse struct { // LambdaFunctionURLStreamingResponse models the response to a Lambda Function URL when InvokeMode is RESPONSE_STREAM. // If the InvokeMode of the Function URL is BUFFERED (default), use LambdaFunctionURLResponse instead. // -// Example: -// -// lambda.Start(func() (*events.LambdaFunctionURLStreamingResponse, error) { -// return &events.LambdaFunctionURLStreamingResponse{ -// StatusCode: 200, -// Headers: map[string]string{ -// "Content-Type": "text/html", -// }, -// Body: strings.NewReader("Hello World!"), -// }, nil -// }) -// // Note: This response type requires compiling with `-tags lambda.norpc`, or choosing the `provided` or `provided.al2` runtime. type LambdaFunctionURLStreamingResponse struct { prelude *bytes.Buffer