@@ -220,6 +220,43 @@ func TestClient_Connect(t *testing.T) {
220220 return atomic .LoadInt32 (& onClosedCalled ) == 1
221221 }, 100 * time .Millisecond , 20 * time .Millisecond , "onClose should be called" )
222222 })
223+
224+ t .Run ("when OnCloseCtx returns error, we still close the connection" , func (t * testing.T ) {
225+ server , err := NewTestServer ()
226+ require .NoError (t , err )
227+ defer server .Close ()
228+
229+ var onClosedCalled int32
230+ onCloseCtx := func (ctx context.Context , c * connection.Connection ) error {
231+ // increase the counter
232+ atomic .AddInt32 (& onClosedCalled , 1 )
233+ return errors .New ("error from on close handler" )
234+ }
235+
236+ var onErrCalled int32
237+ errHandler := func (err error ) {
238+ atomic .AddInt32 (& onErrCalled , 1 )
239+ require .Contains (t , err .Error (), "error from on close handler" )
240+ }
241+
242+ c , err := connection .New (
243+ server .Addr ,
244+ testSpec ,
245+ readMessageLength ,
246+ writeMessageLength ,
247+ connection .ErrorHandler (errHandler ),
248+ connection .OnCloseCtx (onCloseCtx ),
249+ )
250+ require .NoError (t , err )
251+
252+ err = c .CloseCtx (context .Background ())
253+ require .NoError (t , err )
254+
255+ // eventually the onClosedCalled should be 1
256+ require .Eventually (t , func () bool {
257+ return atomic .LoadInt32 (& onClosedCalled ) == 1
258+ }, 100 * time .Millisecond , 20 * time .Millisecond , "onClose should be called" )
259+ })
223260}
224261
225262func TestClient_Write (t * testing.T ) {
0 commit comments