@@ -13,6 +13,7 @@ import (
1313 "log/slog"
1414 "net/http"
1515 "net/http/httptest"
16+ "strings"
1617 "sync/atomic"
1718 "testing"
1819 "time"
@@ -150,51 +151,83 @@ func TestHandleNotificationsPerBackend_SSE(t *testing.T) {
150151}
151152
152153func TestSession_StreamNotifications (t * testing.T ) {
153- // Single backend streaming two events with valid messages.
154- id1 , _ := jsonrpc .MakeID ("1" )
155- id2 , _ := jsonrpc .MakeID ("2" )
156- msg1 , _ := jsonrpc .EncodeMessage (& jsonrpc.Request {Method : "a1" , ID : id1 })
157- msg2 , _ := jsonrpc .EncodeMessage (& jsonrpc.Request {Method : "a2" , ID : id2 })
158- body := "event: a1\n " + "data: " + string (msg1 ) + "\n \n " + "event: a2\n " + "data: " + string (msg2 ) + "\n \n "
159- srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
160- if r .Method != http .MethodGet {
161- w .WriteHeader (http .StatusBadRequest )
162- return
163- }
164- if r .Header .Get (internalapi .MCPBackendHeader ) != "backend1" {
165- w .WriteHeader (http .StatusBadRequest )
166- return
167- }
168- w .Header ().Set ("Content-Type" , "text/event-stream" )
169- for _ , b := range []byte (body ) {
170- _ , _ = w .Write ([]byte {b })
171- if f , ok := w .(http.Flusher ); ok {
172- f .Flush ()
154+ tests := []struct {
155+ name string
156+ eventInterval time.Duration
157+ deadline time.Duration
158+ heartbeatInterval time.Duration
159+ wantHeartbeats bool
160+ }{
161+ // the default heartbeat interval is 1 second, but the events will come faster, so
162+ // we don't expect any heartbeats.
163+ {"fast events" , 10 * time .Millisecond , 5 * time .Second , 10 * time .Second , false },
164+ // configure a heartbeat interval faster than the event interval, so we expect heartbeats.
165+ {"slow events" , 20 * time .Millisecond , 5 * time .Second , 10 * time .Millisecond , true },
166+ // disable heartbeats. Even though events come in slowly, we don't expect heartbeats.
167+ {"no heartbeats" , 50 * time .Millisecond , 25 * time .Second , 0 , false },
168+ }
169+
170+ for _ , tc := range tests {
171+ t .Run (tc .name , func (t * testing.T ) {
172+ // Override the default heartbeat interval for testing.
173+ originalHeartbeatInterval := heartbeatInterval
174+ heartbeatInterval = tc .heartbeatInterval
175+ t .Cleanup (func () { heartbeatInterval = originalHeartbeatInterval })
176+
177+ // Single backend streaming two events with valid messages.
178+ id1 , _ := jsonrpc .MakeID ("1" )
179+ id2 , _ := jsonrpc .MakeID ("2" )
180+ msg1 , _ := jsonrpc .EncodeMessage (& jsonrpc.Request {Method : "a1" , ID : id1 })
181+ msg2 , _ := jsonrpc .EncodeMessage (& jsonrpc.Request {Method : "a2" , ID : id2 })
182+ body := "event: a1\n " + "data: " + string (msg1 ) + "\n \n " + "event: a2\n " + "data: " + string (msg2 ) + "\n \n "
183+ srv := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
184+ if r .Method != http .MethodGet {
185+ w .WriteHeader (http .StatusBadRequest )
186+ return
187+ }
188+ if r .Header .Get (internalapi .MCPBackendHeader ) != "backend1" {
189+ w .WriteHeader (http .StatusBadRequest )
190+ return
191+ }
192+ w .Header ().Set ("Content-Type" , "text/event-stream" )
193+ for _ , b := range []byte (body ) {
194+ _ , _ = w .Write ([]byte {b })
195+ if f , ok := w .(http.Flusher ); ok {
196+ f .Flush ()
197+ }
198+ time .Sleep (tc .eventInterval )
199+ }
200+ }))
201+ defer srv .Close ()
202+ proxy := newTestMCPProxy ()
203+ proxy .backendListenerAddr = srv .URL
204+
205+ s := & session {
206+ proxy : proxy ,
207+ perBackendSessions : map [filterapi.MCPBackendName ]* compositeSessionEntry {
208+ "backend1" : {
209+ sessionID : "s1" ,
210+ },
211+ },
212+ route : "test-route" ,
173213 }
174- time .Sleep (10 * time .Millisecond )
175- }
176- }))
177- defer srv .Close ()
178- proxy := newTestMCPProxy ()
179- proxy .backendListenerAddr = srv .URL
214+ rr := httptest .NewRecorder ()
215+ ctx , cancel := context .WithTimeout (t .Context (), tc .deadline )
216+ defer cancel ()
217+ err2 := s .streamNotifications (ctx , rr )
218+ require .NoError (t , err2 )
219+ out := rr .Body .String ()
220+ require .Contains (t , out , "event: a1" )
221+ require .Contains (t , out , "event: a2" )
222+ heartbeatCount := strings .Count (out , `"method":"ping"` )
180223
181- s := & session {
182- proxy : proxy ,
183- perBackendSessions : map [filterapi.MCPBackendName ]* compositeSessionEntry {
184- "backend1" : {
185- sessionID : "s1" ,
186- },
187- },
188- route : "test-route" ,
224+ if tc .wantHeartbeats {
225+ require .Greater (t , heartbeatCount , 1 , "expected some heartbeats after the initial one" )
226+ } else {
227+ require .Equal (t , 1 , heartbeatCount , "expected only the initial heartbeat" )
228+ }
229+ })
189230 }
190- rr := httptest .NewRecorder ()
191- ctx , cancel := context .WithTimeout (t .Context (), 5 * time .Second )
192- defer cancel ()
193- err2 := s .streamNotifications (ctx , rr )
194- require .NoError (t , err2 )
195- out := rr .Body .String ()
196- require .Contains (t , out , "event: a1" )
197- require .Contains (t , out , "event: a2" )
198231}
199232
200233func TestSendRequestPerBackend_ErrorStatus (t * testing.T ) {
@@ -231,3 +264,27 @@ func TestSendRequestPerBackend_EOF(t *testing.T) {
231264 }, http .MethodGet , nil )
232265 require .True (t , err2 == nil || errors .Is (err2 , io .EOF ), "unexpected error: %v" , err2 )
233266}
267+
268+ func TestGetHeartbeatInterval (t * testing.T ) {
269+ defaultInterval := 1 * time .Minute
270+
271+ tests := []struct {
272+ name string
273+ env string
274+ want time.Duration
275+ }{
276+ {"unset" , "" , defaultInterval },
277+ {"invalid" , "invalid" , defaultInterval },
278+ {"zero" , "0s" , 0 },
279+ {"value" , "5m" , 5 * time .Minute },
280+ }
281+
282+ for _ , tt := range tests {
283+ t .Run (tt .name , func (t * testing.T ) {
284+ if tt .env != "" {
285+ t .Setenv ("MCP_PROXY_HEARTBEAT_INTERVAL" , tt .env )
286+ }
287+ require .Equal (t , tt .want , getHeartbeatInterval (defaultInterval ))
288+ })
289+ }
290+ }
0 commit comments