diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 496694a5..bfae0c60 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -6,11 +6,13 @@ package mcp_test import ( "context" + "errors" "log" "os" "os/exec" "runtime" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -37,36 +39,103 @@ func runServer() { } } -func TestCmdTransport(t *testing.T) { - // Conservatively, limit to major OS where we know that os.Exec is - // supported. - switch runtime.GOOS { - case "darwin", "linux", "windows": - default: - t.Skip("unsupported OS") +func TestServerRunContextCancel(t *testing.T) { + server := mcp.NewServer("greeter", "v0.0.1", nil) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + + // run the server and capture the exit error + onServerExit := make(chan error) + go func() { + onServerExit <- server.Run(ctx, serverTransport) + }() + + // send a ping to the server to ensure it's running + client := mcp.NewClient("client", "v0.0.1", nil) + session, err := client.Connect(ctx, clientTransport) + if err != nil { + t.Fatal(err) + } + if err := session.Ping(context.Background(), nil); err != nil { + t.Fatal(err) } + // cancel the context to stop the server + cancel() + + // wait for the server to exit + // TODO: use synctest when availble + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after context cancellation") + case err := <-onServerExit: + if !errors.Is(err, context.Canceled) { + t.Fatalf("server did not exit after context cancellation, got error: %v", err) + } + } +} + +func TestServerInterrupt(t *testing.T) { + requireExec(t) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - exe, err := os.Executable() + cmd := createServerCommand(t) + + client := mcp.NewClient("client", "v0.0.1", nil) + session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) if err != nil { t.Fatal(err) } - cmd := exec.Command(exe) - cmd.Env = append(os.Environ(), runAsServer+"=true") + + // get a signal when the server process exits + onExit := make(chan struct{}) + go func() { + cmd.Process.Wait() + close(onExit) + }() + + // send a signal to the server process to terminate it + if runtime.GOOS == "windows" { + // Windows does not support os.Interrupt + session.Close() + } else { + cmd.Process.Signal(os.Interrupt) + } + + // wait for the server to exit + // TODO: use synctest when availble + select { + case <-time.After(5 * time.Second): + t.Fatal("server did not exit after SIGTERM") + case <-onExit: + } +} + +func TestCmdTransport(t *testing.T) { + requireExec(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cmd := createServerCommand(t) client := mcp.NewClient("client", "v0.0.1", nil) session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd)) if err != nil { - log.Fatal(err) + t.Fatal(err) } got, err := session.CallTool(ctx, &mcp.CallToolParams{ Name: "greet", Arguments: map[string]any{"name": "user"}, }) if err != nil { - log.Fatal(err) + t.Fatal(err) } want := &mcp.CallToolResult{ Content: []mcp.Content{ @@ -80,3 +149,28 @@ func TestCmdTransport(t *testing.T) { t.Fatalf("closing server: %v", err) } } + +func createServerCommand(t *testing.T) *exec.Cmd { + t.Helper() + + exe, err := os.Executable() + if err != nil { + t.Fatal(err) + } + cmd := exec.Command(exe) + cmd.Env = append(os.Environ(), runAsServer+"=true") + + return cmd +} + +func requireExec(t *testing.T) { + t.Helper() + + // Conservatively, limit to major OS where we know that os.Exec is + // supported. + switch runtime.GOOS { + case "darwin", "linux", "windows": + default: + t.Skip("unsupported OS") + } +} diff --git a/mcp/server.go b/mcp/server.go index cd8f808b..0aa054fa 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -404,13 +404,26 @@ func fileResourceHandler(dir string) ResourceHandler { // Run runs the server over the given transport, which must be persistent. // -// Run blocks until the client terminates the connection. +// Run blocks until the client terminates the connection or the provided +// context is cancelled. If the context is cancelled, Run closes the connection. func (s *Server) Run(ctx context.Context, t Transport) error { ss, err := s.Connect(ctx, t) if err != nil { return err } - return ss.Wait() + + ssClosed := make(chan error) + go func() { + ssClosed <- ss.Wait() + }() + + select { + case <-ctx.Done(): + ss.Close() + return ctx.Err() + case err := <-ssClosed: + return err + } } // bind implements the binder[*ServerSession] interface, so that Servers can