Skip to content

Commit 1b4d2f6

Browse files
authored
Merge branch 'main' into mcp-server-req-auth-2
2 parents 58d9bbe + 1a54234 commit 1b4d2f6

File tree

16 files changed

+867
-472
lines changed

16 files changed

+867
-472
lines changed

.github/workflows/test.yml

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ on:
33
# Manual trigger
44
workflow_dispatch:
55
push:
6-
branches: main
6+
branches: [main]
77
pull_request:
88

99
permissions:
@@ -13,43 +13,51 @@ jobs:
1313
lint:
1414
runs-on: ubuntu-latest
1515
steps:
16-
- name: Check out code
17-
uses: actions/checkout@v4
18-
- name: Set up Go
19-
uses: actions/setup-go@v5
20-
- name: Check formatting
21-
run: |
22-
unformatted=$(gofmt -l .)
23-
if [ -n "$unformatted" ]; then
24-
echo "The following files are not properly formatted:"
25-
echo "$unformatted"
26-
exit 1
27-
fi
28-
echo "All Go files are properly formatted"
16+
- name: Check out code
17+
uses: actions/checkout@v4
18+
- name: Set up Go
19+
uses: actions/setup-go@v5
20+
with:
21+
go-version: "^1.23"
22+
- name: Check formatting
23+
run: |
24+
unformatted=$(gofmt -l .)
25+
if [ -n "$unformatted" ]; then
26+
echo "The following files are not properly formatted:"
27+
echo "$unformatted"
28+
exit 1
29+
fi
30+
echo "All Go files are properly formatted"
31+
- name: Run Go vet
32+
run: go vet ./...
33+
- name: Run staticcheck
34+
uses: dominikh/staticcheck-action@v1
35+
with:
36+
version: "latest"
2937

3038
test:
3139
runs-on: ubuntu-latest
3240
strategy:
3341
matrix:
34-
go: [ '1.23', '1.24', '1.25.0-rc.3' ]
42+
go: ["1.23", "1.24", "1.25.0-rc.3"]
3543
steps:
36-
- name: Check out code
37-
uses: actions/checkout@v4
38-
- name: Set up Go
39-
uses: actions/setup-go@v5
40-
with:
41-
go-version: ${{ matrix.go }}
42-
- name: Test
43-
run: go test -v ./...
44+
- name: Check out code
45+
uses: actions/checkout@v4
46+
- name: Set up Go
47+
uses: actions/setup-go@v5
48+
with:
49+
go-version: ${{ matrix.go }}
50+
- name: Test
51+
run: go test -v ./...
4452

4553
race-test:
4654
runs-on: ubuntu-latest
4755
steps:
48-
- name: Check out code
49-
uses: actions/checkout@v4
50-
- name: Set up Go
51-
uses: actions/setup-go@v5
52-
with:
53-
go-version: '1.24'
54-
- name: Test with -race
55-
run: go test -v -race ./...
56+
- name: Check out code
57+
uses: actions/checkout@v4
58+
- name: Set up Go
59+
uses: actions/setup-go@v5
60+
with:
61+
go-version: "1.24"
62+
- name: Test with -race
63+
run: go test -v -race ./...

internal/jsonrpc2/conn.go

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,46 @@ func (c *Connection) Call(ctx context.Context, method string, params any) *Async
374374
return ac
375375
}
376376

