Skip to content

Commit 5ac2016

Browse files
committed
replace http.Error with return
1 parent 6b4027a commit 5ac2016

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

mcp/streamable.go

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -304,19 +304,25 @@ type idContextKey struct{}
304304

305305
// ServeHTTP handles a single HTTP request for the session.
306306
func (t *StreamableServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request) {
307+
status := 0
308+
message := ""
307309
switch req.Method {
308310
case http.MethodGet:
309-
t.serveGET(w, req)
311+
status, message = t.serveGET(w, req)
310312
case http.MethodPost:
311-
t.servePOST(w, req)
313+
status, message = t.servePOST(w, req)
312314
default:
313315
// Should not be reached, as this is checked in StreamableHTTPHandler.ServeHTTP.
314316
w.Header().Set("Allow", "GET, POST")
315-
http.Error(w, "unsupported method", http.StatusMethodNotAllowed)
317+
status = http.StatusMethodNotAllowed
318+
message = "unsupported method"
319+
}
320+
if status != 0 && status != http.StatusOK {
321+
http.Error(w, message, status)
316322
}
317323
}
318324

319-
func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) {
325+
func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Request) (int, string) {
320326
// connID 0 corresponds to the default GET request.
321327
id := StreamID(0)
322328
// By default, we haven't seen a last index. Since indices start at 0, we represent
@@ -328,49 +334,42 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
328334
var ok bool
329335
id, lastIdx, ok = parseEventID(eid)
330336
if !ok {
331-
http.Error(w, fmt.Sprintf("malformed Last-Event-ID %q", eid), http.StatusBadRequest)
332-
return
337+
return http.StatusBadRequest, fmt.Sprintf("malformed Last-Event-ID %q", eid)
333338
}
334339
}
335340

336341
t.mu.Lock()
337342
stream, ok := t.streams[id]
338343
if !ok {
339-
http.Error(w, "unknown stream", http.StatusBadRequest)
340344
t.mu.Unlock()
341-
return
345+
return http.StatusBadRequest, "unknown stream"
342346
}
343347
if stream.signal != nil {
344-
http.Error(w, "stream ID conflicts with ongoing stream", http.StatusBadRequest)
345348
t.mu.Unlock()
346-
return
349+
return http.StatusBadRequest, "stream ID conflicts with ongoing stream"
347350
}
348351
stream.signal = make(chan struct{}, 1)
349352
t.mu.Unlock()
350353

351-
t.streamResponse(stream, w, req, lastIdx)
354+
return t.streamResponse(stream, w, req, lastIdx)
352355
}
353356

354-
func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) {
357+
func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.Request) (int, string) {
355358
if len(req.Header.Values("Last-Event-ID")) > 0 {
356-
http.Error(w, "can't send Last-Event-ID for POST request", http.StatusBadRequest)
357-
return
359+
return http.StatusBadRequest, "can't send Last-Event-ID for POST request"
358360
}
359361

360362
// Read incoming messages.
361363
body, err := io.ReadAll(req.Body)
362364
if err != nil {
363-
http.Error(w, "failed to read body", http.StatusBadRequest)
364-
return
365+
return http.StatusBadRequest, "failed to read body"
365366
}
366367
if len(body) == 0 {
367-
http.Error(w, "POST requires a non-empty body", http.StatusBadRequest)
368-
return
368+
return http.StatusBadRequest, "POST requires a non-empty body"
369369
}
370370
incoming, _, err := readBatch(body)
371371
if err != nil {
372-
http.Error(w, fmt.Sprintf("malformed payload: %v", err), http.StatusBadRequest)
373-
return
372+
return http.StatusBadRequest, fmt.Sprintf("malformed payload: %v", err)
374373
}
375374
requests := make(map[jsonrpc.ID]struct{})
376375
for _, msg := range incoming {
@@ -401,11 +400,11 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
401400
// TODO(rfindley): consider optimizing for a single incoming request, by
402401
// responding with application/json when there is only a single message in
403402
// the response.
404-
t.streamResponse(stream, w, req, -1)
403+
return t.streamResponse(stream, w, req, -1)
405404
}
406405

407406
// lastIndex is the index of the last seen event if resuming, else -1.
408-
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) {
407+
func (t *StreamableServerTransport) streamResponse(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int) (int, string) {
409408
defer func() {
410409
t.mu.Lock()
411410
stream.signal = nil
@@ -431,7 +430,7 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
431430
}
432431
if _, err := writeEvent(w, e); err != nil {
433432
// Connection closed or broken.
434-
// TODO: log when we add server-side logging.
433+
// TODO(#170): log when we add server-side logging.
435434
return false
436435
}
437436
writes++
@@ -454,13 +453,12 @@ func (t *StreamableServerTransport) streamResponse(stream *stream, w http.Respon
454453
if errors.Is(err, ErrEventsPurged) {
455454
status = http.StatusInsufficientStorage
456455
}
457-
http.Error(w, err.Error(), status)
458-
return
456+
return status, err.Error()
459457
}
460458
// The iterator yields events beginning just after lastIndex, or it would have
461459
// yielded an error.
462460
if !write(data) {
463-
return
461+
return 0, ""
464462
}
465463
}
466464
}
@@ -475,11 +473,10 @@ stream:
475473

476474
for _, data := range outgoing {
477475
if err := t.opts.EventStore.Append(req.Context(), t.SessionID(), stream.id, data); err != nil {
478-
http.Error(w, err.Error(), http.StatusInternalServerError)
479-
return
476+
return http.StatusInternalServerError, err.Error()
480477
}
481478
if !write(data) {
482-
return
479+
return 0, ""
483480
}
484481
}
485482

@@ -489,22 +486,22 @@ stream:
489486
// If all requests have been handled and replied to, we should terminate this connection.
490487
// "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
491488
// §6.4, https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
492-
// TODO(jba): why not terminate regardless of http method?
489+
// TODO(jba,findleyr): why not terminate regardless of http method?
493490
if req.Method == http.MethodPost && nOutstanding == 0 {
494491
if writes == 0 {
495492
// Spec: If the server accepts the input, the server MUST return HTTP
496493
// status code 202 Accepted with no body.
497494
w.WriteHeader(http.StatusAccepted)
498495
}
499-
return
496+
return 0, ""
500497
}
501498

502499
select {
503500
case <-signal: // there are new outgoing messages
504501
// return to top of loop
505502
case <-t.done: // session is closed
506503
if writes == 0 {
507-
http.Error(w, "session terminated", http.StatusGone)
504+
return http.StatusGone, "session terminated"
508505
}
509506
break stream
510507
case <-req.Context().Done():
@@ -514,6 +511,7 @@ stream:
514511
break stream
515512
}
516513
}
514+
return 0, ""
517515
}
518516

519517
// Event IDs: encode both the logical connection ID and the index, as

0 commit comments

Comments
 (0)