Skip to content

Commit fe9434a

Browse files
authored
mcp: configure heartbeats to be less chatty (#1294)
**Description** We were emitting heartbeats every 1s and that was very chatty, especially when we're generating traces for each heartbeat. This PR: * Makes the heartbeat configurable and allows even disabling it by setting the `MCP_PROXY_HEARTBEAT_INTERVAL=0`. * Changes the default heartbeat interval to 1 minute. * Eagerly sends an initial ping to tackle the Goose issue. With this, heartbeats can be disabled and Goose works fine as well. **Related Issues/PRs (if applicable)** N/A **Special notes for reviewers (if applicable)** N/A --------- Signed-off-by: Ignasi Barrera <[email protected]>
1 parent 782b0ed commit fe9434a

File tree

4 files changed

+140
-54
lines changed

4 files changed

+140
-54
lines changed

internal/mcpproxy/session.go

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"io"
1616
"log/slog"
1717
"net/http"
18+
"os"
1819
"strings"
1920
"sync"
2021
"time"
@@ -179,8 +180,22 @@ func (s *session) streamNotifications(ctx context.Context, w http.ResponseWriter
179180
//
180181
// TODO: no idea exactly why this is necessary. Goose shouldn't block on the first event.
181182

182-
heartbeatTicker := time.NewTicker(1 * time.Second)
183-
defer heartbeatTicker.Stop()
183+
var (
184+
heartbeats <-chan time.Time
185+
heartbeatTicker *time.Ticker
186+
)
187+
if heartbeatInterval > 0 {
188+
heartbeatTicker = time.NewTicker(heartbeatInterval)
189+
defer heartbeatTicker.Stop()
190+
heartbeats = heartbeatTicker.C
191+
} else {
192+
heartbeats = make(chan time.Time) // never ticks
193+
}
194+
195+
// Eagerly send an initial heartbeat event to unblock Goose
196+
heartBeatEvent := &sseEvent{event: "message", messages: []jsonrpc.Message{newHeartBeatPingMessage()}}
197+
heartBeatEvent.writeAndMaybeFlush(w)
198+
184199
for {
185200
select {
186201
case event, ok := <-backendMsgs:
@@ -208,7 +223,12 @@ func (s *session) streamNotifications(ctx context.Context, w http.ResponseWriter
208223
s.proxy.recordResponse(ctx, _msg)
209224
}
210225
event.writeAndMaybeFlush(w)
211-
case <-heartbeatTicker.C:
226+
// Reset the heartbeat ticker so that the next heartbeat will be sent after the full interval.
227+
// This avoids sending heartbeats too frequently when there are events.
228+
if heartbeatTicker != nil {
229+
heartbeatTicker.Reset(heartbeatInterval)
230+
}
231+
case <-heartbeats:
212232
heartBeatEvent := &sseEvent{event: "message", messages: []jsonrpc.Message{newHeartBeatPingMessage()}}
213233
heartBeatEvent.writeAndMaybeFlush(w)
214234
case <-ctx.Done():
@@ -218,6 +238,20 @@ func (s *session) streamNotifications(ctx context.Context, w http.ResponseWriter
218238
}
219239
}
220240

241+
// heartbeatInterval is computed at startup to avoid the locks in os.Getenv() to be called on the request path.
242+
var heartbeatInterval = getHeartbeatInterval(1 * time.Minute)
243+
244+
// getHeartbeatInterval returns the heartbeat interval configured via the MCP_HEARTBEAT_INTERVAL environment variable.
245+
// If the environment variable is not set or invalid, it returns the default value of 1 minute.
246+
// This value is intentionally hidden under an environment variable as it is unclear if it is generally useful.
247+
func getHeartbeatInterval(def time.Duration) time.Duration {
248+
hbi, err := time.ParseDuration(os.Getenv("MCP_PROXY_HEARTBEAT_INTERVAL"))
249+
if err != nil {
250+
return def
251+
}
252+
return hbi
253+
}
254+
221255
// sendToAllBackends sends an HTTP request to all backends in this session and returns a channel that streams
222256
// the response events from all backends.
223257
func (s *session) sendToAllBackends(ctx context.Context, httpMethod string, request *jsonrpc.Request, span tracing.MCPSpan) <-chan *sseEvent {

internal/mcpproxy/session_test.go

Lines changed: 99 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"log/slog"
1414
"net/http"
1515
"net/http/httptest"
16+
"strings"
1617
"sync/atomic"
1718
"testing"
1819
"time"
@@ -150,51 +151,83 @@ func TestHandleNotificationsPerBackend_SSE(t *testing.T) {
150151
}
151152

152153
func TestSession_StreamNotifications(t *testing.T) {
153-
// Single backend streaming two events with valid messages.
154-
id1, _ := jsonrpc.MakeID("1")
155-
id2, _ := jsonrpc.MakeID("2")
156-
msg1, _ := jsonrpc.EncodeMessage(&jsonrpc.Request{Method: "a1", ID: id1})
157-
msg2, _ := jsonrpc.EncodeMessage(&jsonrpc.Request{Method: "a2", ID: id2})
158-
body := "event: a1\n" + "data: " + string(msg1) + "\n\n" + "event: a2\n" + "data: " + string(msg2) + "\n\n"
159-
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
160-
if r.Method != http.MethodGet {
161-
w.WriteHeader(http.StatusBadRequest)
162-
return
163-
}
164-
if r.Header.Get(internalapi.MCPBackendHeader) != "backend1" {
165-
w.WriteHeader(http.StatusBadRequest)
166-
return
167-
}
168-
w.Header().Set("Content-Type", "text/event-stream")
169-
for _, b := range []byte(body) {
170-
_, _ = w.Write([]byte{b})
171-
if f, ok := w.(http.Flusher); ok {
172-
f.Flush()
154+
tests := []struct {
155+
name string
156+
eventInterval time.Duration
157+
deadline time.Duration
158+
heartbeatInterval time.Duration
159+
wantHeartbeats bool
160+
}{
161+
// the default heartbeat interval is 1 second, but the events will come faster, so
162+
// we don't expect any heartbeats.
163+
{"fast events", 10 * time.Millisecond, 5 * time.Second, 10 * time.Second, false},
164+
// configure a heartbeat interval faster than the event interval, so we expect heartbeats.
165+
{"slow events", 20 * time.Millisecond, 5 * time.Second, 10 * time.Millisecond, true},
166+
// disable heartbeats. Even though events come in slowly, we don't expect heartbeats.
167+
{"no heartbeats", 50 * time.Millisecond, 25 * time.Second, 0, false},
168+
}
169+
170+
for _, tc := range tests {
171+
t.Run(tc.name, func(t *testing.T) {
172+
// Override the default heartbeat interval for testing.
173+
originalHeartbeatInterval := heartbeatInterval
174+
heartbeatInterval = tc.heartbeatInterval
175+
t.Cleanup(func() { heartbeatInterval = originalHeartbeatInterval })
176+
177+
// Single backend streaming two events with valid messages.
178+
id1, _ := jsonrpc.MakeID("1")
179+
id2, _ := jsonrpc.MakeID("2")
180+
msg1, _ := jsonrpc.EncodeMessage(&jsonrpc.Request{Method: "a1", ID: id1})
181+
msg2, _ := jsonrpc.EncodeMessage(&jsonrpc.Request{Method: "a2", ID: id2})
182+
body := "event: a1\n" + "data: " + string(msg1) + "\n\n" + "event: a2\n" + "data: " + string(msg2) + "\n\n"
183+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
184+
if r.Method != http.MethodGet {
185+
w.WriteHeader(http.StatusBadRequest)
186+
return
187+
}
188+
if r.Header.Get(internalapi.MCPBackendHeader) != "backend1" {
189+
w.WriteHeader(http.StatusBadRequest)
190+
return
191+
}
192+
w.Header().Set("Content-Type", "text/event-stream")
193+
for _, b := range []byte(body) {
194+
_, _ = w.Write([]byte{b})
195+
if f, ok := w.(http.Flusher); ok {
196+
f.Flush()
197+
}
198+
time.Sleep(tc.eventInterval)
199+
}
200+
}))
201+
defer srv.Close()
202+
proxy := newTestMCPProxy()
203+
proxy.backendListenerAddr = srv.URL
204+
205+
s := &session{
206+
proxy: proxy,
207+
perBackendSessions: map[filterapi.MCPBackendName]*compositeSessionEntry{
208+
"backend1": {
209+
sessionID: "s1",
210+
},
211+
},
212+
route: "test-route",
173213
}
174-
time.Sleep(10 * time.Millisecond)
175-
}
176-
}))
177-
defer srv.Close()
178-
proxy := newTestMCPProxy()
179-
proxy.backendListenerAddr = srv.URL
214+
rr := httptest.NewRecorder()
215+
ctx, cancel := context.WithTimeout(t.Context(), tc.deadline)
216+
defer cancel()
217+
err2 := s.streamNotifications(ctx, rr)
218+
require.NoError(t, err2)
219+
out := rr.Body.String()
220+
require.Contains(t, out, "event: a1")
221+
require.Contains(t, out, "event: a2")
222+
heartbeatCount := strings.Count(out, `"method":"ping"`)
180223

181-
s := &session{
182-
proxy: proxy,
183-
perBackendSessions: map[filterapi.MCPBackendName]*compositeSessionEntry{
184-
"backend1": {
185-
sessionID: "s1",
186-
},
187-
},
188-
route: "test-route",
224+
if tc.wantHeartbeats {
225+
require.Greater(t, heartbeatCount, 1, "expected some heartbeats after the initial one")
226+
} else {
227+
require.Equal(t, 1, heartbeatCount, "expected only the initial heartbeat")
228+
}
229+
})
189230
}
190-
rr := httptest.NewRecorder()
191-
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
192-
defer cancel()
193-
err2 := s.streamNotifications(ctx, rr)
194-
require.NoError(t, err2)
195-
out := rr.Body.String()
196-
require.Contains(t, out, "event: a1")
197-
require.Contains(t, out, "event: a2")
198231
}
199232

