Skip to content

Commit 788f6e1

Browse files
manx98lyingbug
authored andcommitted
fix: MCP Client connection state not marked as closed after connection loss in SSE transport type
1 parent 9c554c4 commit 788f6e1

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

internal/mcp/client.go

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package mcp
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"net/http"
9+
"strings"
810
"time"
911

1012
"github.com/Tencent/WeKnora/internal/logger"
@@ -124,10 +126,32 @@ func NewMCPClient(config *ClientConfig) (MCPClient, error) {
124126
return nil, ErrUnsupportedTransport
125127
}
126128

127-
return &mcpGoClient{
129+
instance := &mcpGoClient{
128130
service: config.Service,
129131
client: mcpClient,
130-
}, nil
132+
}
133+
mcpClient.OnConnectionLost(instance.onConnectionLost)
134+
return instance, nil
135+
}
136+
137+
// onConnectionLost callback when the connection is lost
138+
func (c *mcpGoClient) onConnectionLost(err error) {
139+
_ = c.Disconnect()
140+
logger.Warnf(context.Background(), "MCP server connection has been lost, URL:%s, error:%v", *c.service.URL, err)
141+
}
142+
143+
// checkErrorAndDisconnectIfNeeded Check for errors and call Disconnect when reconnection is required
144+
func (c *mcpGoClient) checkErrorAndDisconnectIfNeeded(err error) {
145+
var transportErr *transport.Error
146+
// In SSE transport type, connection loss does not always actively trigger onConnectionLost (a go-mcp issue).
147+
// Once the connection is lost, the session becomes invalid.
148+
// Without reconnecting, it will continuously cause "Invalid session ID" errors.
149+
if c.service.TransportType == types.MCPTransportSSE &&
150+
errors.As(err, &transportErr) &&
151+
transportErr.Err != nil &&
152+
strings.Contains(transportErr.Err.Error(), "Invalid session ID") {
153+
_ = c.Disconnect()
154+
}
131155
}
132156

133157
// Connect establishes connection to the MCP service
@@ -140,7 +164,6 @@ func (c *mcpGoClient) Connect(ctx context.Context) error {
140164
if err := c.client.Start(ctx); err != nil {
141165
return fmt.Errorf("failed to start client: %w", err)
142166
}
143-
144167
c.connected = true
145168
if c.service.TransportType == types.MCPTransportStdio {
146169
logger.GetLogger(ctx).Infof("MCP stdio client connected: %s %v",
@@ -161,7 +184,6 @@ func (c *mcpGoClient) Disconnect() error {
161184
if c.client != nil {
162185
c.client.Close()
163186
}
164-
165187
c.connected = false
166188
c.initialized = false
167189
return nil
@@ -187,6 +209,7 @@ func (c *mcpGoClient) Initialize(ctx context.Context) (*InitializeResult, error)
187209

188210
result, err := c.client.Initialize(ctx, req)
189211
if err != nil {
212+
c.checkErrorAndDisconnectIfNeeded(err)
190213
return nil, fmt.Errorf("failed to initialize: %w", err)
191214
}
192215

@@ -210,6 +233,7 @@ func (c *mcpGoClient) ListTools(ctx context.Context) ([]*types.MCPTool, error) {
210233
req := mcp.ListToolsRequest{}
211234
result, err := c.client.ListTools(ctx, req)
212235
if err != nil {
236+
c.checkErrorAndDisconnectIfNeeded(err)
213237
return nil, fmt.Errorf("failed to list tools: %w", err)
214238
}
215239

@@ -236,6 +260,7 @@ func (c *mcpGoClient) ListResources(ctx context.Context) ([]*types.MCPResource,
236260
req := mcp.ListResourcesRequest{}
237261
result, err := c.client.ListResources(ctx, req)
238262
if err != nil {
263+
c.checkErrorAndDisconnectIfNeeded(err)
239264
return nil, fmt.Errorf("failed to list resources: %w", err)
240265
}
241266

@@ -268,6 +293,7 @@ func (c *mcpGoClient) CallTool(ctx context.Context, name string, args map[string
268293

269294
result, err := c.client.CallTool(ctx, req)
270295
if err != nil {
296+
c.checkErrorAndDisconnectIfNeeded(err)
271297
return nil, fmt.Errorf("failed to call tool: %w", err)
272298
}
273299

@@ -308,6 +334,7 @@ func (c *mcpGoClient) ReadResource(ctx context.Context, uri string) (*ReadResour
308334

309335
result, err := c.client.ReadResource(ctx, req)
310336
if err != nil {
337+
c.checkErrorAndDisconnectIfNeeded(err)
311338
return nil, fmt.Errorf("failed to read resource: %w", err)
312339
}
313340

0 commit comments

Comments
 (0)