diff --git a/client.go b/client.go index ec6141f8..2437ea15 100644 --- a/client.go +++ b/client.go @@ -95,6 +95,9 @@ type ( // SuccessHook type is for reacting to request success SuccessHook func(*Client, *Response) + // CloseHook type is for reacting to client closing + CloseHook func() + // RequestFunc type is for extended manipulation of the Request instance RequestFunc func(*Request) *Request @@ -215,6 +218,7 @@ type Client struct { invalidHooks []ErrorHook panicHooks []ErrorHook successHooks []SuccessHook + closeHooks []CloseHook contentTypeEncoders map[string]ContentTypeEncoder contentTypeDecoders map[string]ContentTypeDecoder contentDecompresserKeys []string @@ -838,6 +842,15 @@ func (c *Client) OnPanic(h ErrorHook) *Client { return c } +// OnClose method adds a callback that will be run whenever the client is closed. +// The hooks are executed in the order they were registered. +func (c *Client) OnClose(h CloseHook) *Client { + c.lock.Lock() + defer c.lock.Unlock() + c.closeHooks = append(c.closeHooks, h) + return c +} + // ContentTypeEncoders method returns all the registered content type encoders. func (c *Client) ContentTypeEncoders() map[string]ContentTypeEncoder { c.lock.RLock() @@ -2221,10 +2234,14 @@ func (c *Client) Clone(ctx context.Context) *Client { // Close method performs cleanup and closure activities on the client instance func (c *Client) Close() error { + // Execute close hooks first + c.onCloseHooks() + if c.LoadBalancer() != nil { silently(c.LoadBalancer().Close()) } close(c.certWatcherStopChan) + return nil } @@ -2377,6 +2394,15 @@ func (c *Client) onInvalidHooks(req *Request, err error) { } } +// Helper to run closeHooks hooks. +func (c *Client) onCloseHooks() { + c.lock.RLock() + defer c.lock.RUnlock() + for _, h := range c.closeHooks { + h() + } +} + func (c *Client) debugf(format string, v ...any) { if c.IsDebug() { c.Logger().Debugf(format, v...) diff --git a/client_test.go b/client_test.go index 8a5eef50..e753099c 100644 --- a/client_test.go +++ b/client_test.go @@ -1515,3 +1515,35 @@ func TestClientCircuitBreaker(t *testing.T) { assertError(t, err) assertEqual(t, uint32(1), c.circuitBreaker.failureCount.Load()) } + +func TestClientOnClose(t *testing.T) { + var hookExecuted bool + + c := dcnl() + c.OnClose(func() { + hookExecuted = true + }) + + err := c.Close() + assertNil(t, err) + assertEqual(t, true, hookExecuted) +} + +func TestClientOnCloseMultipleHooks(t *testing.T) { + var executionOrder []string + + c := dcnl() + c.OnClose(func() { + executionOrder = append(executionOrder, "first") + }) + c.OnClose(func() { + executionOrder = append(executionOrder, "second") + }) + c.OnClose(func() { + executionOrder = append(executionOrder, "third") + }) + + err := c.Close() + assertNil(t, err) + assertEqual(t, []string{"first", "second", "third"}, executionOrder) +}