200233
func TestSendRequestPerBackend_ErrorStatus(t *testing.T) {
@@ -231,3 +264,27 @@ func TestSendRequestPerBackend_EOF(t *testing.T) {
231264
}, http.MethodGet, nil)
232265
require.True(t, err2 == nil || errors.Is(err2, io.EOF), "unexpected error: %v", err2)
233266
}
267+
268+
func TestGetHeartbeatInterval(t *testing.T) {
269+
defaultInterval := 1 * time.Minute
270+
271+
tests := []struct {
272+
name string
273+
env string
274+
want time.Duration
275+
}{
276+
{"unset", "", defaultInterval},
277+
{"invalid", "invalid", defaultInterval},
278+
{"zero", "0s", 0},
279+
{"value", "5m", 5 * time.Minute},
280+
}
281+
282+
for _, tt := range tests {
283+
t.Run(tt.name, func(t *testing.T) {
284+
if tt.env != "" {
285+
t.Setenv("MCP_PROXY_HEARTBEAT_INTERVAL", tt.env)
286+
}
287+
require.Equal(t, tt.want, getHeartbeatInterval(defaultInterval))
288+
})
289+
}
290+
}

tests/extproc/mcp/mcp_test.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,6 @@ func TestMCP_forceJSONResponse(t *testing.T) {
7979
}
8080
}
8181

