Skip to content

Commit 859588d

Browse files
authored
fix: call onConnectionLost when SSE connection closes, not only on NO_ERROR (#709)
1 parent b2fb8ba commit 859588d

File tree

3 files changed

+23
-148
lines changed

3 files changed

+23
-148
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ Key examples include:
650650

651651
### Transports
652652

653-
MCP-Go supports stdio, SSE and streamable-HTTP transport layers. For SSE transport, you can use `SetConnectionLostHandler()` to detect and handle HTTP/2 idle timeout disconnections (NO_ERROR) for implementing reconnection logic.
653+
MCP-Go supports stdio, SSE and streamable-HTTP transport layers. For SSE transport, you can use `SetConnectionLostHandler()` to detect and handle disconnections for implementing reconnection logic.
654654

655655
### Session Management
656656

client/transport/sse.go

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -245,22 +245,14 @@ func (c *SSE) readSSE(reader io.ReadCloser) {
245245
}
246246
c.handleSSEEvent(event, data)
247247
}
248-
break
249248
}
250-
// Checking whether the connection was terminated due to NO_ERROR in HTTP2 based on RFC9113
251-
// Only handle NO_ERROR specially if onConnectionLost handler is set to maintain backward compatibility
252-
if strings.Contains(err.Error(), "NO_ERROR") {
253-
c.connectionLostMu.RLock()
254-
handler := c.onConnectionLost
255-
c.connectionLostMu.RUnlock()
256-
257-
if handler != nil {
258-
// This is not actually an error - HTTP2 idle timeout disconnection
259-
handler(err)
260-
return
261-
}
262-
}
263-
if !c.closed.Load() {
249+
c.connectionLostMu.RLock()
250+
handler := c.onConnectionLost
251+
c.connectionLostMu.RUnlock()
252+
if handler != nil {
253+
// Notify that the connection will be closed due to an error
254+
handler(err)
255+
} else if err == io.EOF && !c.closed.Load() {
264256
c.logger.Errorf("SSE stream error: %v", err)
265257
}
266258
return

client/transport/sse_test.go

Lines changed: 15 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -539,14 +539,14 @@ func TestSSE(t *testing.T) {
539539
}
540540
})
541541

