Skip to content

Commit 41599ae

Browse files
committed
Merge branch 'main' into ex1
2 parents 5f2b816 + c879ac3 commit 41599ae

File tree

9 files changed

+245
-21
lines changed

9 files changed

+245
-21
lines changed

design/design.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -748,13 +748,26 @@ Server sessions also support the spec methods `ListResources` and `ListResourceT
748748
749749
#### Subscriptions
750750
751-
ClientSessions can manage change notifications on particular resources:
751+
##### Client-Side Usage
752+
753+
Use the Subscribe and Unsubscribe methods on a ClientSession to start or stop receiving updates for a specific resource.
752754
753755
```go
754756
func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error
755757
func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error
756758
```
757759
760+
To process incoming update notifications, you must provide a ResourceUpdatedHandler in your ClientOptions. The SDK calls this function automatically whenever the server sends a notification for a resource you're subscribed to.
761+
762+
```go
763+
type ClientOptions struct {
764+
...
765+
ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams)
766+
}
767+
```
768+
769+
##### Server-Side Implementation
770+
758771
The server does not implement resource subscriptions. It passes along subscription requests to the user, and supplies a method to notify clients of changes. It tracks which sessions have subscribed to which resources so the user doesn't have to.
759772
760773
If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers.
@@ -772,7 +785,7 @@ type ServerOptions struct {
772785
User code should call `ResourceUpdated` when a subscribed resource changes.
773786
774787
```go
775-
func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error
788+
func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotificationParams) error
776789
```
777790
778791
The server routes these notifications to the server sessions that subscribed to the resource.

examples/sse/main.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,5 @@ func main() {
6666
return nil
6767
}
6868
})
69-
7069
log.Fatal(http.ListenAndServe(addr, handler))
7170
}

mcp/client.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ type ClientOptions struct {
6060
ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams)
6161
PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams)
6262
ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams)
63+
ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams)
6364
LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams)
6465
ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams)
6566
// If non-zero, defines an interval for regular "ping" requests.
@@ -293,6 +294,7 @@ var clientMethodInfos = map[string]methodInfo{
293294
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)),
294295
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)),
295296
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)),
297+
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)),
296298
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)),
297299
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)),
298300
}
@@ -386,6 +388,20 @@ func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) (
386388
return handleSend[*CompleteResult](ctx, cs, methodComplete, orZero[Params](params))
387389
}
388390

391+
// Subscribe sends a "resources/subscribe" request to the server, asking for
392+
// notifications when the specified resource changes.
393+
func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error {
394+
_, err := handleSend[*emptyResult](ctx, cs, methodSubscribe, orZero[Params](params))
395+
return err
396+
}
397+
398+
// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling
399+
// a previous subscription.
400+
func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error {
401+
_, err := handleSend[*emptyResult](ctx, cs, methodUnsubscribe, orZero[Params](params))
402+
return err
403+
}
404+
389405
func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) {
390406
return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params)
391407
}
@@ -398,6 +414,10 @@ func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSessio
398414
return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params)
399415
}
400416

417+
func (c *Client) callResourceUpdatedHandler(ctx context.Context, s *ClientSession, params *ResourceUpdatedNotificationParams) (Result, error) {
418+
return callNotificationHandler(ctx, c.opts.ResourceUpdatedHandler, s, params)
419+
}
420+
401421
func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) {
402422
if h := c.opts.LoggingMessageHandler; h != nil {
403423
h(ctx, cs, params)

mcp/mcp_test.go

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func TestEndToEnd(t *testing.T) {
6060

6161
// Channels to check if notification callbacks happened.
6262
notificationChans := map[string]chan int{}
63-
for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} {
63+
for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe"} {
6464
notificationChans[name] = make(chan int, 1)
6565
}
6666
waitForNotification := func(t *testing.T, name string) {
@@ -78,6 +78,14 @@ func TestEndToEnd(t *testing.T) {
7878
ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) {
7979
notificationChans["progress_server"] <- 0
8080
},
81+
SubscribeHandler: func(context.Context, *SubscribeParams) error {
82+
notificationChans["subscribe"] <- 0
83+
return nil
84+
},
85+
UnsubscribeHandler: func(context.Context, *UnsubscribeParams) error {
86+
notificationChans["unsubscribe"] <- 0
87+
return nil
88+
},
8189
}
8290
s := NewServer(testImpl, sopts)
8391
AddTool(s, &Tool{
@@ -128,6 +136,9 @@ func TestEndToEnd(t *testing.T) {
128136
ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) {
129137
notificationChans["progress_client"] <- 0
130138
},
139+
ResourceUpdatedHandler: func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) {
140+
notificationChans["resource_updated"] <- 0
141+
},
131142
}
132143
c := NewClient(testImpl, opts)
133144
rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files"))
@@ -421,6 +432,37 @@ func TestEndToEnd(t *testing.T) {
421432
waitForNotification(t, "progress_server")
422433
})
423434

435+
t.Run("resource_subscriptions", func(t *testing.T) {
436+
err := cs.Subscribe(ctx, &SubscribeParams{
437+
URI: "test",
438+
})
439+
if err != nil {
440+
t.Fatal(err)
441+
}
442+
waitForNotification(t, "subscribe")
443+
s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{
444+
URI: "test",
445+
})
446+
waitForNotification(t, "resource_updated")
447+
err = cs.Unsubscribe(ctx, &UnsubscribeParams{
448+
URI: "test",
449+
})
450+
if err != nil {
451+
t.Fatal(err)
452+
}
453+
waitForNotification(t, "unsubscribe")
454+
455+
// Verify the client does not receive the update after unsubscribing.
456+
s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{
457+
URI: "test",
458+
})
459+
select {
460+
case <-notificationChans["resource_updated"]:
461+
t.Fatalf("resource updated after unsubscription")
462+
case <-time.After(time.Second):
463+
}
464+
})
465+
424466
// Disconnect.
425467
cs.Close()
426468
clientWG.Wait()
@@ -617,7 +659,7 @@ func TestCancellation(t *testing.T) {
617659
return nil, nil
618660
}
619661
_, cs := basicConnection(t, func(s *Server) {
620-
s.AddTool(&Tool{Name: "slow"}, slowRequest)
662+
s.AddTool(&Tool{Name: "slow", InputSchema: &jsonschema.Schema{}}, slowRequest)
621663
})
622664
defer cs.Close()
623665

mcp/protocol.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,38 @@ type ToolListChangedParams struct {
859859
func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) }
860860
func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) }
861861

862+
// Sent from the client to request resources/updated notifications from the
863+
// server whenever a particular resource changes.
864+
type SubscribeParams struct {
865+
// This property is reserved by the protocol to allow clients and servers to
866+
// attach additional metadata to their responses.
867+
Meta `json:"_meta,omitempty"`
868+
// The URI of the resource to subscribe to.
869+
URI string `json:"uri"`
870+
}
871+
872+
// Sent from the client to request cancellation of resources/updated
873+
// notifications from the server. This should follow a previous
874+
// resources/subscribe request.
875+
type UnsubscribeParams struct {
876+
// This property is reserved by the protocol to allow clients and servers to
877+
// attach additional metadata to their responses.
878+
Meta `json:"_meta,omitempty"`
879+
// The URI of the resource to unsubscribe from.
880+
URI string `json:"uri"`
881+
}
882+
883+
// A notification from the server to the client, informing it that a resource
884+
// has changed and may need to be read again. This should only be sent if the
885+
// client previously sent a resources/subscribe request.
886+
type ResourceUpdatedNotificationParams struct {
887+
// This property is reserved by the protocol to allow clients and servers to
888+
// attach additional metadata to their responses.
889+
Meta `json:"_meta,omitempty"`
890+
// The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to.
891+
URI string `json:"uri"`
892+
}
893+
862894
// TODO(jba): add CompleteRequest and related types.
863895

864896
// TODO(jba): add ElicitRequest and related types.

mcp/server.go

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
"encoding/json"
1313
"fmt"
1414
"iter"
15-
"log"
15+
"maps"
1616
"net/url"
1717
"path/filepath"
1818
"slices"
@@ -43,6 +43,7 @@ type Server struct {
4343
sessions []*ServerSession
4444
sendingMethodHandler_ MethodHandler[*ServerSession]
4545
receivingMethodHandler_ MethodHandler[*ServerSession]
46+
resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool
4647
}
4748

4849
// ServerOptions is used to configure behavior of the server.
@@ -64,6 +65,10 @@ type ServerOptions struct {
6465
// If the peer fails to respond to pings originating from the keepalive check,
6566
// the session is automatically closed.
6667
KeepAlive time.Duration
68+
// Function called when a client session subscribes to a resource.
69+
SubscribeHandler func(context.Context, *SubscribeParams) error
70+
// Function called when a client session unsubscribes from a resource.
71+
UnsubscribeHandler func(context.Context, *UnsubscribeParams) error
6772
}
6873

6974
// NewServer creates a new MCP server. The resulting server has no features:
@@ -89,7 +94,12 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
8994
if opts.PageSize == 0 {
9095
opts.PageSize = DefaultPageSize
9196
}
92-
97+
if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil {
98+
panic("SubscribeHandler requires UnsubscribeHandler")
99+
}
100+
if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil {
101+
panic("UnsubscribeHandler requires SubscribeHandler")
102+
}
93103
return &Server{
94104
impl: impl,
95105
opts: *opts,
@@ -99,6 +109,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
99109
resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }),
100110
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
101111
receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession],
112+
resourceSubscriptions: make(map[string]map[*ServerSession]bool),
102113
}
103114
}
104115

@@ -120,13 +131,18 @@ func (s *Server) RemovePrompts(names ...string) {
120131
}
121132

122133
// AddTool adds a [Tool] to the server, or replaces one with the same name.
123-
// The tool's input schema must be non-nil.
124134
// The Tool argument must not be modified after this call.
135+
//
136+
// The tool's input schema must be non-nil. For a tool that takes no input,
137+
// or one where any input is valid, set [Tool.InputSchema] to the empty schema,
138+
// &jsonschema.Schema{}.
125139
func (s *Server) AddTool(t *Tool, h ToolHandler) {
126-
// TODO(jba): This is a breaking behavior change. Add before v0.2.0?
127140
if t.InputSchema == nil {
128-
log.Printf("mcp: tool %q has a nil input schema. This will panic in a future release.", t.Name)
129-
// panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name))
141+
// This prevents the tool author from forgetting to write a schema where
142+
// one should be provided. If we papered over this by supplying the empty
143+
// schema, then every input would be validated and the problem wouldn't be
144+
// discovered until runtime, when the LLM sent bad data.
145+
panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name))
130146
}
131147
if err := addToolErr(s, t, h); err != nil {
132148
panic(err)
@@ -225,6 +241,9 @@ func (s *Server) capabilities() *serverCapabilities {
225241
}
226242
if s.resources.len() > 0 || s.resourceTemplates.len() > 0 {
227243
caps.Resources = &resourceCapabilities{ListChanged: true}
244+
if s.opts.SubscribeHandler != nil {
245+
caps.Resources.Subscribe = true
246+
}
228247
}
229248
return caps
230249
}
@@ -428,6 +447,57 @@ func fileResourceHandler(dir string) ResourceHandler {
428447
}
429448
}
430449

450+
// ResourceUpdated sends a notification to all clients that have subscribed to the
451+
// resource specified in params. This method is the primary way for a
452+
// server author to signal that a resource has changed.
453+
func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error {
454+
s.mu.Lock()
455+
subscribedSessions := s.resourceSubscriptions[params.URI]
456+
sessions := slices.Collect(maps.Keys(subscribedSessions))
457+
s.mu.Unlock()
458+
notifySessions(sessions, notificationResourceUpdated, params)
459+
return nil
460+
}
461+
462+
func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *SubscribeParams) (*emptyResult, error) {
463+
if s.opts.SubscribeHandler == nil {
464+
return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound)
465+
}
466+
if err := s.opts.SubscribeHandler(ctx, params); err != nil {
467+
return nil, err
468+
}
469+
470+
s.mu.Lock()
471+
defer s.mu.Unlock()
472+
if s.resourceSubscriptions[params.URI] == nil {
473+
s.resourceSubscriptions[params.URI] = make(map[*ServerSession]bool)
474+
}
475+
s.resourceSubscriptions[params.URI][ss] = true
476+
477+
return &emptyResult{}, nil
478+
}
479+
480+
func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *UnsubscribeParams) (*emptyResult, error) {
481+
if s.opts.UnsubscribeHandler == nil {
482+
return nil, jsonrpc2.ErrMethodNotFound
483+
}
484+
485+
if err := s.opts.UnsubscribeHandler(ctx, params); err != nil {
486+
return nil, err
487+
}
488+
489+
s.mu.Lock()
490+
defer s.mu.Unlock()
491+
if subscribedSessions, ok := s.resourceSubscriptions[params.URI]; ok {
492+
delete(subscribedSessions, ss)
493+
if len(subscribedSessions) == 0 {
494+
delete(s.resourceSubscriptions, params.URI)
495+
}
496+
}
497+
498+
return &emptyResult{}, nil
499+
}
500+
431501
// Run runs the server over the given transport, which must be persistent.
432502
//
433503
// Run blocks until the client terminates the connection or the provided
@@ -475,6 +545,10 @@ func (s *Server) disconnect(cc *ServerSession) {
475545
s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool {
476546
return cc2 == cc
477547
})
548+
549+
for _, subscribedSessions := range s.resourceSubscriptions {
550+
delete(subscribedSessions, cc)
551+
}
478552
}
479553

480554
// Connect connects the MCP server over the given transport and starts handling
@@ -540,7 +614,7 @@ func (ss *ServerSession) ID() string {
540614

541615
// Ping pings the client.
542616
func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error {
543-
_, err := handleSend[*emptyResult](ctx, ss, methodPing, params)
617+
_, err := handleSend[*emptyResult](ctx, ss, methodPing, orZero[Params](params))
544618
return err
545619
}
546620

@@ -616,6 +690,8 @@ var serverMethodInfos = map[string]methodInfo{
616690
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)),
617691
methodReadResource: newMethodInfo(serverMethod((*Server).readResource)),
618692
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)),
693+
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)),
694+
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)),
619695
notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)),
620696
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)),
621697
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)),

0 commit comments

Comments
 (0)