@@ -844,6 +844,33 @@ func (ss *ServerSession) updateState(mut func(*ServerSessionState)) {
844844 }
845845}
846846
847+ // hasInitialized reports whether the server has received the initialized
848+ // notification.
849+ //
850+ // TODO(findleyr): use this to prevent change notifications.
851+ func (ss * ServerSession ) hasInitialized () bool {
852+ ss .mu .Lock ()
853+ defer ss .mu .Unlock ()
854+ return ss .state .InitializedParams != nil
855+ }
856+
857+ // checkInitialized returns a formatted error if the server has not yet
858+ // received the initialized notification.
859+ func (ss * ServerSession ) checkInitialized (method string ) error {
860+ if ! ss .hasInitialized () {
861+ // TODO(rfindley): enable this check.
862+ // Right now is is flaky, because server tests don't await the initialized notification.
863+ // Perhaps requests should simply block until they have received the initialized notification
864+
865+ // if strings.HasPrefix(method, "notifications/") {
866+ // return fmt.Errorf("must not send %q before %q is received", method, notificationInitialized)
867+ // } else {
868+ // return fmt.Errorf("cannot call %q before %q is received", method, notificationInitialized)
869+ // }
870+ }
871+ return nil
872+ }
873+
847874func (ss * ServerSession ) ID () string {
848875 if c , ok := ss .mcpConn .(hasSessionID ); ok {
849876 return c .SessionID ()
@@ -859,11 +886,17 @@ func (ss *ServerSession) Ping(ctx context.Context, params *PingParams) error {
859886
860887// ListRoots lists the client roots.
861888func (ss * ServerSession ) ListRoots (ctx context.Context , params * ListRootsParams ) (* ListRootsResult , error ) {
889+ if err := ss .checkInitialized (methodListRoots ); err != nil {
890+ return nil , err
891+ }
862892 return handleSend [* ListRootsResult ](ctx , methodListRoots , newServerRequest (ss , orZero [Params ](params )))
863893}
864894
865895// CreateMessage sends a sampling request to the client.
866896func (ss * ServerSession ) CreateMessage (ctx context.Context , params * CreateMessageParams ) (* CreateMessageResult , error ) {
897+ if err := ss .checkInitialized (methodCreateMessage ); err != nil {
898+ return nil , err
899+ }
867900 if params == nil {
868901 params = & CreateMessageParams {Messages : []* SamplingMessage {}}
869902 }
@@ -877,6 +910,9 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag
877910
878911// Elicit sends an elicitation request to the client asking for user input.
879912func (ss * ServerSession ) Elicit (ctx context.Context , params * ElicitParams ) (* ElicitResult , error ) {
913+ if err := ss .checkInitialized (methodElicit ); err != nil {
914+ return nil , err
915+ }
880916 return handleSend [* ElicitResult ](ctx , methodElicit , newServerRequest (ss , orZero [Params ](params )))
881917}
882918
@@ -978,7 +1014,7 @@ func (ss *ServerSession) getConn() *jsonrpc2.Connection { return ss.conn }
9781014// handle invokes the method described by the given JSON RPC request.
9791015func (ss * ServerSession ) handle (ctx context.Context , req * jsonrpc.Request ) (any , error ) {
9801016 ss .mu .Lock ()
981- initialized := ss .state .InitializedParams != nil
1017+ initialized := ss .state .InitializeParams != nil
9821018 ss .mu .Unlock ()
9831019
9841020 // From the spec:
0 commit comments