@@ -592,7 +592,10 @@ func errorCode(err error) int64 {
592592//
593593// The caller should cancel either the client connection or server connection
594594// when the connections are no longer needed.
595- func basicConnection (t * testing.T , config func (* Server )) (* ClientSession , * ServerSession ) {
595+ //
596+ // The returned func cleans up by closing the client and waiting for the server
597+ // to shut down.
598+ func basicConnection (t * testing.T , config func (* Server )) (* ClientSession , * ServerSession , func ()) {
596599 return basicClientServerConnection (t , nil , nil , config )
597600}
598601
@@ -604,7 +607,10 @@ func basicConnection(t *testing.T, config func(*Server)) (*ClientSession, *Serve
604607//
605608// The caller should cancel either the client connection or server connection
606609// when the connections are no longer needed.
607- func basicClientServerConnection (t * testing.T , client * Client , server * Server , config func (* Server )) (* ClientSession , * ServerSession ) {
610+ //
611+ // The returned func cleans up by closing the client and waiting for the server
612+ // to shut down.
613+ func basicClientServerConnection (t * testing.T , client * Client , server * Server , config func (* Server )) (* ClientSession , * ServerSession , func ()) {
608614 t .Helper ()
609615
610616 ctx := context .Background ()
@@ -628,14 +634,17 @@ func basicClientServerConnection(t *testing.T, client *Client, server *Server, c
628634 if err != nil {
629635 t .Fatal (err )
630636 }
631- return cs , ss
637+ return cs , ss , func () {
638+ cs .Close ()
639+ ss .Wait ()
640+ }
632641}
633642
634643func TestServerClosing (t * testing.T ) {
635- cs , ss := basicConnection (t , func (s * Server ) {
644+ cs , ss , cleanup := basicConnection (t , func (s * Server ) {
636645 AddTool (s , greetTool (), sayHi )
637646 })
638- defer cs . Close ()
647+ defer cleanup ()
639648
640649 ctx := context .Background ()
641650 var wg sync.WaitGroup
@@ -715,10 +724,10 @@ func TestCancellation(t *testing.T) {
715724 }
716725 return nil , nil , nil
717726 }
718- cs , _ := basicConnection (t , func (s * Server ) {
727+ cs , _ , cleanup := basicConnection (t , func (s * Server ) {
719728 AddTool (s , & Tool {Name : "slow" , InputSchema : & jsonschema.Schema {Type : "object" }}, slowTool )
720729 })
721- defer cs . Close ()
730+ defer cleanup ()
722731
723732 ctx , cancel := context .WithCancel (context .Background ())
724733 go cs .CallTool (ctx , & CallToolParams {Name : "slow" })
@@ -741,13 +750,10 @@ func TestMiddleware(t *testing.T) {
741750 t .Fatal (err )
742751 }
743752 // Wait for the server to exit after the client closes its connection.
744- var clientWG sync.WaitGroup
745- clientWG .Add (1 )
746- go func () {
753+ defer func () {
747754 if err := ss .Wait (); err != nil {
748755 t .Errorf ("server failed: %v" , err )
749756 }
750- clientWG .Done ()
751757 }()
752758
753759 var sbuf , cbuf bytes.Buffer
@@ -767,6 +773,8 @@ func TestMiddleware(t *testing.T) {
767773 if err != nil {
768774 t .Fatal (err )
769775 }
776+ defer cs .Close ()
777+
770778 if _ , err := cs .ListTools (ctx , nil ); err != nil {
771779 t .Fatal (err )
772780 }
@@ -1511,7 +1519,7 @@ func TestKeepAliveFailure(t *testing.T) {
15111519func TestAddTool_DuplicateNoPanicAndNoDuplicate (t * testing.T ) {
15121520 // Adding the same tool pointer twice should not panic and should not
15131521 // produce duplicates in the server's tool list.
1514- cs , _ := basicConnection (t , func (s * Server ) {
1522+ cs , _ , cleanup := basicConnection (t , func (s * Server ) {
15151523 // Use two distinct Tool instances with the same name but different
15161524 // descriptions to ensure the second replaces the first
15171525 // This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors
@@ -1520,7 +1528,7 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
15201528 s .AddTool (t1 , nopHandler )
15211529 s .AddTool (t2 , nopHandler )
15221530 })
1523- defer cs . Close ()
1531+ defer cleanup ()
15241532
15251533 ctx := context .Background ()
15261534 res , err := cs .ListTools (ctx , nil )
@@ -1568,14 +1576,15 @@ func TestSynchronousNotifications(t *testing.T) {
15681576 },
15691577 }
15701578 server := NewServer (testImpl , serverOpts )
1571- cs , ss := basicClientServerConnection (t , client , server , func (s * Server ) {
1579+ cs , ss , cleanup := basicClientServerConnection (t , client , server , func (s * Server ) {
15721580 AddTool (s , & Tool {Name : "tool" }, func (ctx context.Context , req * CallToolRequest , args any ) (* CallToolResult , any , error ) {
15731581 if ! rootsChanged .Load () {
15741582 return nil , nil , fmt .Errorf ("didn't get root change notification" )
15751583 }
15761584 return new (CallToolResult ), nil , nil
15771585 })
15781586 })
1587+ defer cleanup ()
15791588
15801589 t .Run ("from client" , func (t * testing.T ) {
15811590 client .AddRoots (& Root {Name : "myroot" , URI : "file://foo" })
@@ -1617,7 +1626,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
16171626 },
16181627 }
16191628 client := NewClient (testImpl , clientOpts )
1620- cs , _ := basicClientServerConnection (t , client , nil , func (s * Server ) {
1629+ cs , _ , cleanup := basicClientServerConnection (t , client , nil , func (s * Server ) {
16211630 AddTool (s , & Tool {Name : "tool1" }, func (ctx context.Context , req * CallToolRequest , args any ) (* CallToolResult , any , error ) {
16221631 req .Session .CreateMessage (ctx , new (CreateMessageParams ))
16231632 return new (CallToolResult ), nil , nil
@@ -1627,7 +1636,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
16271636 return new (CallToolResult ), nil , nil
16281637 })
16291638 })
1630- defer cs . Close ()
1639+ defer cleanup ()
16311640
16321641 ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
16331642 defer cancel ()
@@ -1651,7 +1660,7 @@ func TestPointerArgEquivalence(t *testing.T) {
16511660 type output struct {
16521661 Out string
16531662 }
1654- cs , _ := basicConnection (t , func (s * Server ) {
1663+ cs , _ , cleanup := basicConnection (t , func (s * Server ) {
16551664 // Add two equivalent tools, one of which operates in the 'pointer' realm,
16561665 // the other of which does not.
16571666 //
@@ -1686,7 +1695,7 @@ func TestPointerArgEquivalence(t *testing.T) {
16861695 }
16871696 })
16881697 })
1689- defer cs . Close ()
1698+ defer cleanup ()
16901699
16911700 ctx := context .Background ()
16921701 tools , err := cs .ListTools (ctx , nil )
@@ -1758,7 +1767,9 @@ func TestComplete(t *testing.T) {
17581767 },
17591768 }
17601769 server := NewServer (testImpl , serverOpts )
1761- cs , _ := basicClientServerConnection (t , nil , server , func (s * Server ) {})
1770+ cs , _ , cleanup := basicClientServerConnection (t , nil , server , func (s * Server ) {})
1771+ defer cleanup ()
1772+
17621773 result , err := cs .Complete (context .Background (), & CompleteParams {
17631774 Argument : CompleteParamsArgument {
17641775 Name : "language" ,
0 commit comments