Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,7 @@ func TestElicitationCapabilityDeclaration(t *testing.T) {
RequestedSchema: &jsonschema.Schema{Type: "object"},
})
if err != nil {
t.Errorf("elicitation should work when capability is declared, got error: %v", err)
t.Fatalf("elicitation should work when capability is declared, got error: %v", err)
}
if result.Action != "cancel" {
t.Errorf("got action %q, want %q", result.Action, "cancel")
Expand Down
38 changes: 37 additions & 1 deletion mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,33 @@ func (ss *ServerSession) updateState(mut func(*ServerSessionState)) {
}
}

// hasInitialized reports whether the server has received the initialized
// notification.
//
// TODO(findleyr): use this to prevent change notifications.
func (ss *ServerSession) hasInitialized() bool {
ss.mu.Lock()
defer ss.mu.Unlock()
return ss.state.InitializedParams != nil
}

// checkInitialized returns a formatted error if the server has not yet
// received the initialized notification.
func (ss *ServerSession) checkInitialized(method string) error {
if !ss.hasInitialized() {
// TODO(rfindley): enable this check.
// Right now is is flaky, because server tests don't await the initialized notification.
// Perhaps requests should simply block until they have received the initialized notification

// if strings.HasPrefix(method, "notifications/") {
// return fmt.Errorf("must not send %q before %q is received", method, notificationInitialized)
// } else {
// return fmt.Errorf("cannot call %q before %q is received", method, notificationInitialized)
// }
}
return nil
}

func (ss *ServerSession) ID() string {
if c, ok := ss.mcpConn.(hasSessionID); ok {
return c.SessionID()
Expand All @@ -859,11 +886,17 @@ func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error {

// ListRoots lists the client roots.
func (ss *ServerSession) ListRoots(ctx context.Context, params *ListRootsParams) (*ListRootsResult, error) {
if err := ss.checkInitialized(methodListRoots); err != nil {
return nil, err
}
return handleSend[*ListRootsResult](ctx, methodListRoots, newServerRequest(ss, orZero[Params](params)))
}

// CreateMessage sends a sampling request to the client.
func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessageParams) (*CreateMessageResult, error) {
if err := ss.checkInitialized(methodCreateMessage); err != nil {
return nil, err
}
if params == nil {
params = &CreateMessageParams{Messages: []*SamplingMessage{}}
}
Expand All @@ -877,6 +910,9 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag

// Elicit sends an elicitation request to the client asking for user input.
func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) {
if err := ss.checkInitialized(methodElicit); err != nil {
return nil, err
}
return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params)))
}

Expand Down Expand Up @@ -978,7 +1014,7 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn }
// handle invokes the method described by the given JSON RPC request.
func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
ss.mu.Lock()
initialized := ss.state.InitializedParams != nil
initialized := ss.state.InitializeParams != nil
ss.mu.Unlock()

// From the spec:
Expand Down
5 changes: 4 additions & 1 deletion mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,12 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P)
if sessions == nil {
return
}
// TODO: make this timeout configurable, or call Notify asynchronously.
// TODO: make this timeout configurable, or call handleNotify asynchronously.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// TODO: there's a potential spec violation here, when the feature list
// changes before the session (client or server) is initialized.
for _, s := range sessions {
req := newRequest(s, params)
if err := handleNotify(ctx, method, req); err != nil {
Expand Down
14 changes: 11 additions & 3 deletions mcp/testdata/conformance/server/lifecycle.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ See also modelcontextprotocol/go-sdk#225.

-- client --
{ "jsonrpc":"2.0", "method": "notifications/initialized" }
{ "jsonrpc": "2.0", "id": 2, "method": "tools/list" }
{
"jsonrpc": "2.0",
"id": 1,
Expand All @@ -21,6 +22,14 @@ See also modelcontextprotocol/go-sdk#225.
{ "jsonrpc": "2.0", "id": 3, "method": "tools/list" }

-- server --
{
"jsonrpc": "2.0",
"id": 2,
"error": {
"code": 0,
"message": "method \"tools/list\" is invalid during session initialization"
}
}
{
"jsonrpc": "2.0",
"id": 1,
Expand All @@ -43,9 +52,8 @@ See also modelcontextprotocol/go-sdk#225.
{
"jsonrpc": "2.0",
"id": 2,
"error": {
"code": 0,
"message": "method \"tools/list\" is invalid during session initialization"
"result": {
"tools": []
}
}
{
Expand Down