@@ -145,14 +145,16 @@ func TestWatcher_PIDChange(t *testing.T) {
145145func TestWatcher_PIDChangeSuccess (t * testing.T ) {
146146 // test tests for success, which only happens when no error comes in
147147 // during this time period
148- ctx , cancel := context .WithTimeout ( context . Background (), 1 * time . Second )
148+ ctx , cancel := context .WithCancel ( t . Context () )
149149 defer cancel ()
150150
151- errCh := make (chan error )
151+ // buffered channel so we can drain after we send everything
152+ errCh := make (chan error , 10 )
152153 logger , _ := loggertest .New ("watcher" )
153154 w := NewAgentWatcher (errCh , logger , 1 * time .Millisecond )
154155
155156 // error on watch (counts as lost connect)
157+ sentEverything := make (chan struct {})
156158 mockHandler := func (srv cproto.ElasticAgentControl_StateWatchServer ) error {
157159 // starts with PID 1
158160 err := srv .Send (& cproto.StateResponse {
@@ -209,6 +211,8 @@ func TestWatcher_PIDChangeSuccess(t *testing.T) {
209211 if err != nil {
210212 return err
211213 }
214+ // close the channel to signify that we sent everything
215+ close (sentEverything )
212216 // keep open until end (exiting will count as a lost connection)
213217 <- ctx .Done ()
214218 return nil
@@ -222,10 +226,19 @@ func TestWatcher_PIDChangeSuccess(t *testing.T) {
222226 go w .Run (ctx )
223227
224228 select {
225- case <- ctx .Done ():
226- require .ErrorIs (t , ctx .Err (), context .DeadlineExceeded )
227- case err := <- errCh :
228- assert .NoError (t , err , "error should not have been reported" )
229+ case <- sentEverything :
230+ case <- time .After (5 * time .Second ):
231+ t .Fatal ("timed out waiting for everything to be sent" )
232+ return
233+ }
234+
235+ for {
236+ select {
237+ case err := <- errCh :
238+ assert .NoError (t , err , "error should not have been reported" )
239+ case <- time .After (1 * time .Second ):
240+ return
241+ }
229242 }
230243}
231244
0 commit comments