Skip to content

Commit e3e8aaf

Browse files
authored
mcp: align PromptHandler args with other handlers (#348)
The second arg is a GetPromptRequest. Fixes #300.
1 parent 5fd06ae commit e3e8aaf

File tree

5 files changed

+8
-8
lines changed

5 files changed

+8
-8
lines changed

examples/server/hello/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.Cal
3030
}, nil, nil
3131
}
3232

33-
func PromptHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {
33+
func PromptHi(ctx context.Context, req *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
3434
return &mcp.GetPromptResult{
3535
Description: "Code review prompt",
3636
Messages: []*mcp.PromptMessage{
37-
{Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + params.Arguments["name"]}},
37+
{Role: "user", Content: &mcp.TextContent{Text: "Say hi to " + req.Params.Arguments["name"]}},
3838
},
3939
}, nil
4040
}

mcp/client_list_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,6 @@ func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) {
126126
}
127127
}
128128

129-
func testPromptHandler(context.Context, *mcp.ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {
129+
func testPromptHandler(context.Context, *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
130130
panic("not implemented")
131131
}

mcp/mcp_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ var codeReviewPrompt = &Prompt{
4747
Arguments: []*PromptArgument{{Name: "Code", Required: true}},
4848
}
4949

50-
func codReviewPromptHandler(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) {
50+
func codReviewPromptHandler(_ context.Context, req *GetPromptRequest) (*GetPromptResult, error) {
5151
return &GetPromptResult{
5252
Description: "Code review prompt",
5353
Messages: []*PromptMessage{
54-
{Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}},
54+
{Role: "user", Content: &TextContent{Text: "Please review the following code: " + req.Params.Arguments["Code"]}},
5555
},
5656
}, nil
5757
}
@@ -103,7 +103,7 @@ func TestEndToEnd(t *testing.T) {
103103
return nil, nil, errTestFailure
104104
})
105105
s.AddPrompt(codeReviewPrompt, codReviewPromptHandler)
106-
s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) {
106+
s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *GetPromptRequest) (*GetPromptResult, error) {
107107
return nil, errTestFailure
108108
})
109109
s.AddResource(resource1, readHandler)

mcp/prompt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
// A PromptHandler handles a call to prompts/get.
12-
type PromptHandler func(context.Context, *ServerSession, *GetPromptParams) (*GetPromptResult, error)
12+
type PromptHandler func(context.Context, *GetPromptRequest) (*GetPromptResult, error)
1313

1414
type serverPrompt struct {
1515
prompt *Prompt

mcp/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetProm
446446
Message: fmt.Sprintf("unknown prompt %q", req.Params.Name),
447447
}
448448
}
449-
return prompt.handler(ctx, req.Session, req.Params)
449+
return prompt.handler(ctx, req)
450450
}
451451

452452
func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) {

0 commit comments

Comments
 (0)