Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions mcp/event.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

// This file is for SSE events.
// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events.

package mcp

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"iter"
"net/http"
"strings"
)

// An Event is a server-sent event.
// See https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#fields.
type Event struct {
Name string // the "event" field
ID string // the "id" field
Data []byte // the "data" field
}

// Empty reports whether the Event is empty.
func (e Event) Empty() bool {
return e.Name == "" && e.ID == "" && len(e.Data) == 0
}

// writeEvent writes the event to w, and flushes.
func writeEvent(w io.Writer, evt Event) (int, error) {
var b bytes.Buffer
if evt.Name != "" {
fmt.Fprintf(&b, "event: %s\n", evt.Name)
}
if evt.ID != "" {
fmt.Fprintf(&b, "id: %s\n", evt.ID)
}
fmt.Fprintf(&b, "data: %s\n\n", string(evt.Data))
n, err := w.Write(b.Bytes())
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return n, err
}

// scanEvents iterates SSE events in the given scanner. The iterated error is
// terminal: if encountered, the stream is corrupt or broken and should no
// longer be used.
//
// TODO(rfindley): consider a different API here that makes failure modes more
// apparent.
func scanEvents(r io.Reader) iter.Seq2[Event, error] {
scanner := bufio.NewScanner(r)
const maxTokenSize = 1 * 1024 * 1024 // 1 MiB max line size
scanner.Buffer(nil, maxTokenSize)

// TODO: investigate proper behavior when events are out of order, or have
// non-standard names.
var (
eventKey = []byte("event")
idKey = []byte("id")
dataKey = []byte("data")
)

return func(yield func(Event, error) bool) {
// iterate event from the wire.
// https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#examples
//
// - `key: value` line records.
// - Consecutive `data: ...` fields are joined with newlines.
// - Unrecognized fields are ignored. Since we only care about 'event', 'id', and
// 'data', these are the only three we consider.
// - Lines starting with ":" are ignored.
// - Records are terminated with two consecutive newlines.
var (
evt Event
dataBuf *bytes.Buffer // if non-nil, preceding field was also data
)
flushData := func() {
if dataBuf != nil {
evt.Data = dataBuf.Bytes()
dataBuf = nil
}
}
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
flushData()
// \n\n is the record delimiter
if !evt.Empty() && !yield(evt, nil) {
return
}
evt = Event{}
continue
}
before, after, found := bytes.Cut(line, []byte{':'})
if !found {
yield(Event{}, fmt.Errorf("malformed line in SSE stream: %q", string(line)))
return
}
if !bytes.Equal(before, dataKey) {
flushData()
}
switch {
case bytes.Equal(before, eventKey):
evt.Name = strings.TrimSpace(string(after))
case bytes.Equal(before, idKey):
evt.ID = strings.TrimSpace(string(after))
case bytes.Equal(before, dataKey):
data := bytes.TrimSpace(after)
if dataBuf != nil {
dataBuf.WriteByte('\n')
dataBuf.Write(data)
} else {
dataBuf = new(bytes.Buffer)
dataBuf.Write(data)
}
}
}
if err := scanner.Err(); err != nil {
if errors.Is(err, bufio.ErrTooLong) {
err = fmt.Errorf("event exceeded max line length of %d", maxTokenSize)
}
if !yield(Event{}, err) {
return
}
}
flushData()
if !evt.Empty() {
yield(evt, nil)
}
}
}
99 changes: 99 additions & 0 deletions mcp/event_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

package mcp

import (
"strings"
"testing"
)

func TestScanEvents(t *testing.T) {
tests := []struct {
name string
input string
want []Event
wantErr string
}{
{
name: "simple event",
input: "event: message\nid: 1\ndata: hello\n\n",
want: []Event{
{Name: "message", ID: "1", Data: []byte("hello")},
},
},
{
name: "multiple data lines",
input: "data: line 1\ndata: line 2\n\n",
want: []Event{
{Data: []byte("line 1\nline 2")},
},
},
{
name: "multiple events",
input: "data: first\n\nevent: second\ndata: second\n\n",
want: []Event{
{Data: []byte("first")},
{Name: "second", Data: []byte("second")},
},
},
{
name: "no trailing newline",
input: "data: hello",
want: []Event{
{Data: []byte("hello")},
},
},
{
name: "malformed line",
input: "invalid line\n\n",
wantErr: "malformed line",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := strings.NewReader(tt.input)
var got []Event
var err error
for e, err2 := range scanEvents(r) {
if err2 != nil {
err = err2
break
}
got = append(got, e)
}

if tt.wantErr != "" {
if err == nil {
t.Fatalf("scanEvents() got nil error, want error containing %q", tt.wantErr)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("scanEvents() error = %q, want containing %q", err, tt.wantErr)
}
return
}

if err != nil {
t.Fatalf("scanEvents() returned unexpected error: %v", err)
}

if len(got) != len(tt.want) {
t.Fatalf("scanEvents() got %d events, want %d", len(got), len(tt.want))
}

for i := range got {
if g, w := got[i].Name, tt.want[i].Name; g != w {
t.Errorf("event %d: name = %q, want %q", i, g, w)
}
if g, w := got[i].ID, tt.want[i].ID; g != w {
t.Errorf("event %d: id = %q, want %q", i, g, w)
}
if g, w := string(got[i].Data), string(tt.want[i].Data); g != w {
t.Errorf("event %d: data = %q, want %q", i, g, w)
}
}
})
}
}
Loading
Loading