diff --git a/context.go b/context.go index 837edc2..327de5c 100644 --- a/context.go +++ b/context.go @@ -4,11 +4,11 @@ import ( "context" ) -const contextRequestIDKey = "request.id" +const ContextRequestIDKey = "request.id" // requestID returns a request present on context. -func requestID(ctx context.Context) string { - value := ctx.Value(contextRequestIDKey) +func RequestID(ctx context.Context) string { + value := ctx.Value(ContextRequestIDKey) if value == nil { return "" } diff --git a/context_test.go b/context_test.go index 4e12f9d..d6a907d 100644 --- a/context_test.go +++ b/context_test.go @@ -1,17 +1,18 @@ -package httpclient +package httpclient_test import ( "context" + "github.com/globocom/httpclient" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("requestID", func() { It("returns the request id included within the values", func() { - ctx := context.WithValue(context.Background(), contextRequestIDKey, "42") + ctx := context.WithValue(context.Background(), httpclient.ContextRequestIDKey, "42") - id := requestID(ctx) + id := httpclient.RequestID(ctx) Expect(id).To(Equal("42")) }) @@ -19,7 +20,7 @@ var _ = Describe("requestID", func() { It("returns blank string if request id is not present on the context", func() { ctx := context.Background() - id := requestID(ctx) + id := httpclient.RequestID(ctx) Expect(id).To(Equal("")) }) diff --git a/register_metrics_test.go b/register_metrics_test.go new file mode 100644 index 0000000..0a46484 --- /dev/null +++ b/register_metrics_test.go @@ -0,0 +1,70 @@ +package httpclient_test + +import ( + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/globocom/httpclient" + "github.com/stretchr/testify/assert" +) + +type mockMetrics struct { + pushToSeriesCalls []string + incrCounterCalls []string + incrCounterWithAttrsCalls []struct { + key string + attrs map[string]string + } + lock sync.Mutex +} + +func (m *mockMetrics) PushToSeries(key string, value float64) { + m.lock.Lock() + defer m.lock.Unlock() + m.pushToSeriesCalls = append(m.pushToSeriesCalls, key) +} + +func (m *mockMetrics) IncrCounter(key string) { + m.lock.Lock() + defer m.lock.Unlock() + m.incrCounterCalls = append(m.incrCounterCalls, key) +} + +func (m *mockMetrics) IncrCounterWithAttrs(key string, attrs map[string]string) { + m.lock.Lock() + defer m.lock.Unlock() + m.incrCounterWithAttrsCalls = append(m.incrCounterWithAttrsCalls, struct { + key string + attrs map[string]string + }{key, attrs}) +} + +func TestHTTPClient_MetricsIntegration(t *testing.T) { + metrics := &mockMetrics{} + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) + rw.Write([]byte("OK")) + })) + defer server.Close() + + client := httpclient.NewHTTPClient( + &httpclient.LoggerAdapter{Writer: io.Discard}, + httpclient.WithHostURL(server.URL), + httpclient.WithMetrics(metrics), + ) + + req := client.NewRequest() + resp, err := req.Get("/") + assert.NoError(t, err) + assert.NotNil(t, resp) + + assert.Eventually(t, func() bool { + metrics.lock.Lock() + defer metrics.lock.Unlock() + return len(metrics.pushToSeriesCalls) > 0 && len(metrics.incrCounterCalls) > 0 && len(metrics.incrCounterWithAttrsCalls) > 0 + }, time.Second, 100*time.Millisecond) +} diff --git a/request.go b/request.go index 5577517..9c2cb13 100644 --- a/request.go +++ b/request.go @@ -36,6 +36,12 @@ func (r *Request) HostURL() *url.URL { return r.hostURL } +// SetHostURL sets the host url for the request. +func (r *Request) SetHostURL(url *url.URL) *Request { + r.hostURL = url + return r +} + // SetAlias sets the alias to replace the hostname in metrics. func (r *Request) SetAlias(alias string) *Request { r.alias = alias @@ -144,8 +150,12 @@ func registerMetrics(key string, metrics Metrics, f func() (*Response, error)) ( if metrics != nil { go func(resp *Response, err error) { - attrs := map[string]string{} + var attrs map[string]string if resp != nil { + attrs = map[string]string{ + "host": resp.Request().HostURL().Host, + "path": resp.Request().HostURL().Path, + } metrics.PushToSeries(fmt.Sprintf("%s.%s", key, "response_time"), resp.ResponseTime().Seconds()) if resp.statusCode != 0 { metrics.IncrCounter(fmt.Sprintf("%s.status.%d", key, resp.StatusCode())) diff --git a/request_test.go b/request_test.go index b926436..c2db8b1 100644 --- a/request_test.go +++ b/request_test.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/globocom/httpclient" @@ -30,6 +31,7 @@ func TestRequest(t *testing.T) { "SetBody": testSetBody, "SetHeader": testSetHeader, "SetBasicAuth": testSetBasicAuth, + "SetHostURL": testSetHostURL, "Get": testGet, "Post": testPost, "Put": testPut, @@ -131,3 +133,30 @@ func testDelete(target *httpclient.Request) func(*testing.T) { assert.Equal(t, "DELETE", gReq.Method) } } + +func testSetHostURL(target *httpclient.Request) func(*testing.T) { + return func(t *testing.T) { + // Create a new URL to set + newURL, err := url.Parse("https://example.com:8080") + assert.NoError(t, err) + + // Test setting the host URL + result := target.SetHostURL(newURL) + + // Verify the method returns the request instance (for chaining) + assert.Equal(t, target, result) + + // Verify the host URL was set correctly + hostURL := target.HostURL() + assert.NotNil(t, hostURL) + assert.Equal(t, "https://example.com:8080", hostURL.String()) + assert.Equal(t, "example.com", hostURL.Hostname()) + assert.Equal(t, "8080", hostURL.Port()) + assert.Equal(t, "https", hostURL.Scheme) + + // Test with nil URL + result2 := target.SetHostURL(nil) + assert.Equal(t, target, result2) + assert.Nil(t, target.HostURL()) + } +} diff --git a/suite_test.go b/suite_test.go index bfd2ae9..f2b1457 100644 --- a/suite_test.go +++ b/suite_test.go @@ -1,4 +1,4 @@ -package httpclient +package httpclient_test import ( "testing" diff --git a/transport.go b/transport.go index 3d3543b..6a55792 100644 --- a/transport.go +++ b/transport.go @@ -24,7 +24,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { } func (t *Transport) setRequestIDHeader(ctx context.Context, req *http.Request) { - rID := requestID(ctx) + rID := RequestID(ctx) if rID == "" { return }