@@ -40,6 +40,25 @@ func TestStreamableTransports(t *testing.T) {
4040 // 1. Create a server with a simple "greet" tool.
4141 server := NewServer (testImpl , nil )
4242 AddTool (server , & Tool {Name : "greet" , Description : "say hi" }, sayHi )
43+ AddTool (server , & Tool {Name : "sample" }, func (ctx context.Context , req * ServerRequest [* CallToolParamsFor [any ]]) (* CallToolResultFor [any ], error ) {
44+ // Test that we can make sampling requests during tool handling.
45+ //
46+ // Try this on both the request context and a background context, so
47+ // that messages may be delivered on either the POST or GET connection.
48+ for _ , ctx := range map [string ]context.Context {
49+ "request context" : ctx ,
50+ "background context" : context .Background (),
51+ } {
52+ res , err := req .Session .CreateMessage (ctx , & CreateMessageParams {})
53+ if err != nil {
54+ return nil , err
55+ }
56+ if g , w := res .Model , "aModel" ; g != w {
57+ return nil , fmt .Errorf ("got %q, want %q" , g , w )
58+ }
59+ }
60+ return & CallToolResultFor [any ]{}, nil
61+ })
4362
4463 // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a
4564 // cookie-checking middleware.
@@ -81,7 +100,11 @@ func TestStreamableTransports(t *testing.T) {
81100 Endpoint : httpServer .URL ,
82101 HTTPClient : httpClient ,
83102 }
84- client := NewClient (testImpl , nil )
103+ client := NewClient (testImpl , & ClientOptions {
104+ CreateMessageHandler : func (context.Context , * ClientRequest [* CreateMessageParams ]) (* CreateMessageResult , error ) {
105+ return & CreateMessageResult {Model : "aModel" , Content : & TextContent {}}, nil
106+ },
107+ })
85108 session , err := client .Connect (ctx , transport , nil )
86109 if err != nil {
87110 t .Fatalf ("client.Connect() failed: %v" , err )
@@ -119,6 +142,19 @@ func TestStreamableTransports(t *testing.T) {
119142 if diff := cmp .Diff (want , got ); diff != "" {
120143 t .Errorf ("CallTool() returned unexpected content (-want +got):\n %s" , diff )
121144 }
145+
146+ // 6. Run the "sampling" tool and verify that the streamable server can
147+ // call tools.
148+ result , err := session .CallTool (ctx , & CallToolParams {
149+ Name : "sample" ,
150+ Arguments : map [string ]any {},
151+ })
152+ if err != nil {
153+ t .Fatal (err )
154+ }
155+ if result .IsError {
156+ t .Fatalf ("tool failed: %s" , result .Content [0 ].(* TextContent ).Text )
157+ }
122158 })
123159 }
124160}
0 commit comments