@@ -17,6 +17,7 @@ import (
1717 "slices"
1818 "strings"
1919 "sync"
20+ "sync/atomic"
2021 "testing"
2122 "time"
2223
@@ -549,31 +550,47 @@ func errorCode(err error) int64 {
549550//
550551// The caller should cancel either the client connection or server connection
551552// when the connections are no longer needed.
552- func basicConnection (t * testing.T , config func (* Server )) (* ServerSession , * ClientSession ) {
553+ func basicConnection (t * testing.T , config func (* Server )) (* ClientSession , * ServerSession ) {
554+ return basicClientServerConnection (t , nil , nil , config )
555+ }
556+
557+ // basicClientServerConnection creates a basic connection between client and
558+ // server. If either client or server is nil, empty implementations are used.
559+ //
560+ // The provided function may be used to configure features on the resulting
561+ // server, prior to connection.
562+ //
563+ // The caller should cancel either the client connection or server connection
564+ // when the connections are no longer needed.
565+ func basicClientServerConnection (t * testing.T , client * Client , server * Server , config func (* Server )) (* ClientSession , * ServerSession ) {
553566 t .Helper ()
554567
555568 ctx := context .Background ()
556569 ct , st := NewInMemoryTransports ()
557570
558- s := NewServer (testImpl , nil )
571+ if server == nil {
572+ server = NewServer (testImpl , nil )
573+ }
559574 if config != nil {
560- config (s )
575+ config (server )
561576 }
562- ss , err := s .Connect (ctx , st , nil )
577+ ss , err := server .Connect (ctx , st , nil )
563578 if err != nil {
564579 t .Fatal (err )
565580 }
566581
567- c := NewClient (testImpl , nil )
568- cs , err := c .Connect (ctx , ct , nil )
582+ if client == nil {
583+ client = NewClient (testImpl , nil )
584+ }
585+ cs , err := client .Connect (ctx , ct , nil )
569586 if err != nil {
570587 t .Fatal (err )
571588 }
572- return ss , cs
589+ return cs , ss
573590}
574591
575592func TestServerClosing (t * testing.T ) {
576- cc , cs := basicConnection (t , func (s * Server ) {
593+ cs , ss := basicConnection (t , func (s * Server ) {
577594 AddTool (s , greetTool (), sayHi )
578595 })
579596 defer cs .Close ()
@@ -593,7 +610,7 @@ func TestServerClosing(t *testing.T) {
593610 }); err != nil {
594611 t .Fatalf ("after connecting: %v" , err )
595612 }
596- cc .Close ()
613+ ss .Close ()
597614 wg .Wait ()
598615 if _ , err := cs .CallTool (ctx , & CallToolParams {
599616 Name : "greet" ,
@@ -656,7 +673,7 @@ func TestCancellation(t *testing.T) {
656673 }
657674 return nil , nil
658675 }
659- _ , cs := basicConnection (t , func (s * Server ) {
676+ cs , _ := basicConnection (t , func (s * Server ) {
660677 AddTool (s , & Tool {Name : "slow" }, slowRequest )
661678 })
662679 defer cs .Close ()
@@ -940,7 +957,7 @@ func TestKeepAliveFailure(t *testing.T) {
940957func TestAddTool_DuplicateNoPanicAndNoDuplicate (t * testing.T ) {
941958 // Adding the same tool pointer twice should not panic and should not
942959 // produce duplicates in the server's tool list.
943- _ , cs := basicConnection (t , func (s * Server ) {
960+ cs , _ := basicConnection (t , func (s * Server ) {
944961 // Use two distinct Tool instances with the same name but different
945962 // descriptions to ensure the second replaces the first
946963 // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors
@@ -972,4 +989,98 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
972989 }
973990}
974991
992+ func TestSynchronousNotifications (t * testing.T ) {
993+ var toolsChanged atomic.Bool
994+ clientOpts := & ClientOptions {
995+ ToolListChangedHandler : func (ctx context.Context , req * ClientRequest [* ToolListChangedParams ]) {
996+ toolsChanged .Store (true )
997+ },
998+ CreateMessageHandler : func (ctx context.Context , req * ClientRequest [* CreateMessageParams ]) (* CreateMessageResult , error ) {
999+ if ! toolsChanged .Load () {
1000+ return nil , fmt .Errorf ("didn't get a tools changed notification" )
1001+ }
1002+ // TODO(rfindley): investigate the error returned from this test if
1003+ // CreateMessageResult is new(CreateMessageResult): it's a mysterious
1004+ // unmarshalling error that we should improve.
1005+ return & CreateMessageResult {Content : & TextContent {}}, nil
1006+ },
1007+ }
1008+ client := NewClient (testImpl , clientOpts )
1009+
1010+ var rootsChanged atomic.Bool
1011+ serverOpts := & ServerOptions {
1012+ RootsListChangedHandler : func (_ context.Context , req * ServerRequest [* RootsListChangedParams ]) {
1013+ rootsChanged .Store (true )
1014+ },
1015+ }
1016+ server := NewServer (testImpl , serverOpts )
1017+ cs , ss := basicClientServerConnection (t , client , server , func (s * Server ) {
1018+ AddTool (s , & Tool {Name : "tool" }, func (ctx context.Context , req * ServerRequest [* CallToolParams ]) (* CallToolResult , error ) {
1019+ if ! rootsChanged .Load () {
1020+ return nil , fmt .Errorf ("didn't get root change notification" )
1021+ }
1022+ return new (CallToolResult ), nil
1023+ })
1024+ })
1025+
1026+ t .Run ("from client" , func (t * testing.T ) {
1027+ client .AddRoots (& Root {Name : "myroot" , URI : "file://foo" })
1028+ res , err := cs .CallTool (context .Background (), & CallToolParams {Name : "tool" })
1029+ if err != nil {
1030+ t .Fatalf ("CallTool failed: %v" , err )
1031+ }
1032+ if res .IsError {
1033+ t .Errorf ("tool error: %v" , res .Content [0 ].(* TextContent ).Text )
1034+ }
1035+ })
1036+
1037+ t .Run ("from server" , func (t * testing.T ) {
1038+ server .RemoveTools ("tool" )
1039+ if _ , err := ss .CreateMessage (context .Background (), new (CreateMessageParams )); err != nil {
1040+ t .Errorf ("CreateMessage failed: %v" , err )
1041+ }
1042+ })
1043+ }
1044+
1045+ func TestNoDistributedDeadlock (t * testing.T ) {
1046+ // This test verifies that calls are asynchronous, and so it's not possible
1047+ // to have a distributed deadlock.
1048+ //
1049+ // The setup creates potential deadlock for both the client and server: the
1050+ // client sends a call to tool1, which itself calls createMessage, which in
1051+ // turn calls tool2, which calls ping.
1052+ //
1053+ // If the server were not asynchronous, the call to tool2 would hang. If the
1054+ // client were not asynchronous, the call to ping would hang.
1055+ //
1056+ // Such a scenario is unlikely in practice, but is still theoretically
1057+ // possible, and in any case making tool calls asynchronous by default
1058+ // delegates synchronization to the user.
1059+ clientOpts := & ClientOptions {
1060+ CreateMessageHandler : func (ctx context.Context , req * ClientRequest [* CreateMessageParams ]) (* CreateMessageResult , error ) {
1061+ req .Session .CallTool (ctx , & CallToolParams {Name : "tool2" })
1062+ return & CreateMessageResult {Content : & TextContent {}}, nil
1063+ },
1064+ }
1065+ client := NewClient (testImpl , clientOpts )
1066+ cs , _ := basicClientServerConnection (t , client , nil , func (s * Server ) {
1067+ AddTool (s , & Tool {Name : "tool1" }, func (ctx context.Context , req * ServerRequest [* CallToolParams ]) (* CallToolResult , error ) {
1068+ req .Session .CreateMessage (ctx , new (CreateMessageParams ))
1069+ return new (CallToolResult ), nil
1070+ })
1071+ AddTool (s , & Tool {Name : "tool2" }, func (ctx context.Context , req * ServerRequest [* CallToolParams ]) (* CallToolResult , error ) {
1072+ req .Session .Ping (ctx , nil )
1073+ return new (CallToolResult ), nil
1074+ })
1075+ })
1076+ defer cs .Close ()
1077+
1078+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
1079+ defer cancel ()
1080+ if _ , err := cs .CallTool (ctx , & CallToolParams {Name : "tool1" }); err != nil {
1081+ // should not deadlock
1082+ t .Fatalf ("CallTool failed: %v" , err )
1083+ }
1084+ }
1085+
9751086var testImpl = & Implementation {Name : "test" , Version : "v1.0.0" }
0 commit comments