From 717adc51adee8b2a41719554b88ddb87cffc0b3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20Echterh=C3=B6lter?= Date: Sat, 13 Sep 2025 13:59:10 +0200 Subject: [PATCH 1/4] feat(middleware): implement authentication and context storage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Adds middleware for storing authorization headers and tokens in request context • Introduces local middleware for testing with JWT and tenant information --- middleware/auth.go | 19 ++++++++ middleware/authzMiddlewares.go | 22 ++++++++++ middleware/logger.go | 17 ++++++++ middleware/middleware.go | 17 ++++++++ middleware/otel.go | 17 ++++++++ middleware/requestid.go | 48 +++++++++++++++++++++ middleware/requestid_test.go | 44 +++++++++++++++++++ middleware/sentry.go | 31 +++++++++++++ middleware/spiffe.go | 22 ++++++++++ middleware/test_support/local_middleware.go | 28 ++++++++++++ middleware/token.go | 29 +++++++++++++ 11 files changed, 294 insertions(+) create mode 100644 middleware/auth.go create mode 100644 middleware/authzMiddlewares.go create mode 100644 middleware/logger.go create mode 100644 middleware/middleware.go create mode 100644 middleware/otel.go create mode 100644 middleware/requestid.go create mode 100644 middleware/requestid_test.go create mode 100644 middleware/sentry.go create mode 100644 middleware/spiffe.go create mode 100644 middleware/test_support/local_middleware.go create mode 100644 middleware/token.go diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..9b50870 --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,19 @@ +package middleware + +import ( + "net/http" + + "github.com/go-http-utils/headers" + "github.com/platform-mesh/golang-commons/context" +) + +// StoreAuthHeader stores the Authorization header within the request context +func StoreAuthHeader() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + auth := request.Header.Get(headers.Authorization) + ctx := context.AddAuthHeaderToContext(request.Context(), auth) + next.ServeHTTP(responseWriter, request.WithContext(ctx)) + }) + } +} diff --git a/middleware/authzMiddlewares.go b/middleware/authzMiddlewares.go new file mode 100644 index 0000000..81e96f2 --- /dev/null +++ b/middleware/authzMiddlewares.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net/http" + + "github.com/platform-mesh/golang-commons/policy_services" +) + +// Deprecated: CreateMiddlewares use CreateAuthMiddleware instead. +func CreateMiddlewares(retriever policy_services.TenantRetriever) []func(http.Handler) http.Handler { + return CreateAuthMiddleware(retriever) +} + +func CreateAuthMiddleware(retriever policy_services.TenantRetriever) []func(http.Handler) http.Handler { + mws := make([]func(http.Handler) http.Handler, 0, 5) + + mws = append(mws, StoreWebToken()) + mws = append(mws, StoreAuthHeader()) + mws = append(mws, StoreSpiffeHeader()) + + return mws +} diff --git a/middleware/logger.go b/middleware/logger.go new file mode 100644 index 0000000..4e40a0b --- /dev/null +++ b/middleware/logger.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "net/http" + + "github.com/platform-mesh/golang-commons/logger" +) + +// StoreLoggerMiddleware is a middleware that stores a given Logger in the request context +func StoreLoggerMiddleware(log *logger.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := logger.SetLoggerInContext(r.Context(), log) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..1ddc24d --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "net/http" + + "github.com/platform-mesh/golang-commons/logger" +) + +func CreateMiddleware(log *logger.Logger) []func(http.Handler) http.Handler { + return []func(http.Handler) http.Handler{ + SetOtelTracingContext(), + SentryRecoverer, + StoreLoggerMiddleware(log), + SetRequestId(), + SetRequestIdInLogger(), + } +} diff --git a/middleware/otel.go b/middleware/otel.go new file mode 100644 index 0000000..f134652 --- /dev/null +++ b/middleware/otel.go @@ -0,0 +1,17 @@ +package middleware + +import ( + "net/http" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" +) + +func SetOtelTracingContext() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + ctx := otel.GetTextMapPropagator().Extract(request.Context(), propagation.HeaderCarrier(request.Header)) + next.ServeHTTP(responseWriter, request.WithContext(ctx)) + }) + } +} diff --git a/middleware/requestid.go b/middleware/requestid.go new file mode 100644 index 0000000..59694dd --- /dev/null +++ b/middleware/requestid.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/google/uuid" + "github.com/platform-mesh/golang-commons/context/keys" + "github.com/platform-mesh/golang-commons/logger" +) + +const requestIdHeader = "X-Request-Id" + +func SetRequestId() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + ctx := request.Context() + var requestId string + if ids, ok := request.Header[requestIdHeader]; ok && len(ids) == 1 { + requestId = ids[0] + } else { + // Generate a new request id, header was not received. + requestId = uuid.New().String() + } + ctx = context.WithValue(ctx, keys.RequestIdCtxKey, requestId) + next.ServeHTTP(responseWriter, request.WithContext(ctx)) + }) + } +} + +func SetRequestIdInLogger() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + ctx := request.Context() + log := logger.LoadLoggerFromContext(ctx) + log = logger.NewRequestLoggerFromZerolog(ctx, log.Logger) + ctx = logger.SetLoggerInContext(ctx, log) + next.ServeHTTP(responseWriter, request.WithContext(ctx)) + }) + } +} + +func GetRequestId(ctx context.Context) string { + if val, ok := ctx.Value(keys.RequestIdCtxKey).(string); ok { + return val + } + return "" +} diff --git a/middleware/requestid_test.go b/middleware/requestid_test.go new file mode 100644 index 0000000..f22577c --- /dev/null +++ b/middleware/requestid_test.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSetRequestIdWithIncomingHeader(t *testing.T) { + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val := GetRequestId(r.Context()) + assert.Equal(t, "123", val) + }) + + // create the handler to test, using our custom "next" handler + handlerToTest := SetRequestId()(nextHandler) + + // create a mock request to use + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Add("X-Request-Id", "123") + + // call the handler using a mock response recorder (we'll not use that anyway) + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) +} + +func TestSetRequestIdWitoutIncomingHeader(t *testing.T) { + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val := GetRequestId(r.Context()) + assert.Len(t, val, 36) + }) + + // create the handler to test, using our custom "next" handler + handlerToTest := SetRequestId()(nextHandler) + + // create a mock request to use + req := httptest.NewRequest("GET", "http://testing", nil) + + // call the handler using a mock response recorder (we'll not use that anyway) + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) +} diff --git a/middleware/sentry.go b/middleware/sentry.go new file mode 100644 index 0000000..ad1bb75 --- /dev/null +++ b/middleware/sentry.go @@ -0,0 +1,31 @@ +package middleware + +import ( + "net/http" + "runtime/debug" + "time" + + "github.com/getsentry/sentry-go" + "github.com/platform-mesh/golang-commons/logger" +) + +// Recoverer implements a middleware that recover from panics, sends them to Sentry +// log the panic together with a stack trace and sends HTTP status 500 +func SentryRecoverer(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil && err != http.ErrAbortHandler { + log := logger.LoadLoggerFromContext(r.Context()) + log.Error().Interface("panic", err).Interface("stack", debug.Stack()).Msg("recovered http panic") + sentry.CurrentHub().Recover(err) + sentry.Flush(time.Second * 5) + + w.WriteHeader(http.StatusInternalServerError) + } + }() + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} diff --git a/middleware/spiffe.go b/middleware/spiffe.go new file mode 100644 index 0000000..48e3a19 --- /dev/null +++ b/middleware/spiffe.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net/http" + + "github.com/platform-mesh/golang-commons/context" + "github.com/platform-mesh/golang-commons/jwt" +) + +func StoreSpiffeHeader() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + ctx := request.Context() + uriVal := jwt.GetSpiffeUrlValue(request.Header) + + if uriVal != nil { + ctx = context.AddSpiffeToContext(ctx, *uriVal) + } + next.ServeHTTP(responseWriter, request.WithContext(ctx)) + }) + } +} diff --git a/middleware/test_support/local_middleware.go b/middleware/test_support/local_middleware.go new file mode 100644 index 0000000..7a212c4 --- /dev/null +++ b/middleware/test_support/local_middleware.go @@ -0,0 +1,28 @@ +package local_middleware + +import ( + "net/http" + + "github.com/go-jose/go-jose/v4" + "github.com/golang-jwt/jwt/v5" + "github.com/platform-mesh/golang-commons/context" +) + +func LocalMiddleware(tenantId string, userId string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + ctx := request.Context() + + claims := &jwt.RegisteredClaims{Issuer: "localhost:8080", Subject: userId, Audience: jwt.ClaimStrings{"testing"}} + token, err := jwt.NewWithClaims(jwt.SigningMethodNone, claims).SignedString(jwt.UnsafeAllowNoneSignatureType) + if err != nil { + panic(err) // This shouldn't happen, and if it does, only locally + } + + ctx = context.AddWebTokenToContext(ctx, token, []jose.SignatureAlgorithm{"none"}) + ctx = context.AddTenantToContext(ctx, tenantId) + + next.ServeHTTP(responseWriter, request.WithContext(ctx)) + }) + } +} diff --git a/middleware/token.go b/middleware/token.go new file mode 100644 index 0000000..adcd49d --- /dev/null +++ b/middleware/token.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/go-http-utils/headers" + "github.com/go-jose/go-jose/v4" + "github.com/platform-mesh/golang-commons/context" +) + +const tokenAuthPrefix = "BEARER" + +var SignatureAlgorithms = []jose.SignatureAlgorithm{jose.RS256} + +// StoreWebToken retrieves the actual JWT Token within the Authorization header, and it stores it in the context as a struct +func StoreWebToken() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(responseWriter http.ResponseWriter, request *http.Request) { + ctx := request.Context() + auth := strings.Split(request.Header.Get(headers.Authorization), " ") + if len(auth) > 1 && strings.ToUpper(auth[0]) == tokenAuthPrefix { + ctx = context.AddWebTokenToContext(ctx, auth[1], SignatureAlgorithms) + } + + next.ServeHTTP(responseWriter, request.WithContext(ctx)) + }) + } +} From 096421f9a61f732b1823aef9b1e98fb4fa393784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20Echterh=C3=B6lter?= Date: Sat, 13 Sep 2025 14:09:56 +0200 Subject: [PATCH 2/4] test: increased unit test coverage --- middleware/auth_test.go | 101 ++++++++++++++++++ middleware/authzMiddlewares_test.go | 104 ++++++++++++++++++ middleware/logger_test.go | 56 ++++++++++ middleware/middleware_test.go | 41 +++++++ middleware/otel_test.go | 96 +++++++++++++++++ middleware/requestid_test.go | 49 +++++++++ middleware/sentry_test.go | 141 ++++++++++++++++++++++++ middleware/spiffe_test.go | 160 ++++++++++++++++++++++++++++ middleware/token_test.go | 154 ++++++++++++++++++++++++++ 9 files changed, 902 insertions(+) create mode 100644 middleware/auth_test.go create mode 100644 middleware/authzMiddlewares_test.go create mode 100644 middleware/logger_test.go create mode 100644 middleware/middleware_test.go create mode 100644 middleware/otel_test.go create mode 100644 middleware/sentry_test.go create mode 100644 middleware/spiffe_test.go create mode 100644 middleware/token_test.go diff --git a/middleware/auth_test.go b/middleware/auth_test.go new file mode 100644 index 0000000..b07902d --- /dev/null +++ b/middleware/auth_test.go @@ -0,0 +1,101 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-http-utils/headers" + "github.com/platform-mesh/golang-commons/context" + "github.com/stretchr/testify/assert" +) + +func TestStoreAuthHeader_WithAuthHeader(t *testing.T) { + expectedAuth := "Bearer token123" + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify auth header is stored in context + authFromContext, err := context.GetAuthHeaderFromContext(r.Context()) + assert.NoError(t, err) + assert.Equal(t, expectedAuth, authFromContext) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreAuthHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(headers.Authorization, expectedAuth) + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreAuthHeader_WithoutAuthHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify empty auth header returns error when not set + _, err := context.GetAuthHeaderFromContext(r.Context()) + assert.Error(t, err) // Should return error when no auth header is set + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreAuthHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // No authorization header set + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreAuthHeader_WithEmptyAuthHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify empty auth header returns error when empty + _, err := context.GetAuthHeaderFromContext(r.Context()) + assert.Error(t, err) // Should return error when auth header is empty + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreAuthHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(headers.Authorization, "") + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreAuthHeader_MultipleAuthHeaders(t *testing.T) { + // Test behavior when multiple authorization headers are present + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Should get the first/combined value + authFromContext, err := context.GetAuthHeaderFromContext(r.Context()) + assert.NoError(t, err) + assert.NotEmpty(t, authFromContext) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreAuthHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Add(headers.Authorization, "Bearer token1") + req.Header.Add(headers.Authorization, "Bearer token2") + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} diff --git a/middleware/authzMiddlewares_test.go b/middleware/authzMiddlewares_test.go new file mode 100644 index 0000000..c2e6767 --- /dev/null +++ b/middleware/authzMiddlewares_test.go @@ -0,0 +1,104 @@ +package middleware + +import ( + "context" + "testing" + + "github.com/platform-mesh/golang-commons/policy_services" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockTenantRetriever is a mock implementation of TenantRetriever +type MockTenantRetriever struct { + mock.Mock +} + +func (m *MockTenantRetriever) RetrieveTenant(ctx context.Context) (string, error) { + args := m.Called(ctx) + return args.String(0), args.Error(1) +} + +func TestCreateMiddlewares_Deprecated(t *testing.T) { + mockRetriever := &MockTenantRetriever{} + + middlewares := CreateMiddlewares(mockRetriever) + + // Should return 3 middlewares (same as CreateAuthMiddleware) + assert.Len(t, middlewares, 3) + + // Each middleware should be a valid function + for _, mw := range middlewares { + assert.NotNil(t, mw) + } +} + +func TestCreateAuthMiddleware(t *testing.T) { + mockRetriever := &MockTenantRetriever{} + + middlewares := CreateAuthMiddleware(mockRetriever) + + // Should return 3 middlewares: StoreWebToken, StoreAuthHeader, StoreSpiffeHeader + assert.Len(t, middlewares, 3) + + // Each middleware should be a valid function + for _, mw := range middlewares { + assert.NotNil(t, mw) + } +} + +func TestCreateAuthMiddleware_WithNilRetriever(t *testing.T) { + middlewares := CreateAuthMiddleware(nil) + + // Should still return 3 middlewares even with nil retriever + assert.Len(t, middlewares, 3) + + // Each middleware should be a valid function + for _, mw := range middlewares { + assert.NotNil(t, mw) + } +} + +func TestCreateAuthMiddleware_ReturnsCorrectMiddlewares(t *testing.T) { + mockRetriever := &MockTenantRetriever{} + + middlewares := CreateAuthMiddleware(mockRetriever) + + // Verify we get exactly 3 middlewares + assert.Len(t, middlewares, 3) + + // We can't easily test the exact middleware functions returned without more complex setup, + // but we can verify they're all valid middleware functions by checking their signatures + for _, mw := range middlewares { + assert.NotNil(t, mw) + // Each middleware should be a function that takes an http.Handler and returns an http.Handler + // This is implicitly tested by the fact that the function compiles and returns without error + } +} + +func TestCreateMiddlewares_Equivalence(t *testing.T) { + mockRetriever := &MockTenantRetriever{} + + deprecatedMiddlewares := CreateMiddlewares(mockRetriever) + newMiddlewares := CreateAuthMiddleware(mockRetriever) + + // Both functions should return the same number of middlewares + assert.Equal(t, len(deprecatedMiddlewares), len(newMiddlewares)) + + // Both should return 3 middlewares + assert.Len(t, deprecatedMiddlewares, 3) + assert.Len(t, newMiddlewares, 3) +} + +// Test that implements policy_services.TenantRetriever interface +func TestTenantRetrieverInterface(t *testing.T) { + mockRetriever := &MockTenantRetriever{} + + // Verify our mock implements the interface + var retriever policy_services.TenantRetriever = mockRetriever + assert.NotNil(t, retriever) + + // Test that we can use it with the middleware functions + middlewares := CreateAuthMiddleware(retriever) + assert.Len(t, middlewares, 3) +} diff --git a/middleware/logger_test.go b/middleware/logger_test.go new file mode 100644 index 0000000..c871b80 --- /dev/null +++ b/middleware/logger_test.go @@ -0,0 +1,56 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/platform-mesh/golang-commons/logger" + "github.com/platform-mesh/golang-commons/logger/testlogger" + "github.com/stretchr/testify/assert" +) + +func TestStoreLoggerMiddleware(t *testing.T) { + testLog := testlogger.New() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify logger is stored in context + logFromContext := logger.LoadLoggerFromContext(r.Context()) + assert.NotNil(t, logFromContext) + + // The logger should be the same instance we passed + assert.Equal(t, testLog.Logger, logFromContext) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreLoggerMiddleware(testLog.Logger) + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreLoggerMiddleware_NilLogger(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Even with nil logger, the middleware should not panic + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreLoggerMiddleware(nil) + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + // Should not panic + assert.NotPanics(t, func() { + handlerToTest.ServeHTTP(recorder, req) + }) + + assert.Equal(t, http.StatusOK, recorder.Code) +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..53d3855 --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,41 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/platform-mesh/golang-commons/logger/testlogger" + "github.com/stretchr/testify/assert" +) + +func TestCreateMiddleware(t *testing.T) { + log := testlogger.New() + middlewares := CreateMiddleware(log.Logger) + + // Should return 5 middlewares + assert.Len(t, middlewares, 5) + + // Each middleware should be a valid function + for _, mw := range middlewares { + assert.NotNil(t, mw) + } + + // Test that middlewares can be chained + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Apply all middlewares + var finalHandler http.Handler = handler + for i := len(middlewares) - 1; i >= 0; i-- { + finalHandler = middlewares[i](finalHandler) + } + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + finalHandler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} diff --git a/middleware/otel_test.go b/middleware/otel_test.go new file mode 100644 index 0000000..bc95079 --- /dev/null +++ b/middleware/otel_test.go @@ -0,0 +1,96 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" +) + +func TestSetOtelTracingContext(t *testing.T) { + // Set up a test propagator + propagator := propagation.TraceContext{} + otel.SetTextMapPropagator(propagator) + + // Create a span context to inject + tracer := trace.NewNoopTracerProvider().Tracer("test") + ctx, span := tracer.Start(context.Background(), "test-span") + span.End() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The context should have been extracted and set + assert.NotNil(t, r.Context()) + w.WriteHeader(http.StatusOK) + }) + + middleware := SetOtelTracingContext() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + + // Inject trace context into headers + propagator.Inject(ctx, propagation.HeaderCarrier(req.Header)) + + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestSetOtelTracingContext_NoTraceHeaders(t *testing.T) { + // Set up a test propagator + propagator := propagation.TraceContext{} + otel.SetTextMapPropagator(propagator) + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Even without trace headers, context should be set + assert.NotNil(t, r.Context()) + w.WriteHeader(http.StatusOK) + }) + + middleware := SetOtelTracingContext() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestSetOtelTracingContext_Integration(t *testing.T) { + // Test that the middleware properly integrates with the OpenTelemetry propagation system + propagator := propagation.TraceContext{} + otel.SetTextMapPropagator(propagator) + + var extractedContext context.Context + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + extractedContext = r.Context() + w.WriteHeader(http.StatusOK) + }) + + middleware := SetOtelTracingContext() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + + // Add a fake trace header to test extraction + req.Header.Set("traceparent", "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") + + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.NotNil(t, extractedContext) + + // Verify that the context is different from the original request context + assert.NotEqual(t, req.Context(), extractedContext) +} diff --git a/middleware/requestid_test.go b/middleware/requestid_test.go index f22577c..10ce175 100644 --- a/middleware/requestid_test.go +++ b/middleware/requestid_test.go @@ -1,10 +1,12 @@ package middleware import ( + "context" "net/http" "net/http/httptest" "testing" + "github.com/platform-mesh/golang-commons/context/keys" "github.com/stretchr/testify/assert" ) @@ -42,3 +44,50 @@ func TestSetRequestIdWitoutIncomingHeader(t *testing.T) { // call the handler using a mock response recorder (we'll not use that anyway) handlerToTest.ServeHTTP(httptest.NewRecorder(), req) } + +func TestSetRequestIdInLogger(t *testing.T) { + // This test verifies that SetRequestIdInLogger creates a request-aware logger + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The logger in context should be updated with request information + w.WriteHeader(http.StatusOK) + }) + + // create the handler to test + handlerToTest := SetRequestIdInLogger()(nextHandler) + + // create a mock request to use + req := httptest.NewRequest("GET", "http://testing", nil) + + // call the handler using a mock response recorder + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) +} + +func TestGetRequestId_WithValidContext(t *testing.T) { + requestId := "test-request-id-123" + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + retrievedId := GetRequestId(r.Context()) + assert.Equal(t, requestId, retrievedId) + }) + + handlerToTest := SetRequestId()(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Add("X-Request-Id", requestId) + + handlerToTest.ServeHTTP(httptest.NewRecorder(), req) +} + +func TestGetRequestId_WithEmptyContext(t *testing.T) { + // Test GetRequestId with a context that doesn't have a request ID + emptyCtx := context.Background() + requestId := GetRequestId(emptyCtx) + assert.Empty(t, requestId) +} + +func TestGetRequestId_WithInvalidContextValue(t *testing.T) { + // Test GetRequestId with a context that has an invalid request ID value + ctx := context.WithValue(context.Background(), keys.RequestIdCtxKey, 123) // not a string + requestId := GetRequestId(ctx) + assert.Empty(t, requestId) +} diff --git a/middleware/sentry_test.go b/middleware/sentry_test.go new file mode 100644 index 0000000..8e23036 --- /dev/null +++ b/middleware/sentry_test.go @@ -0,0 +1,141 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/platform-mesh/golang-commons/logger" + "github.com/platform-mesh/golang-commons/logger/testlogger" + "github.com/stretchr/testify/assert" +) + +func TestSentryRecoverer_NoPanic(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("success")) + }) + + handlerToTest := SentryRecoverer(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) + assert.Equal(t, "success", recorder.Body.String()) +} + +func TestSentryRecoverer_WithPanic(t *testing.T) { + log := testlogger.New().HideLogOutput() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + handlerToTest := SentryRecoverer(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // Add logger to context so the middleware can log the panic + ctx := req.Context() + ctx = logger.SetLoggerInContext(ctx, log.Logger) + req = req.WithContext(ctx) + + recorder := httptest.NewRecorder() + + // Should not panic, should recover + assert.NotPanics(t, func() { + handlerToTest.ServeHTTP(recorder, req) + }) + + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + + // Verify that the panic was logged + messages, err := log.GetLogMessages() + assert.NoError(t, err) + assert.NotEmpty(t, messages) + + // Find the panic log message + found := false + for _, msg := range messages { + if msg.Message == "recovered http panic" { + found = true + break + } + } + assert.True(t, found, "Expected to find panic log message") +} + +func TestSentryRecoverer_WithHttpErrAbortHandler(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // http.ErrAbortHandler should not be recovered + panic(http.ErrAbortHandler) + }) + + handlerToTest := SentryRecoverer(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + // The middleware should not recover from http.ErrAbortHandler + // Since the condition is `err != http.ErrAbortHandler`, it should let this panic through + // However, since the defer recover catches it but doesn't handle it, it won't re-panic + // Let's test that it doesn't crash the middleware + assert.NotPanics(t, func() { + handlerToTest.ServeHTTP(recorder, req) + }) +} + +func TestSentryRecoverer_WithStringPanic(t *testing.T) { + log := testlogger.New().HideLogOutput() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("string panic message") + }) + + handlerToTest := SentryRecoverer(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // Add logger to context + ctx := req.Context() + ctx = logger.SetLoggerInContext(ctx, log.Logger) + req = req.WithContext(ctx) + + recorder := httptest.NewRecorder() + + assert.NotPanics(t, func() { + handlerToTest.ServeHTTP(recorder, req) + }) + + assert.Equal(t, http.StatusInternalServerError, recorder.Code) + + // Verify that the panic was logged + messages, err := log.GetLogMessages() + assert.NoError(t, err) + assert.NotEmpty(t, messages) +} + +func TestSentryRecoverer_WithErrorPanic(t *testing.T) { + log := testlogger.New().HideLogOutput() + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic(assert.AnError) + }) + + handlerToTest := SentryRecoverer(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // Add logger to context + ctx := req.Context() + ctx = logger.SetLoggerInContext(ctx, log.Logger) + req = req.WithContext(ctx) + + recorder := httptest.NewRecorder() + + assert.NotPanics(t, func() { + handlerToTest.ServeHTTP(recorder, req) + }) + + assert.Equal(t, http.StatusInternalServerError, recorder.Code) +} diff --git a/middleware/spiffe_test.go b/middleware/spiffe_test.go new file mode 100644 index 0000000..84544ff --- /dev/null +++ b/middleware/spiffe_test.go @@ -0,0 +1,160 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/platform-mesh/golang-commons/context" + "github.com/platform-mesh/golang-commons/jwt" + "github.com/stretchr/testify/assert" +) + +func TestStoreSpiffeHeader_WithValidSpiffeHeader(t *testing.T) { + expectedSpiffeID := "spiffe://example.org/workload" + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify SPIFFE ID is stored in context + spiffeFromContext, err := context.GetSpiffeFromContext(r.Context()) + assert.NoError(t, err) + assert.Equal(t, expectedSpiffeID, spiffeFromContext) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreSpiffeHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // Set the SPIFFE header that jwt.GetSpiffeUrlValue expects + req.Header.Set("X-Forwarded-Client-Cert", "Subject=\"CN=test\";URI="+expectedSpiffeID) + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreSpiffeHeader_WithoutSpiffeHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Context should not have SPIFFE ID when no header is present + _, err := context.GetSpiffeFromContext(r.Context()) + assert.Error(t, err) // Should return an error when no SPIFFE ID is present + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreSpiffeHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // No SPIFFE header set + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreSpiffeHeader_WithEmptySpiffeHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Context should not have SPIFFE ID when header is empty + _, err := context.GetSpiffeFromContext(r.Context()) + assert.Error(t, err) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreSpiffeHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set("X-Forwarded-Client-Cert", "") + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreSpiffeHeader_WithInvalidSpiffeHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Context should not have SPIFFE ID when header is invalid + _, err := context.GetSpiffeFromContext(r.Context()) + assert.Error(t, err) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreSpiffeHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set("X-Forwarded-Client-Cert", "InvalidHeaderValue") + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreSpiffeHeader_WithMultipleSpiffeHeaders(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // With multiple headers, the behavior depends on jwt.GetSpiffeUrlValue implementation + // It should either return the first valid one or handle concatenated headers + spiffeFromContext, err := context.GetSpiffeFromContext(r.Context()) + if err == nil { + // If we get a value, it should be a valid SPIFFE ID + assert.Contains(t, spiffeFromContext, "spiffe://") + } + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreSpiffeHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Add("X-Forwarded-Client-Cert", "Subject=\"CN=test1\";URI=spiffe://example.org/workload1") + req.Header.Add("X-Forwarded-Client-Cert", "Subject=\"CN=test2\";URI=spiffe://example.org/workload2") + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreSpiffeHeader_Integration(t *testing.T) { + // Test the integration with jwt.GetSpiffeUrlValue function + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify that the middleware properly uses jwt.GetSpiffeUrlValue + spiffeValue := jwt.GetSpiffeUrlValue(r.Header) + + if spiffeValue != nil { + spiffeFromContext, err := context.GetSpiffeFromContext(r.Context()) + assert.NoError(t, err) + assert.Equal(t, *spiffeValue, spiffeFromContext) + } else { + _, err := context.GetSpiffeFromContext(r.Context()) + assert.Error(t, err) + } + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreSpiffeHeader() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // Test without header first + recorder := httptest.NewRecorder() + handlerToTest.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) + + // Test with valid header + req = httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set("X-Forwarded-Client-Cert", "Subject=\"CN=test\";URI=spiffe://example.org/test") + recorder = httptest.NewRecorder() + handlerToTest.ServeHTTP(recorder, req) + assert.Equal(t, http.StatusOK, recorder.Code) +} diff --git a/middleware/token_test.go b/middleware/token_test.go new file mode 100644 index 0000000..de3f907 --- /dev/null +++ b/middleware/token_test.go @@ -0,0 +1,154 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-http-utils/headers" + "github.com/platform-mesh/golang-commons/context" + "github.com/stretchr/testify/assert" +) + +func TestStoreWebToken_WithValidBearerToken(t *testing.T) { + token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + authHeader := "Bearer " + token + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Token parsing may fail due to signature validation, which is expected in tests + // The middleware should handle this gracefully + _, err := context.GetWebTokenFromContext(r.Context()) + // In a real scenario with proper JWT validation, this might fail + // For test purposes, we just verify the middleware doesn't crash + if err != nil { + // This is expected behavior when token validation fails + assert.Error(t, err) + } + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreWebToken() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(headers.Authorization, authHeader) + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreWebToken_WithoutAuthHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Context should not have a token + _, err := context.GetWebTokenFromContext(r.Context()) + assert.Error(t, err) // Should return an error when no token is present + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreWebToken() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + // No authorization header set + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreWebToken_WithNonBearerToken(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Context should not have a token + _, err := context.GetWebTokenFromContext(r.Context()) + assert.Error(t, err) // Should return an error when no valid token is present + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreWebToken() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(headers.Authorization, "Basic dXNlcjpwYXNz") // Basic auth, not Bearer + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreWebToken_WithEmptyBearerToken(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Context should not have a token due to empty token + _, err := context.GetWebTokenFromContext(r.Context()) + assert.Error(t, err) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreWebToken() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(headers.Authorization, "Bearer ") + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreWebToken_WithBearerTokenLowercase(t *testing.T) { + token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c" + authHeader := "bearer " + token // lowercase + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Token parsing may fail due to signature validation, which is expected in tests + _, err := context.GetWebTokenFromContext(r.Context()) + // The middleware should process lowercase bearer tokens + // but token validation may still fail due to signature issues + if err != nil { + // This is expected behavior when token validation fails + assert.Error(t, err) + } + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreWebToken() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(headers.Authorization, authHeader) + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} + +func TestStoreWebToken_WithMalformedAuthHeader(t *testing.T) { + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Context should not have a token + _, err := context.GetWebTokenFromContext(r.Context()) + assert.Error(t, err) + + w.WriteHeader(http.StatusOK) + }) + + middleware := StoreWebToken() + handlerToTest := middleware(nextHandler) + + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set(headers.Authorization, "Bearer") // Missing space and token + recorder := httptest.NewRecorder() + + handlerToTest.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +} From 1237e2dc59610df7ca4955709dbf644e221d5599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20Echterh=C3=B6lter?= Date: Sat, 13 Sep 2025 14:15:25 +0200 Subject: [PATCH 3/4] chore: address lint warnings --- middleware/otel_test.go | 4 ++-- middleware/sentry_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/middleware/otel_test.go b/middleware/otel_test.go index bc95079..513042d 100644 --- a/middleware/otel_test.go +++ b/middleware/otel_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" - "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" ) func TestSetOtelTracingContext(t *testing.T) { @@ -18,7 +18,7 @@ func TestSetOtelTracingContext(t *testing.T) { otel.SetTextMapPropagator(propagator) // Create a span context to inject - tracer := trace.NewNoopTracerProvider().Tracer("test") + tracer := noop.NewTracerProvider().Tracer("test") ctx, span := tracer.Start(context.Background(), "test-span") span.End() diff --git a/middleware/sentry_test.go b/middleware/sentry_test.go index 8e23036..f9705b7 100644 --- a/middleware/sentry_test.go +++ b/middleware/sentry_test.go @@ -13,7 +13,7 @@ import ( func TestSentryRecoverer_NoPanic(t *testing.T) { nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("success")) + _, _ = w.Write([]byte("success")) }) handlerToTest := SentryRecoverer(nextHandler) From a54717a102b4001a5d7703269e40b86c47f1a2ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20Echterh=C3=B6lter?= Date: Sat, 13 Sep 2025 14:16:45 +0200 Subject: [PATCH 4/4] chore(middleware): exclude test support from coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit • Updates coverage configuration to exclude middleware test support files --- .testcoverage.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.testcoverage.yml b/.testcoverage.yml index 372613f..3dc1910 100644 --- a/.testcoverage.yml +++ b/.testcoverage.yml @@ -4,4 +4,5 @@ exclude: - mocks # exclude generated mock files - ^test/ - ^logger/testlogger + - middleware/test_support