diff --git a/backend/httpclient/http_client.go b/backend/httpclient/http_client.go index 3a5403bdf..7355d1b67 100644 --- a/backend/httpclient/http_client.go +++ b/backend/httpclient/http_client.go @@ -212,6 +212,7 @@ func DefaultMiddlewares() []Middleware { CustomHeadersMiddleware(), ContextualMiddleware(), ErrorSourceMiddleware(), + ResponseLimitMiddleware(0), } } diff --git a/backend/httpclient/http_client_test.go b/backend/httpclient/http_client_test.go index 43cc9f965..a59bb0fe4 100644 --- a/backend/httpclient/http_client_test.go +++ b/backend/httpclient/http_client_test.go @@ -55,13 +55,14 @@ func TestNewClient(t *testing.T) { require.NoError(t, err) require.NotNil(t, client) - require.Len(t, usedMiddlewares, 6) + require.Len(t, usedMiddlewares, 7) require.Equal(t, TracingMiddlewareName, usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, DataSourceMetricsMiddlewareName, usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, usedMiddlewares[2].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, usedMiddlewares[3].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, usedMiddlewares[4].(MiddlewareName).MiddlewareName()) require.Equal(t, ErrorSourceMiddlewareName, usedMiddlewares[5].(MiddlewareName).MiddlewareName()) + require.Equal(t, ResponseLimitMiddlewareName, usedMiddlewares[6].(MiddlewareName).MiddlewareName()) }) t.Run("New() with opts middleware should return expected http.Client", func(t *testing.T) { diff --git a/backend/httpclient/provider_test.go b/backend/httpclient/provider_test.go index 2ca879184..e1023a141 100644 --- a/backend/httpclient/provider_test.go +++ b/backend/httpclient/provider_test.go @@ -25,13 +25,14 @@ func TestProvider(t *testing.T) { client, err := ctx.provider.New() require.NoError(t, err) require.NotNil(t, client) - require.Len(t, ctx.usedMiddlewares, 6) + require.Len(t, ctx.usedMiddlewares, 7) require.Equal(t, TracingMiddlewareName, ctx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, DataSourceMetricsMiddlewareName, ctx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, ctx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, ctx.usedMiddlewares[3].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, ctx.usedMiddlewares[4].(MiddlewareName).MiddlewareName()) require.Equal(t, ErrorSourceMiddlewareName, ctx.usedMiddlewares[5].(MiddlewareName).MiddlewareName()) + require.Equal(t, ResponseLimitMiddlewareName, ctx.usedMiddlewares[6].(MiddlewareName).MiddlewareName()) }) t.Run("Transport should use default middlewares", func(t *testing.T) { @@ -39,13 +40,14 @@ func TestProvider(t *testing.T) { transport, err := ctx.provider.GetTransport() require.NoError(t, err) require.NotNil(t, transport) - require.Len(t, ctx.usedMiddlewares, 6) + require.Len(t, ctx.usedMiddlewares, 7) require.Equal(t, TracingMiddlewareName, ctx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, DataSourceMetricsMiddlewareName, ctx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, BasicAuthenticationMiddlewareName, ctx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) require.Equal(t, CustomHeadersMiddlewareName, ctx.usedMiddlewares[3].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, ctx.usedMiddlewares[4].(MiddlewareName).MiddlewareName()) require.Equal(t, ErrorSourceMiddlewareName, ctx.usedMiddlewares[5].(MiddlewareName).MiddlewareName()) + require.Equal(t, ResponseLimitMiddlewareName, ctx.usedMiddlewares[6].(MiddlewareName).MiddlewareName()) }) t.Run("New() with options and no middleware should return expected http client and transport", func(t *testing.T) { @@ -86,7 +88,7 @@ func TestProvider(t *testing.T) { require.Equal(t, DefaultTimeoutOptions.Timeout, client.Timeout) t.Run("Should use configured middlewares and implement MiddlewareName", func(t *testing.T) { - require.Len(t, pCtx.usedMiddlewares, 9) + require.Len(t, pCtx.usedMiddlewares, 10) require.Equal(t, "mw1", pCtx.usedMiddlewares[0].(MiddlewareName).MiddlewareName()) require.Equal(t, "mw2", pCtx.usedMiddlewares[1].(MiddlewareName).MiddlewareName()) require.Equal(t, "mw3", pCtx.usedMiddlewares[2].(MiddlewareName).MiddlewareName()) @@ -96,6 +98,7 @@ func TestProvider(t *testing.T) { require.Equal(t, CustomHeadersMiddlewareName, pCtx.usedMiddlewares[6].(MiddlewareName).MiddlewareName()) require.Equal(t, ContextualMiddlewareName, pCtx.usedMiddlewares[7].(MiddlewareName).MiddlewareName()) require.Equal(t, ErrorSourceMiddlewareName, pCtx.usedMiddlewares[8].(MiddlewareName).MiddlewareName()) + require.Equal(t, ResponseLimitMiddlewareName, pCtx.usedMiddlewares[9].(MiddlewareName).MiddlewareName()) }) t.Run("When roundtrip should call expected middlewares", func(t *testing.T) { diff --git a/backend/httpclient/response_limit_middleware.go b/backend/httpclient/response_limit_middleware.go index d36c01551..f03f94939 100644 --- a/backend/httpclient/response_limit_middleware.go +++ b/backend/httpclient/response_limit_middleware.go @@ -2,12 +2,32 @@ package httpclient import ( "net/http" + "os" + "strconv" + + "github.com/grafana/grafana-plugin-sdk-go/backend/log" +) + +const ( + ResponseLimitEnvVar = "GF_DATAPROXY_RESPONSE_LIMIT" ) // ResponseLimitMiddlewareName is the middleware name used by ResponseLimitMiddleware. const ResponseLimitMiddlewareName = "response-limit" func ResponseLimitMiddleware(limit int64) Middleware { + if limit <= 0 { + envLimit, ok := os.LookupEnv(ResponseLimitEnvVar) + if ok && envLimit != "" { + limitInt, err := strconv.ParseInt(envLimit, 10, 64) + if err == nil && limitInt > 0 { + limit = limitInt + } + + log.DefaultLogger.Error("failed to parse GF_DATAPROXY_RESPONSE_LIMIT", "error", err) + } + } + return NamedMiddlewareFunc(ResponseLimitMiddlewareName, func(_ Options, next http.RoundTripper) http.RoundTripper { return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { res, err := next.RoundTrip(req) diff --git a/backend/httpclient/response_limit_middleware_test.go b/backend/httpclient/response_limit_middleware_test.go index ec5b3b810..2531cf690 100644 --- a/backend/httpclient/response_limit_middleware_test.go +++ b/backend/httpclient/response_limit_middleware_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "os" "strings" "testing" @@ -18,17 +19,27 @@ func TestResponseLimitMiddleware(t *testing.T) { expectedBodyLength int expectedBody string err error + envLimit string }{ - {limit: 1, expectedBodyLength: 1, expectedBody: "d", err: errors.New("error: http: response body too large, response limit is set to: 1")}, - {limit: 1000000, expectedBodyLength: 5, expectedBody: "dummy", err: nil}, - {limit: 0, expectedBodyLength: 5, expectedBody: "dummy", err: nil}, + // Test that the limit is set from arguments + {limit: 1, expectedBodyLength: 1, expectedBody: "d", err: errors.New("error: http: response body too large, response limit is set to: 1"), envLimit: ""}, + {limit: 1000000, expectedBodyLength: 5, expectedBody: "dummy", err: nil, envLimit: ""}, + {limit: 0, expectedBodyLength: 5, expectedBody: "dummy", err: nil, envLimit: ""}, + // Test that the limit is set from the environment variable + {limit: 0, expectedBodyLength: 1, expectedBody: "d", err: errors.New("error: http: response body too large, response limit is set to: 1"), envLimit: "1"}, + {limit: 0, expectedBodyLength: 5, expectedBody: "dummy", err: nil, envLimit: "1000000"}, + {limit: 0, expectedBodyLength: 5, expectedBody: "dummy", err: nil, envLimit: "-1"}, + {limit: 0, expectedBodyLength: 5, expectedBody: "dummy", err: nil, envLimit: "0"}, } for _, tc := range tcs { - t.Run(fmt.Sprintf("Test ResponseLimitMiddleware with limit: %d", tc.limit), func(t *testing.T) { + t.Run(fmt.Sprintf("Test ResponseLimitMiddleware with limit: %d and envLimit: %s", tc.limit, tc.envLimit), func(t *testing.T) { finalRoundTripper := RoundTripperFunc(func(req *http.Request) (*http.Response, error) { return &http.Response{StatusCode: http.StatusOK, Request: req, Body: io.NopCloser(strings.NewReader("dummy"))}, nil }) + os.Setenv(ResponseLimitEnvVar, tc.envLimit) + defer os.Unsetenv(ResponseLimitEnvVar) + mw := ResponseLimitMiddleware(tc.limit) rt := mw.CreateMiddleware(Options{}, finalRoundTripper) require.NotNil(t, rt)