542-
t.Run("NO_ERROR_WithoutConnectionLostHandler", func(t *testing.T) {
543-
// Test that NO_ERROR without connection lost handler maintains backward compatibility
544-
// When no connection lost handler is set, NO_ERROR should be treated as a regular error
542+
t.Run("WithoutConnectionLostHandler", func(t *testing.T) {
543+
// Test that ERROR without connection lost handler maintains backward compatibility
544+
// When no connection lost handler is set, ERROR should be treated as a regular error
545545

546-
// Create a mock Reader that simulates NO_ERROR
546+
// Create a mock Reader that simulates ERROR
547547
mockReader := &mockReaderWithError{
548548
data: []byte("event: endpoint\ndata: /message\n\n"),
549-
err: errors.New("connection closed: NO_ERROR"),
549+
err: errors.New("context deadline exceeded (Client.Timeout or context cancellation while reading body)"),
550550
}
551551

552552
// Create SSE transport
@@ -571,23 +571,22 @@ func TestSSE(t *testing.T) {
571571
time.Sleep(100 * time.Millisecond)
572572

573573
// The test passes if readSSE completes without panicking or hanging
574-
// In backward compatibility mode, NO_ERROR should be treated as a regular error
575-
t.Log("Backward compatibility test passed: NO_ERROR handled as regular error when no handler is set")
574+
// In backward compatibility mode, ERROR should be treated as a regular error
575+
t.Log("Backward compatibility test passed: ERROR handled as regular error when no handler is set")
576576
})
577577

578-
t.Run("NO_ERROR_ConnectionLost", func(t *testing.T) {
579-
// Test that NO_ERROR in HTTP/2 connection loss is properly handled
580-
// This test verifies that when a connection is lost in a way that produces
581-
// an error message containing "NO_ERROR", the connection lost handler is called
578+
t.Run("WithConnectionLost", func(t *testing.T) {
579+
// Test error handling on connection loss
580+
// Verify the connection loss handler is triggered by an error message
582581

583582
var connectionLostCalled bool
584583
var connectionLostError error
585584
var mu sync.Mutex
586585

587-
// Create a mock Reader that simulates connection loss with NO_ERROR
586+
// Create a mock Reader that simulates connection loss with ERROR
588587
mockReader := &mockReaderWithError{
589588
data: []byte("event: endpoint\ndata: /message\n\n"),
590-
err: errors.New("http2: stream closed with error code NO_ERROR"),
589+
err: errors.New("context deadline exceeded (Client.Timeout or context cancellation while reading body)"),
591590
}
592591

593592
// Create SSE transport
@@ -607,7 +606,7 @@ func TestSSE(t *testing.T) {
607606
connectionLostError = err
608607
})
609608

610-
// Directly test the readSSE method with our mock reader that simulates NO_ERROR
609+
// Directly test the readSSE method with our mock reader that simulates ERROR
611610
go trans.readSSE(mockReader)
612611

613612
// Wait for connection lost handler to be called
@@ -618,138 +617,22 @@ func TestSSE(t *testing.T) {
618617
for {
619618
select {
620619
case <-timeout:
621-
t.Fatal("Connection lost handler was not called within timeout for NO_ERROR connection loss")
620+
t.Fatal("Connection lost handler was not called within timeout for connection loss")
622621
case <-ticker.C:
623622
mu.Lock()
624623
called := connectionLostCalled
625624
err := connectionLostError
626625
mu.Unlock()
627-
628626
if called {
629627
if err == nil {
630628
t.Fatal("Expected connection lost error, got nil")
631629
}
632-
633-
// Verify that the error contains "NO_ERROR" string
634-
if !strings.Contains(err.Error(), "NO_ERROR") {
635-
t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err)
636-
}
637-
638-
t.Logf("Connection lost handler called with NO_ERROR: %v", err)
630+
t.Logf("Connection lost handler called with error: %v", err)
639631
return
640632
}
641633
}
642634
}
643635
})
644-
645-
t.Run("NO_ERROR_Handling", func(t *testing.T) {
646-
// Test specific NO_ERROR string handling in readSSE method
647-
// This tests the code path at line 209 where NO_ERROR is checked
648-
649-
// Create a mock Reader that simulates an error containing "NO_ERROR"
650-
mockReader := &mockReaderWithError{
651-
data: []byte("event: endpoint\ndata: /message\n\n"),
652-
err: errors.New("connection closed: NO_ERROR"),
653-
}
654-
655-
// Create SSE transport
656-
url, closeF := startMockSSEEchoServer()
657-
defer closeF()
658-
659-
trans, err := NewSSE(url)
660-
if err != nil {
661-
t.Fatal(err)
662-
}
663-
664-
var connectionLostCalled bool
665-
var connectionLostError error
666-
var mu sync.Mutex
667-
668-
// Set connection lost handler to verify it's called for NO_ERROR
669-
trans.SetConnectionLostHandler(func(err error) {
670-
mu.Lock()
671-
defer mu.Unlock()
672-
connectionLostCalled = true
673-
connectionLostError = err
674-
})
675-
676-
// Directly test the readSSE method with our mock reader
677-
go trans.readSSE(mockReader)
678-
679-
// Wait for connection lost handler to be called
680-
timeout := time.After(1 * time.Second)
681-
ticker := time.NewTicker(10 * time.Millisecond)
682-
defer ticker.Stop()
683-
684-
for {
685-
select {
686-
case <-timeout:
687-
t.Fatal("Connection lost handler was not called within timeout for NO_ERROR")
688-
case <-ticker.C:
689-
mu.Lock()
690-
called := connectionLostCalled
691-
err := connectionLostError
692-
mu.Unlock()
693-
694-
if called {
695-
if err == nil {
696-
t.Fatal("Expected connection lost error with NO_ERROR, got nil")
697-
}
698-
699-
// Verify that the error contains "NO_ERROR" string
700-
if !strings.Contains(err.Error(), "NO_ERROR") {
701-
t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err)
702-
}
703-
704-
t.Logf("Successfully handled NO_ERROR: %v", err)
705-
return
706-
}
707-
}
708-
}
709-
})
710-
711-
t.Run("RegularError_DoesNotTriggerConnectionLost", func(t *testing.T) {
712-
// Test that regular errors (not containing NO_ERROR) do not trigger connection lost handler
713-
714-
// Create a mock Reader that simulates a regular error
715-
mockReader := &mockReaderWithError{
716-
data: []byte("event: endpoint\ndata: /message\n\n"),
717-
err: errors.New("regular connection error"),
718-
}
719-
720-
// Create SSE transport
721-
url, closeF := startMockSSEEchoServer()
722-
defer closeF()
723-
724-
trans, err := NewSSE(url)
725-
if err != nil {
726-
t.Fatal(err)
727-
}
728-
729-
var connectionLostCalled bool
730-
var mu sync.Mutex
731-
732-
// Set connection lost handler - this should NOT be called for regular errors
733-
trans.SetConnectionLostHandler(func(err error) {
734-
mu.Lock()
735-
defer mu.Unlock()
736-
connectionLostCalled = true
737-
})
738-
739-
// Directly test the readSSE method with our mock reader
740-
go trans.readSSE(mockReader)
741-
742-
// Wait and verify connection lost handler is NOT called
743-
time.Sleep(200 * time.Millisecond)
744-
745-
mu.Lock()
746-
called := connectionLostCalled
747-
mu.Unlock()
748-
749-
if called {
750-
t.Error("Connection lost handler should not be called for regular errors")
751-
}
752-
})
753636
}
754637

755638
func TestSSEErrors(t *testing.T) {

0 commit comments

Comments
 (0)