diff --git a/middleware/authzMiddlewares.go b/middleware/authzMiddlewares.go index 84a202e..37be850 100644 --- a/middleware/authzMiddlewares.go +++ b/middleware/authzMiddlewares.go @@ -4,13 +4,10 @@ import ( "net/http" ) -// Middleware defines a function that wraps an http.Handler. -type Middleware func(http.Handler) http.Handler - -// CreateAuthMiddleware returns a slice of Middleware functions for authentication and authorization. +// CreateAuthMiddleware returns a slice of middleware functions for authentication and authorization. // The returned middlewares are: StoreWebToken, StoreAuthHeader, and StoreSpiffeHeader. -func CreateAuthMiddleware() []Middleware { - return []Middleware{ +func CreateAuthMiddleware() []func(http.Handler) http.Handler { + return []func(http.Handler) http.Handler{ StoreWebToken(), StoreAuthHeader(), StoreSpiffeHeader(), diff --git a/middleware/middleware.go b/middleware/middleware.go index 0e49a42..6f57e66 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -6,13 +6,20 @@ import ( "github.com/platform-mesh/golang-commons/logger" ) -// attaches a request-scoped logger (using the provided logger), assigns a request ID, and propagates that ID into the logger. -func CreateMiddleware(log *logger.Logger) []func(http.Handler) http.Handler { - return []func(http.Handler) http.Handler{ +// CreateMiddleware creates a middleware chain with logging, tracing, and optional authentication. +// It attaches a request-scoped logger (using the provided logger), assigns a request ID, and propagates that ID into the logger. +// When auth is true, authentication middlewares (StoreWebToken, StoreAuthHeader, StoreSpiffeHeader) are included. +func CreateMiddleware(log *logger.Logger, auth bool) []func(http.Handler) http.Handler { + mws := []func(http.Handler) http.Handler{ SetOtelTracingContext(), SentryRecoverer, StoreLoggerMiddleware(log), SetRequestId(), SetRequestIdInLogger(), } + + if auth { + mws = append(mws, CreateAuthMiddleware()...) + } + return mws } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 53d3855..2c002ff 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -9,11 +9,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCreateMiddleware(t *testing.T) { +func TestCreateMiddleware_WithoutAuth(t *testing.T) { log := testlogger.New() - middlewares := CreateMiddleware(log.Logger) + middlewares := CreateMiddleware(log.Logger, false) - // Should return 5 middlewares + // Should return 5 middlewares when auth is false assert.Len(t, middlewares, 5) // Each middleware should be a valid function @@ -39,3 +39,34 @@ func TestCreateMiddleware(t *testing.T) { assert.Equal(t, http.StatusOK, recorder.Code) } + +func TestCreateMiddleware_WithAuth(t *testing.T) { + log := testlogger.New() + middlewares := CreateMiddleware(log.Logger, true) + + // Should return 8 middlewares when auth is true (5 base + 3 auth) + assert.Len(t, middlewares, 8) + + // Each middleware should be a valid function + for _, mw := range middlewares { + assert.NotNil(t, mw) + } + + // Test that middlewares can be chained + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Apply all middlewares + var finalHandler http.Handler = handler + for i := len(middlewares) - 1; i >= 0; i-- { + finalHandler = middlewares[i](finalHandler) + } + + req := httptest.NewRequest("GET", "http://testing", nil) + recorder := httptest.NewRecorder() + + finalHandler.ServeHTTP(recorder, req) + + assert.Equal(t, http.StatusOK, recorder.Code) +}