@@ -234,7 +234,7 @@ func TestEndToEnd(t *testing.T) {
234234 & TextContent {Text : "hi user" },
235235 },
236236 }
237- if diff := cmp .Diff (wantHi , gotHi ); diff != "" {
237+ if diff := cmp .Diff (wantHi , gotHi , ctrCmpOpts ... ); diff != "" {
238238 t .Errorf ("tools/call 'greet' mismatch (-want +got):\n %s" , diff )
239239 }
240240
@@ -253,7 +253,7 @@ func TestEndToEnd(t *testing.T) {
253253 & TextContent {Text : errTestFailure .Error ()},
254254 },
255255 }
256- if diff := cmp .Diff (wantFail , gotFail ); diff != "" {
256+ if diff := cmp .Diff (wantFail , gotFail , ctrCmpOpts ... ); diff != "" {
257257 t .Errorf ("tools/call 'fail' mismatch (-want +got):\n %s" , diff )
258258 }
259259
@@ -1717,7 +1717,7 @@ func TestPointerArgEquivalence(t *testing.T) {
17171717 if err != nil {
17181718 t .Fatal (err )
17191719 }
1720- if diff := cmp .Diff (r0 , r1 ); diff != "" {
1720+ if diff := cmp .Diff (r0 , r1 , ctrCmpOpts ... ); diff != "" {
17211721 t .Errorf ("CallTool(%v) with no arguments mismatch (-%s +%s):\n %s" , args , t0 .Name , t1 .Name , diff )
17221722 }
17231723 }
@@ -1733,7 +1733,7 @@ func TestPointerArgEquivalence(t *testing.T) {
17331733 if err != nil {
17341734 t .Fatal (err )
17351735 }
1736- if diff := cmp .Diff (r0 , r1 ); diff != "" {
1736+ if diff := cmp .Diff (r0 , r1 , ctrCmpOpts ... ); diff != "" {
17371737 t .Errorf ("CallTool({\" In\" : %q}) mismatch (-%s +%s):\n %s" , in , t0 .Name , t1 .Name , diff )
17381738 }
17391739 })
@@ -1837,3 +1837,70 @@ func TestEmbeddedStructResponse(t *testing.T) {
18371837 t .Errorf ("CallTool() failed: %v" , err )
18381838 }
18391839}
1840+
1841+ func TestToolErrorMiddleware (t * testing.T ) {
1842+ ctx := context .Background ()
1843+ ct , st := NewInMemoryTransports ()
1844+
1845+ s := NewServer (testImpl , nil )
1846+ AddTool (s , & Tool {
1847+ Name : "greet" ,
1848+ Description : "say hi" ,
1849+ }, sayHi )
1850+ AddTool (s , & Tool {Name : "fail" , InputSchema : & jsonschema.Schema {Type : "object" }},
1851+ func (context.Context , * CallToolRequest , map [string ]any ) (* CallToolResult , any , error ) {
1852+ return nil , nil , errTestFailure
1853+ })
1854+
1855+ var middleErr error
1856+ s .AddReceivingMiddleware (func (h MethodHandler ) MethodHandler {
1857+ return func (ctx context.Context , method string , req Request ) (Result , error ) {
1858+ res , err := h (ctx , method , req )
1859+ if err == nil {
1860+ if ctr , ok := res .(* CallToolResult ); ok {
1861+ middleErr = ctr .getError ()
1862+ }
1863+ }
1864+ return res , err
1865+ }
1866+ })
1867+ _ , err := s .Connect (ctx , st , nil )
1868+ if err != nil {
1869+ t .Fatal (err )
1870+ }
1871+ client := NewClient (& Implementation {Name : "test-client" }, nil )
1872+ clientSession , err := client .Connect (ctx , ct , nil )
1873+ if err != nil {
1874+ t .Fatal (err )
1875+ }
1876+ defer clientSession .Close ()
1877+
1878+ _ , err = clientSession .CallTool (ctx , & CallToolParams {
1879+ Name : "greet" ,
1880+ Arguments : map [string ]any {"Name" : "al" },
1881+ })
1882+ if err != nil {
1883+ t .Errorf ("CallTool() failed: %v" , err )
1884+ }
1885+ if middleErr != nil {
1886+ t .Errorf ("middleware got error %v, want nil" , middleErr )
1887+ }
1888+ res , err := clientSession .CallTool (ctx , & CallToolParams {
1889+ Name : "fail" ,
1890+ })
1891+ if err != nil {
1892+ t .Errorf ("CallTool() failed: %v" , err )
1893+ }
1894+ if ! res .IsError {
1895+ t .Fatal ("want error, got none" )
1896+ }
1897+ // Clients can't see the error, because it isn't marshaled.
1898+ if err := res .getError (); err != nil {
1899+ t .Fatalf ("got %v, want nil" , err )
1900+ }
1901+ if middleErr != errTestFailure {
1902+ t .Errorf ("middleware got err %v, want errTestFailure" , middleErr )
1903+ }
1904+ }
1905+
1906+ var ctrCmpOpts = []cmp.Option {cmp .AllowUnexported (CallToolResult {})}
0 commit comments