Skip to content

Commit d786d45

Browse files
committed
fix: propagate panics in streaming mode
1 parent f5051e5 commit d786d45

File tree

2 files changed

+181
-42
lines changed

2 files changed

+181
-42
lines changed

functionurl_test.go

Lines changed: 166 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/base64"
66
"encoding/json"
7+
"errors"
78
"github.com/aws/aws-lambda-go/events"
89
"github.com/gofiber/fiber/v2"
910
"github.com/its-felix/aws-lambda-go-http-adapter/adapter"
@@ -67,6 +68,15 @@ func newVanillaAdapter() handler.AdapterFunc {
6768
return adapter.NewVanillaAdapter(mux)
6869
}
6970

71+
func newVanillaPanicAdapter() handler.AdapterFunc {
72+
mux := http.NewServeMux()
73+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
74+
panic("panic from test")
75+
})
76+
77+
return adapter.NewVanillaAdapter(mux)
78+
}
79+
7080
func newEchoAdapter() handler.AdapterFunc {
7181
app := echo.New()
7282
app.Any("*", func(c echo.Context) error {
@@ -89,6 +99,15 @@ func newEchoAdapter() handler.AdapterFunc {
8999
return adapter.NewEchoAdapter(app)
90100
}
91101

102+
func newEchoPanicAdapter() handler.AdapterFunc {
103+
app := echo.New()
104+
app.Any("*", func(c echo.Context) error {
105+
panic("panic from test")
106+
})
107+
108+
return adapter.NewEchoAdapter(app)
109+
}
110+
92111
func newFiberAdapter() handler.AdapterFunc {
93112
app := fiber.New()
94113
app.All("*", func(ctx *fiber.Ctx) error {
@@ -104,7 +123,66 @@ func newFiberAdapter() handler.AdapterFunc {
104123
return adapter.NewFiberAdapter(app)
105124
}
106125

107-
func TestFunctionURLGET(t *testing.T) {
126+
func newFiberPanicAdapter() handler.AdapterFunc {
127+
app := fiber.New()
128+
app.All("*", func(ctx *fiber.Ctx) error {
129+
panic("panic from test")
130+
})
131+
132+
return adapter.NewFiberAdapter(app)
133+
}
134+
135+
type extractor[T any] interface {
136+
StatusCode(T) int
137+
Headers(T) map[string]string
138+
IsBase64Encoded(T) bool
139+
Body(T) string
140+
}
141+
142+
type extractorNormal struct{}
143+
144+
func (extractorNormal) StatusCode(response events.LambdaFunctionURLResponse) int {
145+
return response.StatusCode
146+
}
147+
148+
func (extractorNormal) Headers(response events.LambdaFunctionURLResponse) map[string]string {
149+
return response.Headers
150+
}
151+
152+
func (extractorNormal) IsBase64Encoded(response events.LambdaFunctionURLResponse) bool {
153+
return response.IsBase64Encoded
154+
}
155+
156+
func (extractorNormal) Body(response events.LambdaFunctionURLResponse) string {
157+
return response.Body
158+
}
159+
160+
type extractorStreaming struct{}
161+
162+
func (extractorStreaming) StatusCode(response *events.LambdaFunctionURLStreamingResponse) int {
163+
return response.StatusCode
164+
}
165+
166+
func (extractorStreaming) Headers(response *events.LambdaFunctionURLStreamingResponse) map[string]string {
167+
return response.Headers
168+
}
169+
170+
func (extractorStreaming) IsBase64Encoded(*events.LambdaFunctionURLStreamingResponse) bool {
171+
return false
172+
}
173+
174+
func (extractorStreaming) Body(response *events.LambdaFunctionURLStreamingResponse) string {
175+
defer func() {
176+
if rc, ok := response.Body.(io.Closer); ok {
177+
_ = rc.Close()
178+
}
179+
}()
180+
181+
b, _ := io.ReadAll(response.Body)
182+
return string(b)
183+
}
184+
185+
func TestFunctionURLPOST(t *testing.T) {
108186
adapters := map[string]handler.AdapterFunc{
109187
"vanilla": newVanillaAdapter(),
110188
"echo": newEchoAdapter(),
@@ -113,41 +191,93 @@ func TestFunctionURLGET(t *testing.T) {
113191

114192
for name, a := range adapters {
115193
t.Run(name, func(t *testing.T) {
116-
h := handler.NewFunctionURLHandler(a)
117-
118-
req := newFunctionURLRequest()
119-
res, err := h(context.Background(), req)
120-
if err != nil {
121-
t.Error(err)
122-
}
123-
124-
if res.StatusCode != http.StatusOK {
125-
t.Error("expected status to be 200")
126-
}
127-
128-
if res.Headers["Content-Type"] != "application/json" {
129-
t.Error("expected Content-Type to be application/json")
130-
}
131-
132-
if res.IsBase64Encoded {
133-
t.Error("expected body not to be base64 encoded")
134-
}
135-
136-
body := make(map[string]string)
137-
_ = json.Unmarshal([]byte(res.Body), &body)
138-
139-
expectedBody := map[string]string{
140-
"Method": "POST",
141-
"URL": "https://0dhg9709da0dhg9709da0dhg9709da.lambda-url.eu-central-1.on.aws/example?key=value",
142-
"RemoteAddr": "127.0.0.1:http",
143-
"Body": "hello world",
144-
}
145-
146-
if !reflect.DeepEqual(body, expectedBody) {
147-
t.Logf("expected: %v", expectedBody)
148-
t.Logf("actual: %v", body)
149-
t.Error("request/response didnt match")
150-
}
194+
t.Run("normal", func(t *testing.T) {
195+
h := handler.NewFunctionURLHandler(a)
196+
runTestFunctionURLPOST[events.LambdaFunctionURLResponse](t, h, extractorNormal{})
197+
})
198+
199+
t.Run("streaming", func(t *testing.T) {
200+
h := handler.NewFunctionURLStreamingHandler(a)
201+
runTestFunctionURLPOST[*events.LambdaFunctionURLStreamingResponse](t, h, extractorStreaming{})
202+
})
151203
})
152204
}
153205
}
206+
207+
func runTestFunctionURLPOST[T any](t *testing.T, h func(context.Context, events.LambdaFunctionURLRequest) (T, error), ex extractor[T]) {
208+
req := newFunctionURLRequest()
209+
res, err := h(context.Background(), req)
210+
if err != nil {
211+
t.Error(err)
212+
}
213+
214+
if ex.StatusCode(res) != http.StatusOK {
215+
t.Error("expected status to be 200")
216+
}
217+
218+
if ex.Headers(res)["Content-Type"] != "application/json" {
219+
t.Error("expected Content-Type to be application/json")
220+
}
221+
222+
if ex.IsBase64Encoded(res) {
223+
t.Error("expected body not to be base64 encoded")
224+
}
225+
226+
body := make(map[string]string)
227+
_ = json.Unmarshal([]byte(ex.Body(res)), &body)
228+
229+
expectedBody := map[string]string{
230+
"Method": "POST",
231+
"URL": "https://0dhg9709da0dhg9709da0dhg9709da.lambda-url.eu-central-1.on.aws/example?key=value",
232+
"RemoteAddr": "127.0.0.1:http",
233+
"Body": "hello world",
234+
}
235+
236+
if !reflect.DeepEqual(body, expectedBody) {
237+
t.Logf("expected: %v", expectedBody)
238+
t.Logf("actual: %v", body)
239+
t.Error("request/response didnt match")
240+
}
241+
}
242+
243+
func TestFunctionURLWithPanicAndRecover(t *testing.T) {
244+
adapters := map[string]handler.AdapterFunc{
245+
"vanilla": newVanillaPanicAdapter(),
246+
"echo": newEchoPanicAdapter(),
247+
"fiber": newFiberPanicAdapter(),
248+
}
249+
250+
for name, a := range adapters {
251+
t.Run(name, func(t *testing.T) {
252+
t.Run("normal", func(t *testing.T) {
253+
h := handler.NewFunctionURLHandler(a)
254+
h = handler.WrapWithRecover(h, func(ctx context.Context, event events.LambdaFunctionURLRequest, panicValue any) (events.LambdaFunctionURLResponse, error) {
255+
return events.LambdaFunctionURLResponse{}, errors.New(panicValue.(string))
256+
})
257+
258+
runTestFunctionURLPanicAndRecover(t, h)
259+
})
260+
261+
t.Run("streaming", func(t *testing.T) {
262+
h := handler.NewFunctionURLStreamingHandler(a)
263+
h = handler.WrapWithRecover(h, func(ctx context.Context, event events.LambdaFunctionURLRequest, panicValue any) (*events.LambdaFunctionURLStreamingResponse, error) {
264+
return nil, errors.New(panicValue.(string))
265+
})
266+
267+
runTestFunctionURLPanicAndRecover(t, h)
268+
})
269+
})
270+
}
271+
}
272+
273+
func runTestFunctionURLPanicAndRecover[T any](t *testing.T, h func(context.Context, events.LambdaFunctionURLRequest) (T, error)) {
274+
req := newFunctionURLRequest()
275+
_, err := h(context.Background(), req)
276+
if err == nil {
277+
t.Error("expected to receive an error")
278+
}
279+
280+
if err.Error() != "panic from test" {
281+
t.Error("expected to receive error 'panic from test'")
282+
}
283+
}

handler/functionurl.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,25 +199,34 @@ func handleFunctionURLStreaming(ctx context.Context, event events.LambdaFunction
199199

200200
resCh := make(chan *events.LambdaFunctionURLStreamingResponse)
201201
errCh := make(chan error)
202+
panicCh := make(chan any)
202203

203-
go processRequestFunctionURLStreaming(ctx, req, adapter, resCh, errCh)
204+
go processRequestFunctionURLStreaming(ctx, req, adapter, resCh, errCh, panicCh)
204205

205206
select {
206207
case res := <-resCh:
207208
return res, nil
208209
case err = <-errCh:
209210
return nil, err
211+
case panicV := <-panicCh:
212+
panic(panicV)
210213
case <-ctx.Done():
211214
return nil, ctx.Err()
212215
}
213216
}
214217

215-
func processRequestFunctionURLStreaming(ctx context.Context, req *http.Request, adapter AdapterFunc, resCh chan<- *events.LambdaFunctionURLStreamingResponse, errCh chan<- error) {
216-
defer close(resCh)
217-
defer close(errCh)
218-
218+
func processRequestFunctionURLStreaming(ctx context.Context, req *http.Request, adapter AdapterFunc, resCh chan<- *events.LambdaFunctionURLStreamingResponse, errCh chan<- error, panicCh chan<- any) {
219219
ctx, cancel := context.WithCancel(ctx)
220-
defer cancel()
220+
defer func() {
221+
if panicV := recover(); panicV != nil {
222+
panicCh <- panicV
223+
}
224+
225+
close(panicCh)
226+
close(resCh)
227+
close(errCh)
228+
cancel()
229+
}()
221230

222231
w := functionURLStreamingResponseWriter{
223232
headers: make(http.Header),

0 commit comments

Comments
 (0)