377+
// Async, signals that the current jsonrpc2 request may be handled
378+
// asynchronously to subsequent requests, when ctx is the request context.
379+
//
380+
// Async must be called at most once on each request's context (and its
381+
// descendants).
382+
func Async(ctx context.Context) {
383+
if r, ok := ctx.Value(asyncKey).(*releaser); ok {
384+
r.release(false)
385+
}
386+
}
387+
388+
type asyncKeyType struct{}
389+
390+
var asyncKey = asyncKeyType{}
391+
392+
// A releaser implements concurrency safe 'releasing' of async requests. (A
393+
// request is released when it is allowed to run concurrent with other
394+
// requests, via a call to [Async].)
395+
type releaser struct {
396+
mu sync.Mutex
397+
ch chan struct{}
398+
released bool
399+
}
400+
401+
// release closes the associated channel. If soft is set, multiple calls to
402+
// release are allowed.
403+
func (r *releaser) release(soft bool) {
404+
r.mu.Lock()
405+
defer r.mu.Unlock()
406+
407+
if r.released {
408+
if !soft {
409+
panic("jsonrpc2.Async called multiple times")
410+
}
411+
} else {
412+
close(r.ch)
413+
r.released = true
414+
}
415+
}
416+
377417
type AsyncCall struct {
378418
id ID
379419
ready chan struct{} // closed after response has been set
@@ -425,28 +465,6 @@ func (ac *AsyncCall) Await(ctx context.Context, result any) error {
425465
return json.Unmarshal(ac.response.Result, result)
426466
}
427467

428-
// Respond delivers a response to an incoming Call.
429-
//
430-
// Respond must be called exactly once for any message for which a handler
431-
// returns ErrAsyncResponse. It must not be called for any other message.
432-
func (c *Connection) Respond(id ID, result any, err error) error {
433-
var req *incomingRequest
434-
c.updateInFlight(func(s *inFlightState) {
435-
req = s.incomingByID[id]
436-
})
437-
if req == nil {
438-
return c.internalErrorf("Request not found for ID %v", id)
439-
}
440-
441-
if err == ErrAsyncResponse {
442-
// Respond is supposed to supply the asynchronous response, so it would be
443-
// confusing to call Respond with an error that promises to call Respond
444-
// again.
445-
err = c.internalErrorf("Respond called with ErrAsyncResponse for %q", req.Method)
446-
}
447-
return c.processResult("Respond", req, result, err)
448-
}
449-
450468
// Cancel cancels the Context passed to the Handle call for the inbound message
451469
// with the given ID.
452470
//
@@ -576,11 +594,6 @@ func (c *Connection) acceptRequest(ctx context.Context, msg *Request, preempter
576594
if preempter != nil {
577595
result, err := preempter.Preempt(req.ctx, req.Request)
578596

579-
if req.IsCall() && errors.Is(err, ErrAsyncResponse) {
580-
// This request will remain in flight until Respond is called for it.
581-
return
582-
}
583-
584597
if !errors.Is(err, ErrNotHandled) {
585598
c.processResult("Preempt", req, result, err)
586599
return
@@ -655,19 +668,20 @@ func (c *Connection) handleAsync() {
655668
continue
656669
}
657670

658-
result, err := c.handler.Handle(req.ctx, req.Request)
659-
c.processResult(c.handler, req, result, err)
671+
releaser := &releaser{ch: make(chan struct{})}
672+
ctx := context.WithValue(req.ctx, asyncKey, releaser)
673+
go func() {
674+
defer releaser.release(true)
675+
result, err := c.handler.Handle(ctx, req.Request)
676+
c.processResult(c.handler, req, result, err)
677+
}()
678+
<-releaser.ch
660679
}
661680
}
662681

663682
// processResult processes the result of a request and, if appropriate, sends a response.
664683
func (c *Connection) processResult(from any, req *incomingRequest, result any, err error) error {
665684
switch err {
666-
case ErrAsyncResponse:
667-
if !req.IsCall() {
668-
return c.internalErrorf("%#v returned ErrAsyncResponse for a %q Request without an ID", from, req.Method)
669-
}
670-
return nil // This request is still in flight, so don't record the result yet.
671685
case ErrNotHandled, ErrMethodNotFound:
672686
// Add detail describing the unhandled method.
673687
err = fmt.Errorf("%w: %q", ErrMethodNotFound, req.Method)
@@ -705,10 +719,10 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e
705719
} else if err != nil {
706720
err = fmt.Errorf("%w: %q notification failed: %v", ErrInternal, req.Method, err)
707721
}
708-
if err != nil {
709-
// TODO: can/should we do anything with this error beyond writing it to the event log?
710-
// (Is this the right label to attach to the log?)
711-
}
722+
}
723+
if err != nil {
724+
// TODO: can/should we do anything with this error beyond writing it to the event log?
725+
// (Is this the right label to attach to the log?)
712726
}
713727

714728
// Cancel the request to free any associated resources.
@@ -725,7 +739,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e
725739
// write is used by all things that write outgoing messages, including replies.
726740
// it makes sure that writes are atomic
727741
func (c *Connection) write(ctx context.Context, msg Message) error {
728-
err := c.writer.Write(ctx, msg)
742+
var err error
743+
// Fail writes immediately if the connection is shutting down.
744+
//
745+
// TODO(rfindley): should we allow cancellation notifications through? It
746+
// could be the case that writes can still succeed.
747+
c.updateInFlight(func(s *inFlightState) {
748+
err = s.shuttingDown(ErrServerClosing)
749+
})
750+
if err == nil {
751+
err = c.writer.Write(ctx, msg)
752+
}
753+
754+
// For rejected requests, we don't set the writeErr (which would break the
755+
// connection). They can just be returned to the caller.
756+
if errors.Is(err, ErrRejected) {
757+
return err
758+
}
729759

730760
if err != nil && ctx.Err() == nil {
731761
// The call to Write failed, and since ctx.Err() is nil we can't attribute

internal/jsonrpc2/jsonrpc2.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@ var (
2222
// If a Handler returns ErrNotHandled, the server replies with
2323
// ErrMethodNotFound.
2424
ErrNotHandled = errors.New("JSON RPC not handled")
25-
26-
// ErrAsyncResponse is returned from a handler to indicate it will generate a
27-
// response asynchronously.
28-
//
29-
// ErrAsyncResponse must not be returned for notifications,
30-
// which do not receive responses.
31-
ErrAsyncResponse = errors.New("JSON RPC asynchronous response")
3225
)
3326

3427
// Preempter handles messages on a connection before they are queued to the main

internal/jsonrpc2/jsonrpc2_test.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -371,16 +371,14 @@ func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error
371371
if err := json.Unmarshal(req.Params, &name); err != nil {
372372
return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
373373
}
374+
jsonrpc2.Async(ctx)
374375
waitFor := h.waiter(name)
375-
go func() {
376-
select {
377-
case <-waitFor:
378-
h.conn.Respond(req.ID, true, nil)
379-
case <-ctx.Done():
380-
h.conn.Respond(req.ID, nil, ctx.Err())
381-
}
382-
}()
383-
return nil, jsonrpc2.ErrAsyncResponse
376+
select {
377+
case <-waitFor:
378+
return true, nil
379+
case <-ctx.Done():
380+
return nil, ctx.Err()
381+
}
384382
default:
385383
return nil, jsonrpc2.ErrNotHandled
386384
}

internal/jsonrpc2/wire.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ var (
3737
ErrServerClosing = NewError(-32004, "server is closing")
3838
// ErrClientClosing is a dummy error returned for calls initiated while the client is closing.
3939
ErrClientClosing = NewError(-32003, "client is closing")
40+
41+
// The following errors have special semantics for MCP transports
42+
43+
// ErrRejected may be wrapped to return errors from calls to Writer.Write
44+
// that signal that the request was rejected by the transport layer as
45+
// invalid.
46+
//
47+
// Such failures do not indicate that the connection is broken, but rather
48+
// should be returned to the caller to indicate that the specific request is
49+
// invalid in the current context.
50+
ErrRejected = NewError(-32004, "rejected by transport")
4051
)
4152

4253
const wireVersion = "2.0"

mcp/client.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ func (e unsupportedProtocolVersionError) Error() string {
103103
}
104104

105105
// ClientSessionOptions is reserved for future use.
106-
type ClientSessionOptions struct {
107-
}
106+
type ClientSessionOptions struct{}
108107

109108
// Connect begins an MCP session by connecting to a server over the given
110109
// transport, and initializing the session.
@@ -177,6 +176,8 @@ type clientSessionState struct {
177176
InitializeResult *InitializeResult
178177
}
179178

179+
func (cs *ClientSession) InitializeResult() *InitializeResult { return cs.state.InitializeResult }
180+
180181
func (cs *ClientSession) ID() string {
181182
if c, ok := cs.mcpConn.(hasSessionID); ok {
182183
return c.SessionID()
@@ -323,16 +324,19 @@ func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo {
323324
}
324325

325326
func (cs *ClientSession) handle(ctx context.Context, req *jsonrpc.Request) (any, error) {
327+
if req.IsCall() {
328+
jsonrpc2.Async(ctx)
329+
}
326330
return handleReceive(ctx, cs, req)
327331
}
328332

329-
func (cs *ClientSession) sendingMethodHandler() methodHandler {
333+
func (cs *ClientSession) sendingMethodHandler() MethodHandler {
330334
cs.client.mu.Lock()
331335
defer cs.client.mu.Unlock()
332336
return cs.client.sendingMethodHandler_
333337
}
334338

335-
func (cs *ClientSession) receivingMethodHandler() methodHandler {
339+
func (cs *ClientSession) receivingMethodHandler() MethodHandler {
336340
cs.client.mu.Lock()
337341
defer cs.client.mu.Unlock()
338342
return cs.client.receivingMethodHandler_
@@ -392,7 +396,7 @@ func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (
392396
return handleSend[*CallToolResult](ctx, methodCallTool, newClientRequest(cs, orZero[Params](params)))
393397
}
394398

395-
func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error {
399+
func (cs *ClientSession) SetLoggingLevel(ctx context.Context, params *SetLoggingLevelParams) error {
396400
_, err := handleSend[*emptyResult](ctx, methodSetLevel, newClientRequest(cs, orZero[Params](params)))
397401
return err
398402
}

mcp/conformance_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func runServerTest(t *testing.T, test *conformanceTest) {
183183
return nil, err, false
184184
}
185185
serverMessages = append(serverMessages, msg)
186-
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() {
186+
if req, ok := msg.(*jsonrpc.Request); ok && req.IsCall() {
187187
// Pair up the next outgoing response with this request.
188188
// We assume requests arrive in the same order every time.
189189
if len(outResponses) == 0 {
@@ -201,8 +201,8 @@ func runServerTest(t *testing.T, test *conformanceTest) {
201201
// Synthetic peer interacts with real peer.
202202
for _, req := range outRequests {
203203
writeMsg(req)
204-
if req.ID.IsValid() {
205-
// A request (as opposed to a notification). Wait for the response.
204+
if req.IsCall() {
205+
// A call (as opposed to a notification). Wait for the response.
206206
res, err, ok := nextResponse()
207207
if err != nil {
208208
t.Fatalf("reading server messages failed: %v", err)

mcp/content.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e
253253

254254
func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) {
255255
if wire == nil {
256-
return nil, fmt.Errorf("content wire is nil")
256+
return nil, fmt.Errorf("nil content")
257257
}
258258
if allow != nil && !allow[wire.Type] {
259259
return nil, fmt.Errorf("invalid content type %q", wire.Type)

0 commit comments

Comments
 (0)