Skip to content

Commit a1a3510

Browse files
authored
mcp/server: make Run return on context cancel (#111)
Make `Server.Run` return when the provided context is canceled. Add tests for when the `Run` context is cancelled and also for when the server process receives a signal. Fixes #107.
1 parent 221febd commit a1a3510

File tree

2 files changed

+121
-14
lines changed

2 files changed

+121
-14
lines changed

mcp/cmd_test.go

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ package mcp_test
66

77
import (
88
"context"
9+
"errors"
910
"log"
1011
"os"
1112
"os/exec"
1213
"runtime"
1314
"testing"
15+
"time"
1416

1517
"github.com/google/go-cmp/cmp"
1618
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -37,36 +39,103 @@ func runServer() {
3739
}
3840
}
3941

40-
func TestCmdTransport(t *testing.T) {
41-
// Conservatively, limit to major OS where we know that os.Exec is
42-
// supported.
43-
switch runtime.GOOS {
44-
case "darwin", "linux", "windows":
45-
default:
46-
t.Skip("unsupported OS")
42+
func TestServerRunContextCancel(t *testing.T) {
43+
server := mcp.NewServer("greeter", "v0.0.1", nil)
44+
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)
45+
46+
ctx, cancel := context.WithCancel(context.Background())
47+
defer cancel()
48+
49+
serverTransport, clientTransport := mcp.NewInMemoryTransports()
50+
51+
// run the server and capture the exit error
52+
onServerExit := make(chan error)
53+
go func() {
54+
onServerExit <- server.Run(ctx, serverTransport)
55+
}()
56+
57+
// send a ping to the server to ensure it's running
58+
client := mcp.NewClient("client", "v0.0.1", nil)
59+
session, err := client.Connect(ctx, clientTransport)
60+
if err != nil {
61+
t.Fatal(err)
62+
}
63+
if err := session.Ping(context.Background(), nil); err != nil {
64+
t.Fatal(err)
4765
}
4866

67+
// cancel the context to stop the server
68+
cancel()
69+
70+
// wait for the server to exit
71+
// TODO: use synctest when availble
72+
select {
73+
case <-time.After(5 * time.Second):
74+
t.Fatal("server did not exit after context cancellation")
75+
case err := <-onServerExit:
76+
if !errors.Is(err, context.Canceled) {
77+
t.Fatalf("server did not exit after context cancellation, got error: %v", err)
78+
}
79+
}
80+
}
81+
82+
func TestServerInterrupt(t *testing.T) {
83+
requireExec(t)
84+
4985
ctx, cancel := context.WithCancel(context.Background())
5086
defer cancel()
5187

52-
exe, err := os.Executable()
88+
cmd := createServerCommand(t)
89+
90+
client := mcp.NewClient("client", "v0.0.1", nil)
91+
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
5392
if err != nil {
5493
t.Fatal(err)
5594
}
56-
cmd := exec.Command(exe)
57-
cmd.Env = append(os.Environ(), runAsServer+"=true")
95+
96+
// get a signal when the server process exits
97+
onExit := make(chan struct{})
98+
go func() {
99+
cmd.Process.Wait()
100+
close(onExit)
101+
}()
102+
103+
// send a signal to the server process to terminate it
104+
if runtime.GOOS == "windows" {
105+
// Windows does not support os.Interrupt
106+
session.Close()
107+
} else {
108+
cmd.Process.Signal(os.Interrupt)
109+
}
110+
111+
// wait for the server to exit
112+
// TODO: use synctest when availble
113+
select {
114+
case <-time.After(5 * time.Second):
115+
t.Fatal("server did not exit after SIGTERM")
116+
case <-onExit:
117+
}
118+
}
119+
120+
func TestCmdTransport(t *testing.T) {
121+
requireExec(t)
122+
123+
ctx, cancel := context.WithCancel(context.Background())
124+
defer cancel()
125+
126+
cmd := createServerCommand(t)
58127

59128
client := mcp.NewClient("client", "v0.0.1", nil)
60129
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
61130
if err != nil {
62-
log.Fatal(err)
131+
t.Fatal(err)
63132
}
64133
got, err := session.CallTool(ctx, &mcp.CallToolParams{
65134
Name: "greet",
66135
Arguments: map[string]any{"name": "user"},
67136
})
68137
if err != nil {
69-
log.Fatal(err)
138+
t.Fatal(err)
70139
}
71140
want := &mcp.CallToolResult{
72141
Content: []mcp.Content{
@@ -80,3 +149,28 @@ func TestCmdTransport(t *testing.T) {
80149
t.Fatalf("closing server: %v", err)
81150
}
82151
}
152+
153+
func createServerCommand(t *testing.T) *exec.Cmd {
154+
t.Helper()
155+
156+
exe, err := os.Executable()
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
cmd := exec.Command(exe)
161+
cmd.Env = append(os.Environ(), runAsServer+"=true")
162+
163+
return cmd
164+
}
165+
166+
func requireExec(t *testing.T) {
167+
t.Helper()
168+
169+
// Conservatively, limit to major OS where we know that os.Exec is
170+
// supported.
171+
switch runtime.GOOS {
172+
case "darwin", "linux", "windows":
173+
default:
174+
t.Skip("unsupported OS")
175+
}
176+
}

mcp/server.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,26 @@ func fileResourceHandler(dir string) ResourceHandler {
404404

405405
// Run runs the server over the given transport, which must be persistent.
406406
//
407-
// Run blocks until the client terminates the connection.
407+
// Run blocks until the client terminates the connection or the provided
408+
// context is cancelled. If the context is cancelled, Run closes the connection.
408409
func (s *Server) Run(ctx context.Context, t Transport) error {
409410
ss, err := s.Connect(ctx, t)
410411
if err != nil {
411412
return err
412413
}
413-
return ss.Wait()
414+
415+
ssClosed := make(chan error)
416+
go func() {
417+
ssClosed <- ss.Wait()
418+
}()
419+
420+
select {
421+
case <-ctx.Done():
422+
ss.Close()
423+
return ctx.Err()
424+
case err := <-ssClosed:
425+
return err
426+
}
414427
}
415428

416429
// bind implements the binder[*ServerSession] interface, so that Servers can

0 commit comments

Comments
 (0)