Skip to content

Commit b3f9138

Browse files
committed
mcp/server: make Run return on context cancel
Make `Server.Run` return when the provided context is canceled. Add tests for when the `Run` context is cancelled and also for when the server process receives a signal. Fixes #107
1 parent fbff31a commit b3f9138

File tree

2 files changed

+118
-13
lines changed

2 files changed

+118
-13
lines changed

mcp/cmd_test.go

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ package mcp_test
66

77
import (
88
"context"
9+
"errors"
910
"log"
1011
"os"
1112
"os/exec"
1213
"runtime"
1314
"testing"
15+
"time"
1416

1517
"github.com/google/go-cmp/cmp"
1618
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -37,24 +39,89 @@ func runServer() {
3739
}
3840
}
3941

40-
func TestCmdTransport(t *testing.T) {
41-
// Conservatively, limit to major OS where we know that os.Exec is
42-
// supported.
43-
switch runtime.GOOS {
44-
case "darwin", "linux", "windows":
45-
default:
46-
t.Skip("unsupported OS")
42+
func TestServerRunContextCancel(t *testing.T) {
43+
server := mcp.NewServer("greeter", "v0.0.1", nil)
44+
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)
45+
46+
ctx, cancel := context.WithCancel(context.Background())
47+
defer cancel()
48+
49+
serverTransport, clientTransport := mcp.NewInMemoryTransports()
50+
51+
// run the server and capture the exit error
52+
onServerExit := make(chan error)
53+
go func() {
54+
onServerExit <- server.Run(ctx, serverTransport)
55+
}()
56+
57+
// send a ping to the server to ensure it's running
58+
client := mcp.NewClient("client", "v0.0.1", nil)
59+
session, err := client.Connect(ctx, clientTransport)
60+
if err != nil {
61+
log.Fatal(err)
62+
}
63+
if err := session.Ping(context.Background(), nil); err != nil {
64+
log.Fatal(err)
65+
}
66+
67+
// cancel the context to stop the server
68+
cancel()
69+
70+
// wait for the server to exit
71+
select {
72+
case <-time.After(5 * time.Second):
73+
t.Fatal("server did not exit after context cancellation")
74+
case err := <-onServerExit:
75+
if !errors.Is(err, context.Canceled) {
76+
log.Fatalf("server did not exit after context cancellation, got error: %v", err)
77+
}
4778
}
79+
}
80+
81+
func TestServerInterrupt(t *testing.T) {
82+
requireExec(t)
4883

4984
ctx, cancel := context.WithCancel(context.Background())
5085
defer cancel()
5186

52-
exe, err := os.Executable()
87+
cmd := createServerCommand(t)
88+
89+
client := mcp.NewClient("client", "v0.0.1", nil)
90+
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
5391
if err != nil {
54-
t.Fatal(err)
92+
log.Fatal(err)
5593
}
56-
cmd := exec.Command(exe)
57-
cmd.Env = append(os.Environ(), runAsServer+"=true")
94+
95+
// get a signal when the server process exits
96+
onExit := make(chan struct{})
97+
go func() {
98+
cmd.Process.Wait()
99+
close(onExit)
100+
}()
101+
102+
// send a signal to the server process to terminate it
103+
if runtime.GOOS == "windows" {
104+
// Windows does not support os.Interrupt
105+
session.Close()
106+
} else {
107+
cmd.Process.Signal(os.Interrupt)
108+
}
109+
110+
// wait for the server to exit
111+
select {
112+
case <-time.After(5 * time.Second):
113+
t.Fatal("server did not exit after SIGTERM")
114+
case <-onExit:
115+
}
116+
}
117+
118+
func TestCmdTransport(t *testing.T) {
119+
requireExec(t)
120+
121+
ctx, cancel := context.WithCancel(context.Background())
122+
defer cancel()
123+
124+
cmd := createServerCommand(t)
58125

59126
client := mcp.NewClient("client", "v0.0.1", nil)
60127
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
@@ -80,3 +147,28 @@ func TestCmdTransport(t *testing.T) {
80147
t.Fatalf("closing server: %v", err)
81148
}
82149
}
150+
151+
func createServerCommand(t *testing.T) *exec.Cmd {
152+
t.Helper()
153+
154+
exe, err := os.Executable()
155+
if err != nil {
156+
t.Fatal(err)
157+
}
158+
cmd := exec.Command(exe)
159+
cmd.Env = append(os.Environ(), runAsServer+"=true")
160+
161+
return cmd
162+
}
163+
164+
func requireExec(t *testing.T) {
165+
t.Helper()
166+
167+
// Conservatively, limit to major OS where we know that os.Exec is
168+
// supported.
169+
switch runtime.GOOS {
170+
case "darwin", "linux", "windows":
171+
default:
172+
t.Skip("unsupported OS")
173+
}
174+
}

mcp/server.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,26 @@ func fileResourceHandler(dir string) ResourceHandler {
404404

405405
// Run runs the server over the given transport, which must be persistent.
406406
//
407-
// Run blocks until the client terminates the connection.
407+
// Run blocks until the client terminates the connection or the provided
408+
// context is cancelled. If the context is cancelled, Run closes the connection.
408409
func (s *Server) Run(ctx context.Context, t Transport) error {
409410
ss, err := s.Connect(ctx, t)
410411
if err != nil {
411412
return err
412413
}
413-
return ss.Wait()
414+
415+
ssClosed := make(chan error)
416+
go func() {
417+
ssClosed <- ss.Wait()
418+
}()
419+
420+
select {
421+
case <-ctx.Done():
422+
ss.Close()
423+
return ctx.Err()
424+
case err := <-ssClosed:
425+
return err
426+
}
414427
}
415428

416429
// bind implements the binder[*ServerSession] interface, so that Servers can

0 commit comments

Comments
 (0)