Skip to content

Commit 0b4e711

Browse files
committed
feat: improve functionurl handling
1 parent 241e80d commit 0b4e711

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

handler/functionurl.go

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net/http"
1212
"strconv"
1313
"strings"
14+
"sync/atomic"
1415
"unicode/utf8"
1516
)
1617

@@ -111,7 +112,11 @@ func handleFunctionURL(ctx context.Context, event events.LambdaFunctionURLReques
111112
return def, err
112113
}
113114

114-
b := w.body.Bytes()
115+
b, err := io.ReadAll(&w.body)
116+
if err != nil {
117+
var def events.LambdaFunctionURLResponse
118+
return def, err
119+
}
115120

116121
if !w.contentTypeSet {
117122
w.res.Headers["Content-Type"] = http.DetectContentType(b)
@@ -139,10 +144,10 @@ func NewFunctionURLHandler(adapter AdapterFunc) func(context.Context, events.Lam
139144

140145
// region streaming
141146
type functionURLStreamingResponseWriter struct {
142-
headers http.Header
143-
body io.WriteCloser
144-
res *events.LambdaFunctionURLStreamingResponse
145-
resCh chan<- *events.LambdaFunctionURLStreamingResponse
147+
headers http.Header
148+
headersWritten int32
149+
body io.WriteCloser
150+
resCh chan<- events.LambdaFunctionURLStreamingResponse
146151
}
147152

148153
func (w *functionURLStreamingResponseWriter) Header() http.Header {
@@ -155,31 +160,33 @@ func (w *functionURLStreamingResponseWriter) Write(p []byte) (int, error) {
155160
}
156161

157162
func (w *functionURLStreamingResponseWriter) WriteHeader(statusCode int) {
158-
if w.res == nil {
163+
if atomic.CompareAndSwapInt32(&w.headersWritten, 0, 1) {
159164
pr, pw := io.Pipe()
160165
w.body = pw
161-
w.res = &events.LambdaFunctionURLStreamingResponse{
162-
StatusCode: statusCode,
163-
Headers: make(map[string]string),
164-
Body: pr,
165-
Cookies: make([]string, 0),
166-
}
166+
167+
headers := make(map[string]string)
168+
cookies := make([]string, 0)
167169

168170
for k, values := range w.headers {
169171
if strings.EqualFold("set-cookie", k) {
170-
w.res.Cookies = values
172+
cookies = values
171173
} else {
172174
if len(values) == 0 {
173-
w.res.Headers[k] = ""
175+
headers[k] = ""
174176
} else if len(values) == 1 {
175-
w.res.Headers[k] = values[0]
177+
headers[k] = values[0]
176178
} else {
177-
w.res.Headers[k] = strings.Join(values, ",")
179+
headers[k] = strings.Join(values, ",")
178180
}
179181
}
180182
}
181183

182-
w.resCh <- w.res
184+
w.resCh <- events.LambdaFunctionURLStreamingResponse{
185+
StatusCode: statusCode,
186+
Headers: headers,
187+
Body: pr,
188+
Cookies: cookies,
189+
}
183190
}
184191
}
185192

@@ -197,15 +204,15 @@ func handleFunctionURLStreaming(ctx context.Context, event events.LambdaFunction
197204
return nil, err
198205
}
199206

200-
resCh := make(chan *events.LambdaFunctionURLStreamingResponse)
207+
resCh := make(chan events.LambdaFunctionURLStreamingResponse)
201208
errCh := make(chan error)
202209
panicCh := make(chan any)
203210

204211
go processRequestFunctionURLStreaming(ctx, req, adapter, resCh, errCh, panicCh)
205212

206213
select {
207214
case res := <-resCh:
208-
return res, nil
215+
return &res, nil
209216
case err = <-errCh:
210217
return nil, err
211218
case panicV := <-panicCh:
@@ -215,7 +222,7 @@ func handleFunctionURLStreaming(ctx context.Context, event events.LambdaFunction
215222
}
216223
}
217224

218-
func processRequestFunctionURLStreaming(ctx context.Context, req *http.Request, adapter AdapterFunc, resCh chan<- *events.LambdaFunctionURLStreamingResponse, errCh chan<- error, panicCh chan<- any) {
225+
func processRequestFunctionURLStreaming(ctx context.Context, req *http.Request, adapter AdapterFunc, resCh chan<- events.LambdaFunctionURLStreamingResponse, errCh chan<- error, panicCh chan<- any) {
219226
ctx, cancel := context.WithCancel(ctx)
220227
defer func() {
221228
if panicV := recover(); panicV != nil {
@@ -229,8 +236,9 @@ func processRequestFunctionURLStreaming(ctx context.Context, req *http.Request,
229236
}()
230237

231238
w := functionURLStreamingResponseWriter{
232-
headers: make(http.Header),
233-
resCh: resCh,
239+
headers: make(http.Header),
240+
headersWritten: 0,
241+
resCh: resCh,
234242
}
235243

236244
defer w.Close()

0 commit comments

Comments
 (0)