Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 106 additions & 12 deletions mcp/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package mcp_test

import (
"context"
"errors"
"log"
"os"
"os/exec"
"runtime"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/modelcontextprotocol/go-sdk/mcp"
Expand All @@ -37,36 +39,103 @@ func runServer() {
}
}

func TestCmdTransport(t *testing.T) {
// Conservatively, limit to major OS where we know that os.Exec is
// supported.
switch runtime.GOOS {
case "darwin", "linux", "windows":
default:
t.Skip("unsupported OS")
func TestServerRunContextCancel(t *testing.T) {
server := mcp.NewServer("greeter", "v0.0.1", nil)
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

serverTransport, clientTransport := mcp.NewInMemoryTransports()

// run the server and capture the exit error
onServerExit := make(chan error)
go func() {
onServerExit <- server.Run(ctx, serverTransport)
}()

// send a ping to the server to ensure it's running
client := mcp.NewClient("client", "v0.0.1", nil)
session, err := client.Connect(ctx, clientTransport)
if err != nil {
t.Fatal(err)
}
if err := session.Ping(context.Background(), nil); err != nil {
t.Fatal(err)
}

// cancel the context to stop the server
cancel()

// wait for the server to exit
// TODO: use synctest when availble
select {
case <-time.After(5 * time.Second):
t.Fatal("server did not exit after context cancellation")
case err := <-onServerExit:
if !errors.Is(err, context.Canceled) {
t.Fatalf("server did not exit after context cancellation, got error: %v", err)
}
}
}

func TestServerInterrupt(t *testing.T) {
requireExec(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

exe, err := os.Executable()
cmd := createServerCommand(t)

client := mcp.NewClient("client", "v0.0.1", nil)
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
if err != nil {
t.Fatal(err)
}
cmd := exec.Command(exe)
cmd.Env = append(os.Environ(), runAsServer+"=true")

// get a signal when the server process exits
onExit := make(chan struct{})
go func() {
cmd.Process.Wait()
close(onExit)
}()

// send a signal to the server process to terminate it
if runtime.GOOS == "windows" {
// Windows does not support os.Interrupt
session.Close()
} else {
cmd.Process.Signal(os.Interrupt)
}

// wait for the server to exit
// TODO: use synctest when availble
select {
case <-time.After(5 * time.Second):
t.Fatal("server did not exit after SIGTERM")
case <-onExit:
}
}

func TestCmdTransport(t *testing.T) {
requireExec(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

cmd := createServerCommand(t)

client := mcp.NewClient("client", "v0.0.1", nil)
session, err := client.Connect(ctx, mcp.NewCommandTransport(cmd))
if err != nil {
log.Fatal(err)
t.Fatal(err)
}
got, err := session.CallTool(ctx, &mcp.CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "user"},
})
if err != nil {
log.Fatal(err)
t.Fatal(err)
}
want := &mcp.CallToolResult{
Content: []mcp.Content{
Expand All @@ -80,3 +149,28 @@ func TestCmdTransport(t *testing.T) {
t.Fatalf("closing server: %v", err)
}
}

func createServerCommand(t *testing.T) *exec.Cmd {
t.Helper()

exe, err := os.Executable()
if err != nil {
t.Fatal(err)
}
cmd := exec.Command(exe)
cmd.Env = append(os.Environ(), runAsServer+"=true")

return cmd
}

func requireExec(t *testing.T) {
t.Helper()

// Conservatively, limit to major OS where we know that os.Exec is
// supported.
switch runtime.GOOS {
case "darwin", "linux", "windows":
default:
t.Skip("unsupported OS")
}
}
17 changes: 15 additions & 2 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,26 @@ func fileResourceHandler(dir string) ResourceHandler {

// Run runs the server over the given transport, which must be persistent.
//
// Run blocks until the client terminates the connection.
// Run blocks until the client terminates the connection or the provided
// context is cancelled. If the context is cancelled, Run closes the connection.
func (s *Server) Run(ctx context.Context, t Transport) error {
ss, err := s.Connect(ctx, t)
if err != nil {
return err
}
return ss.Wait()

ssClosed := make(chan error)
go func() {
ssClosed <- ss.Wait()
}()

select {
case <-ctx.Done():
ss.Close()
return ctx.Err()
case err := <-ssClosed:
return err
}
}

// bind implements the binder[*ServerSession] interface, so that Servers can
Expand Down
Loading