@@ -3,8 +3,10 @@ package mcp
33import (
44 "context"
55 "encoding/json"
6+ "errors"
67 "fmt"
78 "net/http"
9+ "strings"
810 "time"
911
1012 "github.com/Tencent/WeKnora/internal/logger"
@@ -124,10 +126,32 @@ func NewMCPClient(config *ClientConfig) (MCPClient, error) {
124126 return nil , ErrUnsupportedTransport
125127 }
126128
127- return & mcpGoClient {
129+ instance := & mcpGoClient {
128130 service : config .Service ,
129131 client : mcpClient ,
130- }, nil
132+ }
133+ mcpClient .OnConnectionLost (instance .onConnectionLost )
134+ return instance , nil
135+ }
136+
137+ // onConnectionLost callback when the connection is lost
138+ func (c * mcpGoClient ) onConnectionLost (err error ) {
139+ _ = c .Disconnect ()
140+ logger .Warnf (context .Background (), "MCP server connection has been lost, URL:%s, error:%v" , * c .service .URL , err )
141+ }
142+
143+ // checkErrorAndDisconnectIfNeeded Check for errors and call Disconnect when reconnection is required
144+ func (c * mcpGoClient ) checkErrorAndDisconnectIfNeeded (err error ) {
145+ var transportErr * transport.Error
146+ // In SSE transport type, connection loss does not always actively trigger onConnectionLost (a go-mcp issue).
147+ // Once the connection is lost, the session becomes invalid.
148+ // Without reconnecting, it will continuously cause "Invalid session ID" errors.
149+ if c .service .TransportType == types .MCPTransportSSE &&
150+ errors .As (err , & transportErr ) &&
151+ transportErr .Err != nil &&
152+ strings .Contains (transportErr .Err .Error (), "Invalid session ID" ) {
153+ _ = c .Disconnect ()
154+ }
131155}
132156
133157// Connect establishes connection to the MCP service
@@ -140,7 +164,6 @@ func (c *mcpGoClient) Connect(ctx context.Context) error {
140164 if err := c .client .Start (ctx ); err != nil {
141165 return fmt .Errorf ("failed to start client: %w" , err )
142166 }
143-
144167 c .connected = true
145168 if c .service .TransportType == types .MCPTransportStdio {
146169 logger .GetLogger (ctx ).Infof ("MCP stdio client connected: %s %v" ,
@@ -161,7 +184,6 @@ func (c *mcpGoClient) Disconnect() error {
161184 if c .client != nil {
162185 c .client .Close ()
163186 }
164-
165187 c .connected = false
166188 c .initialized = false
167189 return nil
@@ -187,6 +209,7 @@ func (c *mcpGoClient) Initialize(ctx context.Context) (*InitializeResult, error)
187209
188210 result , err := c .client .Initialize (ctx , req )
189211 if err != nil {
212+ c .checkErrorAndDisconnectIfNeeded (err )
190213 return nil , fmt .Errorf ("failed to initialize: %w" , err )
191214 }
192215
@@ -210,6 +233,7 @@ func (c *mcpGoClient) ListTools(ctx context.Context) ([]*types.MCPTool, error) {
210233 req := mcp.ListToolsRequest {}
211234 result , err := c .client .ListTools (ctx , req )
212235 if err != nil {
236+ c .checkErrorAndDisconnectIfNeeded (err )
213237 return nil , fmt .Errorf ("failed to list tools: %w" , err )
214238 }
215239
@@ -236,6 +260,7 @@ func (c *mcpGoClient) ListResources(ctx context.Context) ([]*types.MCPResource,
236260 req := mcp.ListResourcesRequest {}
237261 result , err := c .client .ListResources (ctx , req )
238262 if err != nil {
263+ c .checkErrorAndDisconnectIfNeeded (err )
239264 return nil , fmt .Errorf ("failed to list resources: %w" , err )
240265 }
241266
@@ -268,6 +293,7 @@ func (c *mcpGoClient) CallTool(ctx context.Context, name string, args map[string
268293
269294 result , err := c .client .CallTool (ctx , req )
270295 if err != nil {
296+ c .checkErrorAndDisconnectIfNeeded (err )
271297 return nil , fmt .Errorf ("failed to call tool: %w" , err )
272298 }
273299
@@ -308,6 +334,7 @@ func (c *mcpGoClient) ReadResource(ctx context.Context, uri string) (*ReadResour
308334
309335 result , err := c .client .ReadResource (ctx , req )
310336 if err != nil {
337+ c .checkErrorAndDisconnectIfNeeded (err )
311338 return nil , fmt .Errorf ("failed to read resource: %w" , err )
312339 }
313340
0 commit comments