82-
func TestMCP_customWriteTimeout(t *testing.T) {
83-
env := requireNewMCPEnv(t, false, 2*time.Second, defaultMCPPath)
84-
for _, tc := range tests {
85-
t.Run(tc.name+"/custom_write_timeout", func(t *testing.T) {
86-
tc.testFn(t, env)
87-
})
88-
}
89-
}
90-
9182
func TestMCP_differentPath(t *testing.T) {
9283
env := requireNewMCPEnv(t, false, 1200*time.Second, "/mcp/yet/another/path")
9384
t.Run("call", func(t *testing.T) {

tests/internal/testmcp/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"errors"
1111
"fmt"
1212
"log"
13+
"net"
1314
"net/http"
1415
"os"
1516
"time"
@@ -112,6 +113,9 @@ func NewServer(opts *Options) *http.Server {
112113
// Allow long-lived connections.
113114
WriteTimeout: opts.WriteTimeout,
114115
Handler: handler,
116+
ConnState: func(conn net.Conn, state http.ConnState) {
117+
log.Printf("MCP SERVER connection [%s] %s -> %s\n", state, conn.RemoteAddr(), conn.LocalAddr())
118+
},
115119
}
116120
go func() {
117121
log.Printf("starting MCP Streamable-HTTP server on :%d at /mcp", opts.Port)

0 commit comments

Comments
 (0)